Source code for mlx_lattice.core.coords.manager

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass, field
from itertools import count

import mlx.core as mx

from mlx_lattice.core.coords.builders import (
    build_generative_relation,
    build_kernel_relation,
    build_submanifold_kernel_relation,
    build_target_kernel_relation,
    build_transposed_kernel_relation,
)
from mlx_lattice.core.coords.set_ops import inverse_map
from mlx_lattice.core.coords.validation import validate_coords
from mlx_lattice.core.relations import KernelRelation, KernelSpec
from mlx_lattice.core.types import Triple, triple

_manager_ids = count()


def _next_manager_id() -> int:
    return next(_manager_ids)


[docs] @dataclass(frozen=True, slots=True) class CoordinateMapKey: """Opaque coordinate identity key owned by a ``CoordinateManager``. Keys are lightweight metadata objects. A key is valid only for the manager whose ``manager_id`` it stores. The ``stride`` field is part of coordinate identity because the same integer rows represent different lattice cells at different spatial strides. """ id: int stride: Triple manager_id: int
[docs] @dataclass(slots=True) class CoordinateManager: """Owns coordinate arrays and caches sparse relations by identity. Managers are normally created by ``SparseTensor``. Reusing a manager and key lets convolutions, pooling, and aligned sparse algebra reuse native relation metadata instead of rebuilding it for every call. Cache entries are keyed by coordinate identity, optional target identity, kernel geometry, and relation kind. """ _manager_id: int = field(default_factory=_next_manager_id, init=False) _next_id: int = 0 _coords: dict[CoordinateMapKey, mx.array] = field(default_factory=dict) _identity_keys: dict[tuple[int, int, Triple], CoordinateMapKey] = field( default_factory=dict ) _default_identity_keys: dict[tuple[int, Triple], CoordinateMapKey] = ( field(default_factory=dict) ) _active_rows: dict[CoordinateMapKey, mx.array] = field( default_factory=dict ) _kernel_relations: dict[ tuple[CoordinateMapKey, CoordinateMapKey | None, KernelSpec, str], KernelRelation, ] = field(default_factory=dict)
[docs] def insert_coords( self, coords: mx.array, stride: int | Sequence[int] = 1, active_rows: mx.array | None = None, ) -> CoordinateMapKey: """Register a coordinate array by object identity and stride. Args: coords: Integer coordinates with shape ``(N, 4)``. stride: Spatial lattice stride for the coordinate rows. active_rows: Optional ``int32`` scalar ``(1,)`` active-row count. Returns: A manager-owned key. Re-inserting the same coordinate object, active-row object, and stride returns the existing key. """ validate_coords(coords) normalized = triple(stride, name='stride') if active_rows is None: default_key = (id(coords), normalized) if default_key in self._default_identity_keys: return self._default_identity_keys[default_key] active = mx.array([coords.shape[0]], dtype=mx.int32) else: active = active_rows cache_key = (id(coords), id(active), normalized) if cache_key in self._identity_keys: return self._identity_keys[cache_key] key = CoordinateMapKey(self._next_id, normalized, self._manager_id) self._next_id += 1 self._coords[key] = coords self._active_rows[key] = active self._identity_keys[cache_key] = key if active_rows is None: self._default_identity_keys[(id(coords), normalized)] = key return key
[docs] def owns(self, key: CoordinateMapKey) -> bool: """Return whether ``key`` belongs to this manager.""" return key.manager_id == self._manager_id and key in self._coords
[docs] def coords(self, key: CoordinateMapKey) -> mx.array: """Return the coordinate array registered for ``key``.""" if not self.owns(key): raise ValueError( 'coordinate key does not belong to this manager.' ) return self._coords[key]
[docs] def active_rows(self, key: CoordinateMapKey) -> mx.array: """Return the active-row scalar registered for ``key``.""" if not self.owns(key): raise ValueError( 'coordinate key does not belong to this manager.' ) return self._active_rows[key]
[docs] def inverse_map( self, source: CoordinateMapKey, target: CoordinateMapKey, ) -> mx.array: """Return row indices that gather ``target`` coordinates from source. Missing target rows are encoded by the native set-operation contract. Both keys must belong to this manager. """ return inverse_map(self.coords(source), self.coords(target))
[docs] def kernel_relation( self, key: CoordinateMapKey, *, kernel_size: int | Sequence[int] = 3, stride: int | Sequence[int] = 1, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build or reuse a forward sparse kernel relation. The relation maps input rows to output rows generated by applying ``kernel_size``, ``stride``, ``padding``, and ``dilation`` to the input coordinate support. The resulting relation is cached for the exact normalized geometry. """ spec = KernelSpec( size=kernel_size, stride=stride, padding=padding, dilation=dilation, ) cache_key = (key, None, spec, 'forward') if cache_key not in self._kernel_relations: self._kernel_relations[cache_key] = build_kernel_relation( self.coords(key), active_rows=self.active_rows(key), kernel_size=spec.size, stride=spec.stride, padding=spec.padding, dilation=spec.dilation, ) return self._kernel_relations[cache_key]
[docs] def generative_relation( self, key: CoordinateMapKey, *, kernel_size: int | Sequence[int] = 2, stride: int | Sequence[int] = 2, ) -> KernelRelation: """Build or reuse a generative transpose-convolution relation. Generative relations create output support from each input coordinate and transpose-convolution stride without requiring an existing target coordinate set. """ spec = KernelSpec(size=kernel_size, stride=stride) cache_key = (key, None, spec, 'generative') if cache_key not in self._kernel_relations: self._kernel_relations[cache_key] = build_generative_relation( self.coords(key), active_rows=self.active_rows(key), kernel_size=spec.size, stride=spec.stride, ) return self._kernel_relations[cache_key]
[docs] def submanifold_kernel_relation( self, key: CoordinateMapKey, *, kernel_size: int | Sequence[int] = 3, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build or reuse a submanifold sparse kernel relation. The relation fixes output support to the input coordinate identity. It is the relation semantic used by submanifold convolution rather than a generic forward relation with output coordinates discarded. """ spec = KernelSpec( size=kernel_size, stride=1, padding=0, dilation=dilation, ) cache_key = (key, key, spec, 'submanifold') if cache_key not in self._kernel_relations: self._kernel_relations[cache_key] = ( build_submanifold_kernel_relation( self.coords(key), active_rows=self.active_rows(key), kernel_size=spec.size, dilation=spec.dilation, ) ) return self._kernel_relations[cache_key]
[docs] def transposed_kernel_relation( self, key: CoordinateMapKey, *, kernel_size: int | Sequence[int] = 2, stride: int | Sequence[int] = 2, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build or reuse a transpose-convolution relation. The relation represents the sparse transpose of the forward convolution geometry and carries output coordinates plus CSR views for native execution. """ spec = KernelSpec( size=kernel_size, stride=stride, padding=padding, dilation=dilation, ) cache_key = (key, None, spec, 'transpose') if cache_key not in self._kernel_relations: self._kernel_relations[cache_key] = ( build_transposed_kernel_relation( self.coords(key), active_rows=self.active_rows(key), kernel_size=spec.size, stride=spec.stride, padding=spec.padding, dilation=spec.dilation, ) ) return self._kernel_relations[cache_key]
[docs] def target_kernel_relation( self, key: CoordinateMapKey, target_key: CoordinateMapKey, *, kernel_size: int | Sequence[int] = 3, stride: int | Sequence[int] = 1, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build or reuse a relation from one key to explicit target coords. ``target_key`` fixes the output support. Only input rows that connect to those target rows through the kernel geometry contribute edges. """ if not self.owns(key) or not self.owns(target_key): raise ValueError( 'input and target coordinate keys must belong to this manager.' ) spec = KernelSpec( size=kernel_size, stride=stride, padding=padding, dilation=dilation, ) cache_key = (key, target_key, spec, 'target') if cache_key not in self._kernel_relations: self._kernel_relations[cache_key] = ( build_target_kernel_relation( self.coords(key), self.coords(target_key), active_rows=self.active_rows(key), target_active_rows=self.active_rows(target_key), kernel_size=spec.size, stride=spec.stride, padding=spec.padding, dilation=spec.dilation, ) ) return self._kernel_relations[cache_key]