Source code for mlx_lattice.core.relations.specs

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass

from mlx_lattice.core.types import Triple, triple


[docs] @dataclass(frozen=True, slots=True, init=False) class KernelSpec: """Normalized 3D kernel geometry for sparse relation builders. ``KernelSpec`` accepts integers or three-item sequences for size, stride, padding, and dilation and stores all values as triples. The object is hashable, so coordinate managers can use it directly as part of relation cache keys. """ size: Triple stride: Triple padding: Triple dilation: Triple def __init__( self, size: int | Sequence[int] = 3, stride: int | Sequence[int] = 1, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> None: normalized_size = triple(size, name='kernel_size') normalized_stride = triple(stride, name='stride') normalized_padding = triple(padding, name='padding') normalized_dilation = triple(dilation, name='dilation') _require_positive(normalized_size, 'kernel_size') _require_positive(normalized_stride, 'stride') _require_nonnegative(normalized_padding, 'padding') _require_positive(normalized_dilation, 'dilation') object.__setattr__(self, 'size', normalized_size) object.__setattr__(self, 'stride', normalized_stride) object.__setattr__(self, 'padding', normalized_padding) object.__setattr__(self, 'dilation', normalized_dilation) @property def volume(self) -> int: """Number of offsets in the dense kernel footprint.""" return self.size[0] * self.size[1] * self.size[2] @property def is_pointwise(self) -> bool: """Whether this spec is an identity 1x1x1 pointwise mapping.""" return ( self.size == (1, 1, 1) and self.stride == (1, 1, 1) and self.padding == (0, 0, 0) and self.dilation == (1, 1, 1) ) @property def is_centered_submanifold(self) -> bool: """Whether this spec is an odd, stride-1 submanifold footprint.""" return ( self.stride == (1, 1, 1) and self.padding == (0, 0, 0) and self.dilation == (1, 1, 1) and all(value % 2 == 1 for value in self.size) )
# MARK: - helpers def _require_positive(values: Triple, name: str) -> None: if any(value <= 0 for value in values): raise ValueError(f'{name} values must be positive.') def _require_nonnegative(values: Triple, name: str) -> None: if any(value < 0 for value in values): raise ValueError(f'{name} values must be non-negative.')