Source code for MDMC.common.factory

"""
Generalised factory class.
"""

from abc import ABC
from collections.abc import Callable, Iterable, Sequence
from importlib import import_module
from inspect import getmembers, isabstract, isclass
from pathlib import Path
from typing import Generic, TypeVar, get_args

T = TypeVar('T')

# Inheritors define own registry
# pylint: disable=no-member


[docs] class Factory(ABC): # noqa: B024 - Abstract class no abs meth. """ General factory class. Attributes ---------- registry : dict[str, Callable] Dictionary of keys to names. """
[docs] @classmethod def get(cls, key: str) -> Callable[..., T]: """ Return a callable instance to construct given class. """ return cls.registry[key]
[docs] @classmethod def create(cls, key: str, *args, **kwargs) -> Callable[..., T]: """ Return an instance of given class. """ return cls.get(key)(*args, **kwargs)
[docs] @classmethod def available_names(cls) -> Sequence[str]: """ Known types supported by factory. Returns ------- ~collections.abc.Sequence[str] Available keys to load. """ return cls.registry.keys()
[docs] @classmethod def supported_types(cls) -> tuple[type, ...]: """ Return list of types supported by this factory. Returns ------- tuple[type, ...] Parent classes allowed by this factory. """ return get_args(cls.__orig_bases__[0])
[docs] class ModuleFactory(Factory, ABC, Generic[T]): """ Scan current directory for any valid types. Supports: - scanning of files for relevant classes/functions. - exclusion of self and certain files. Attributes ---------- curr_path : Path Path to scan, usually ``Path(__file__).parent``. curr_pack : str Current package to import relative to. Usually ``__package__``. exclude : Sequence[Path] Paths to exclude from search. """ curr_path: Path = None curr_pack: str = None exclude: Sequence[Path] = ()
[docs] @classmethod def scan(cls): """ Scan current directory for any valid types. These types are added to the registry and available to be loaded. """ for path in cls.curr_path.glob("*.py"): if path.stem.startswith(".") or any(path.samefile(other) for other in cls.exclude): continue module = import_module("." + path.stem, cls.curr_pack) classes = getmembers( module, lambda m: ( isclass(m) and not isabstract(m) and issubclass(m, cls.supported_types()) ), ) for name, type_ in classes: cls.registry[name] = type_ # Get name of module too if only one class exists and # module name does not match if len(classes) == 1 and path.stem != classes[0][0]: cls.registry[path.stem] = classes[0][1]
[docs] class RegisterFactory(Factory, ABC, Generic[T]): """ Factory requiring manual registration to data. See Also -------- RegisterFactory.register : Registration mechanism. """
[docs] @classmethod def register(cls, names: str | Iterable[str]) -> Callable[..., T]: """ A class level decorator for registering classes. The names of the modules with which the class is registered should be the parameter passed to the decorator. Parameters ---------- names : str The names of the modules with are registered Example ------- To register the ``SQw`` class with ``RegisterFactory``: .. code-block:: python @RegisterFactory.register('SQw') class SQw(Observable): """ def class_wrapper(wrapped_class: type) -> Callable[..., T]: if isinstance(names, str): cls.registry[names] = wrapped_class else: for name in names: cls.registry[name] = wrapped_class return wrapped_class return class_wrapper