Source code for plum.dispatcher

import sys
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, Optional, Tuple, TypeVar, Union, overload

from .function import Function
from .overload import get_overloads
from .signature import Signature
from .util import Callable, TypeHint, get_class, is_in_class

__all__ = ["Dispatcher", "dispatch", "clear_all_cache"]

T = TypeVar("T", bound=Callable[..., Any])


_dataclass_kw_args: Dict[str, Any] = {}
if sys.version_info >= (3, 10):  # pragma: specific no cover 3.8 3.9
    _dataclass_kw_args |= {"slots": True}


[docs]@dataclass(frozen=True, **_dataclass_kw_args) class Dispatcher: """A namespace for functions. Args: warn_redefinition (bool, optional): Throw a warning whenever a method is redefined. Defaults to `False`. Attributes: functions (dict[str, :class:`.function.Function`]): Functions by name. classes (dict[str, dict[str, :class:`.function.Function`]]): Methods of all classes by the qualified name of a class. warn_redefinition (bool): Throw a warning whenever a method is redefined. """ warn_redefinition: bool = False functions: Dict[str, Function] = field(default_factory=dict) classes: Dict[str, Dict[str, Function]] = field(default_factory=dict) @overload def __call__(self, method: T, precedence: int = ...) -> T: ... @overload def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ... def __call__( self, method: Optional[T] = None, precedence: int = 0 ) -> Union[T, Callable[[T], T]]: """Decorator to register for a particular signature. Args: precedence (int, optional): Precedence of the signature. Defaults to `0`. Returns: function: Decorator. """ if method is None: return partial(self.__call__, precedence=precedence) # If `method` has overloads, assume that those overloads need to be registered # and that `method` is not an implementation. overloads = get_overloads(method) if overloads: for overload_method in overloads: # All `f` returned by `self._add_method` are the same. f = self._add_method(overload_method, None, precedence=precedence) # We do not need to register `method`, because it is not an implementation. return f # The signature will be automatically derived from `method`, so we can safely # set the signature argument to `None`. return self._add_method(method, None, precedence=precedence)
[docs] def multi( self, *signatures: Union[Signature, Tuple[TypeHint, ...]] ) -> Callable[[Callable], Function]: """Decorator to register multiple signatures at once. Args: *signatures (tuple or :class:`.signature.Signature`): Signatures to register. Returns: function: Decorator. """ resolved_signatures = [] for signature in signatures: if isinstance(signature, Signature): resolved_signatures.append(signature) elif isinstance(signature, tuple): resolved_signatures.append(Signature(*signature)) else: raise ValueError( f"Signature `{signature}` must be a tuple or of type " f"`plum.signature.Signature`." ) def decorator(method: Callable) -> Function: # The precedence will not be used, so we can safely set it to `None`. return self._add_method(method, *resolved_signatures, precedence=None) return decorator
[docs] def abstract(self, method: Callable) -> Function: """Decorator for an abstract function definition. The abstract function definition does not implement any methods.""" return self._get_function(method)
def _get_function(self, method: Callable) -> Function: # If a class is the owner, use a namespace specific for that class. Otherwise, # use the global namespace. if is_in_class(method): owner = get_class(method) if owner not in self.classes: self.classes[owner] = {} namespace = self.classes[owner] else: owner = None namespace = self.functions # Create a new function only if the function does not already exist. name = method.__name__ if name not in namespace: namespace[name] = Function( method, owner=owner, warn_redefinition=self.warn_redefinition, ) return namespace[name] def _add_method( self, method: Callable, *signatures: Optional[Signature], precedence: Optional[int], ) -> Function: f = self._get_function(method) for signature in signatures: f.register(method, signature, precedence) return f
[docs] def clear_cache(self): """Clear cache.""" for f in self.functions.values(): f.clear_cache()
[docs]def clear_all_cache(): """Clear all cache, including the cache of subclass checks. This should be called if types are modified.""" for f in Function._instances: f.clear_cache()
dispatch = Dispatcher() #: A default dispatcher for convenience purposes.