Source code for shmpipeline.registry

"""Kernel registry used by the manager and worker runtime."""

from __future__ import annotations

from functools import partial
from importlib.util import find_spec
from typing import Callable, Mapping

from shmpipeline.config import KernelConfig, SharedMemoryConfig
from shmpipeline.errors import ConfigValidationError
from shmpipeline.kernel import Kernel, KernelContext
from shmpipeline.kernels.cpu import (
    AddConstantCpuKernel,
    AffineTransformCpuKernel,
    CopyCpuKernel,
    CustomOperationCpuKernel,
    ElementwiseAddCpuKernel,
    ElementwiseDivideCpuKernel,
    ElementwiseMultiplyCpuKernel,
    ElementwiseSubtractCpuKernel,
    FlattenCpuKernel,
    LeakyIntegratorCpuKernel,
    RaiseErrorCpuKernel,
    ScaleCpuKernel,
    ScaleOffsetCpuKernel,
    ShackHartmannCentroidCpuKernel,
)

_TORCH_AVAILABLE = find_spec("torch") is not None

if _TORCH_AVAILABLE:
    _DEFAULT_GPU_KINDS = (
        "gpu.add_constant",
        "gpu.affine_transform",
        "gpu.copy",
        "gpu.custom_operation",
        "gpu.elementwise_add",
        "gpu.elementwise_divide",
        "gpu.elementwise_multiply",
        "gpu.elementwise_subtract",
        "gpu.flatten",
        "gpu.leaky_integrator",
        "gpu.raise_error",
        "gpu.scale",
        "gpu.scale_offset",
        "gpu.shack_hartmann_centroid",
    )


def _load_default_gpu_kernel(kind: str) -> type[Kernel]:
    from shmpipeline.kernels.gpu import (
        AddConstantGpuKernel,
        AffineTransformGpuKernel,
        CopyGpuKernel,
        CustomOperationGpuKernel,
        ElementwiseAddGpuKernel,
        ElementwiseDivideGpuKernel,
        ElementwiseMultiplyGpuKernel,
        ElementwiseSubtractGpuKernel,
        FlattenGpuKernel,
        LeakyIntegratorGpuKernel,
        RaiseErrorGpuKernel,
        ScaleGpuKernel,
        ScaleOffsetGpuKernel,
        ShackHartmannCentroidGpuKernel,
    )

    return {
        AddConstantGpuKernel.kind: AddConstantGpuKernel,
        AffineTransformGpuKernel.kind: AffineTransformGpuKernel,
        CopyGpuKernel.kind: CopyGpuKernel,
        CustomOperationGpuKernel.kind: CustomOperationGpuKernel,
        ElementwiseAddGpuKernel.kind: ElementwiseAddGpuKernel,
        ElementwiseDivideGpuKernel.kind: ElementwiseDivideGpuKernel,
        ElementwiseMultiplyGpuKernel.kind: ElementwiseMultiplyGpuKernel,
        ElementwiseSubtractGpuKernel.kind: ElementwiseSubtractGpuKernel,
        FlattenGpuKernel.kind: FlattenGpuKernel,
        LeakyIntegratorGpuKernel.kind: LeakyIntegratorGpuKernel,
        RaiseErrorGpuKernel.kind: RaiseErrorGpuKernel,
        ScaleGpuKernel.kind: ScaleGpuKernel,
        ScaleOffsetGpuKernel.kind: ScaleOffsetGpuKernel,
        ShackHartmannCentroidGpuKernel.kind: ShackHartmannCentroidGpuKernel,
    }[kind]


[docs] class KernelRegistry: """Resolve kernel kinds to implementation classes. Registries are the extension point for third-party kernels. The default registry contains the built-in CPU kernels and lazily loads GPU kernels when the optional torch dependency is available. """ def __init__( self, kernels: Mapping[str, type[Kernel]], lazy_kernels: Mapping[str, Callable[[], type[Kernel]]] | None = None, ) -> None: """Store a static mapping of registered kernel implementations.""" self._kernels = dict(kernels) self._lazy_kernels = dict(lazy_kernels or {})
[docs] def get(self, kind: str) -> type[Kernel]: """Return the implementation class for a kernel kind.""" kernel_cls = self._kernels.get(kind) if kernel_cls is not None: return kernel_cls loader = self._lazy_kernels.get(kind) if loader is None: raise ConfigValidationError(f"unknown kernel kind: {kind!r}") kernel_cls = loader() self._kernels[kind] = kernel_cls self._lazy_kernels.pop(kind, None) return kernel_cls
[docs] def kinds(self) -> tuple[str, ...]: """Return registered kernel kinds in sorted order.""" return tuple(sorted({*self._kernels, *self._lazy_kernels}))
[docs] def register( self, kernel_cls: type[Kernel], *, replace: bool = False, ) -> None: """Register one kernel implementation class on this registry.""" kind = _kernel_kind(kernel_cls) if not replace and ( kind in self._kernels or kind in self._lazy_kernels ): raise ValueError(f"kernel kind {kind!r} is already registered") self._kernels[kind] = kernel_cls self._lazy_kernels.pop(kind, None)
[docs] def extended( self, *kernel_classes: type[Kernel], replace: bool = False, ) -> "KernelRegistry": """Return a new registry extended with additional kernel classes.""" registry = KernelRegistry(self._kernels, self._lazy_kernels) for kernel_cls in kernel_classes: registry.register(kernel_cls, replace=replace) return registry
[docs] def validate( self, config: KernelConfig, shared_memory: Mapping[str, SharedMemoryConfig], ) -> None: """Validate one kernel binding against shared-memory definitions.""" self.get(config.kind).validate_config(config, shared_memory)
[docs] def create( self, config: KernelConfig, shared_memory: Mapping[str, SharedMemoryConfig], ) -> Kernel: """Instantiate a kernel after validation.""" kernel_cls = self.get(config.kind) kernel_cls.validate_config(config, shared_memory) return kernel_cls( KernelContext(config=config, shared_memory=shared_memory) )
def _kernel_kind(kernel_cls: type[Kernel]) -> str: if not isinstance(kernel_cls, type) or not issubclass(kernel_cls, Kernel): raise TypeError("kernel_cls must be a Kernel subclass") kind = getattr(kernel_cls, "kind", None) if not isinstance(kind, str) or not kind.strip(): raise ValueError("kernel_cls.kind must be a non-empty string") return kind.strip() _DEFAULT_KERNELS: dict[str, type[Kernel]] = { AddConstantCpuKernel.kind: AddConstantCpuKernel, AffineTransformCpuKernel.kind: AffineTransformCpuKernel, CopyCpuKernel.kind: CopyCpuKernel, CustomOperationCpuKernel.kind: CustomOperationCpuKernel, ElementwiseAddCpuKernel.kind: ElementwiseAddCpuKernel, ElementwiseDivideCpuKernel.kind: ElementwiseDivideCpuKernel, ElementwiseMultiplyCpuKernel.kind: ElementwiseMultiplyCpuKernel, ElementwiseSubtractCpuKernel.kind: ElementwiseSubtractCpuKernel, FlattenCpuKernel.kind: FlattenCpuKernel, LeakyIntegratorCpuKernel.kind: LeakyIntegratorCpuKernel, RaiseErrorCpuKernel.kind: RaiseErrorCpuKernel, ScaleCpuKernel.kind: ScaleCpuKernel, ScaleOffsetCpuKernel.kind: ScaleOffsetCpuKernel, ShackHartmannCentroidCpuKernel.kind: ShackHartmannCentroidCpuKernel, } _DEFAULT_LAZY_KERNELS: dict[str, Callable[[], type[Kernel]]] = {} if _TORCH_AVAILABLE: _DEFAULT_LAZY_KERNELS = { kind: partial(_load_default_gpu_kernel, kind) for kind in _DEFAULT_GPU_KINDS } _DEFAULT_REGISTRY = KernelRegistry(_DEFAULT_KERNELS, _DEFAULT_LAZY_KERNELS)
[docs] def get_default_registry() -> KernelRegistry: """Return the built-in kernel registry.""" return _DEFAULT_REGISTRY