Union Aliases

Union Aliases#

To understand what union aliases are and what problem they solve, consider the following example. Suppose that we would want to implement a special addition function, and we would want to implement it for all NumPy scalar types:

import numpy as np

from typing import Union
from plum import dispatch


scalar_types = sum(np.core.sctypes.values(), [])  # All NumPy scalar types
Scalar = Union[tuple(scalar_types)]  # Union of all NumPy scalar types


@dispatch
def add(x: Scalar, y: Scalar):
    return x + y

This looks all fine, until you look at the documentation. In particular, help(add) prints

Help on Function in module __main__:

add(x: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void], y: Union[numpy.int8, numpy.int16, numpy.int32, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.float16, numpy.float32, numpy.float64, numpy.float128, numpy.complex64, numpy.complex128, numpy.complex256, bool, object, bytes, str, numpy.void])

While the documentation is accurate, it is not at all helpful to expand the union in its many elements, because it obscures the key message: add(x, y) is implemented for all scalars. A better option would be to print add(x: Scalar, y: Scalar). This is precisely what union aliases do: by aliasing a union, you change the way it is displayed. Union aliases must be activated explicitly, because the feature monkeypatches Union.__str__ and Union.__repr__.

>>> from plum import activate_union_aliases, set_union_alias

>>> activate_union_aliases()

>>> set_union_alias(Scalar, alias="Scalar")
typing.Union[Scalar]

After this, help(add) now prints the following:

Help on Function in module __main__:

add(x: Union[Scalar], y: Union[Scalar])

Hurray! Note that the documentation prints Union[Scalar] rather than just Scalar. This is intentional: it is to prevent breaking code that depends on how unions print. For example, printing just Scalar would omit the type parameter(s).

Let’s see with a few more examples how this works:

>>> Scalar
typing.Union[Scalar]

>>> Union[tuple(scalar_types)]
typing.Union[Scalar]

>>> Union[tuple(scalar_types) + (tuple,)]       # Scalar or tuple
typing.Union[Scalar, tuple]

>>> Union[tuple(scalar_types) + (tuple, list)]  # Scalar or tuple or list
typing.Union[Scalar, tuple, list]

If we don’t include all of scalar_types, we won’t see Scalar, as desired:

>>> Union[tuple(scalar_types[:-1])]
typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool]

You can deactivate union aliases with deactivate_union_aliases:

>>> from plum import deactivate_union_aliases

>>> deactivate_union_aliases()

% skip: next "Result depends on NumPy version."
>>> Scalar
typing.Union[numpy.int8, numpy.int16, numpy.int32, numpy.longlong, numpy.int64, numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64, numpy.ulonglong, numpy.float16, numpy.float32, numpy.float64, numpy.longdouble, numpy.complex64, numpy.complex128, numpy.clongdouble, numpy.str_, numpy.bytes_, numpy.void, numpy.bool, numpy.object_]