Source code for mlx_lattice.core.coords.builders

from __future__ import annotations

import math
from collections.abc import Sequence

import mlx.core as mx

from mlx_lattice._native import ext
from mlx_lattice.core.coords.validation import validate_coords
from mlx_lattice.core.relations import (
    KernelRelation,
    KernelSpec,
    NeighborRelation,
    RelationImplicitGemmView,
)
from mlx_lattice.core.relations.views import RelationKind
from mlx_lattice.core.types import Triple, triple

type NativeKernelRelation = tuple[
    mx.array,
    mx.array,
    mx.array,
    mx.array,
    mx.array,
    mx.array,
    mx.array,
    mx.array,
    mx.array,
    mx.array,
]
type NativeNeighborRelation = tuple[
    mx.array,
    mx.array,
    mx.array,
    mx.array,
    mx.array,
    mx.array,
]
type NativeImplicitGemmView = tuple[mx.array, mx.array]


[docs] def build_relation_implicit_gemm_view( relation: KernelRelation, ) -> RelationImplicitGemmView: """Build the dense output-to-input map used by implicit-GEMM kernels.""" contract = relation.contract if contract.kind not in ('forward', 'target', 'submanifold'): raise ValueError( 'implicit GEMM view currently supports forward, target, and ' 'submanifold relations.' ) if ( contract.source_coords is None or contract.source_active_rows is None or contract.out_coords is None ): raise ValueError( 'kernel relation is missing coordinate context for implicit GEMM.' ) output_active = ( contract.target_active_rows if contract.kind in ('target', 'submanifold') else contract.out_count ) if output_active is None: raise ValueError( 'kernel relation is missing output activity for implicit GEMM.' ) offsets = _offset_array(contract.kernel_offsets) out_in_map, row_masks = ext.build_relation_implicit_gemm_view( contract.source_coords, contract.source_active_rows, contract.out_coords, output_active, offsets, contract.stride, contract.padding, ) return RelationImplicitGemmView(out_in_map, row_masks=row_masks)
[docs] def build_target_kernel_relation( coords: mx.array, target_coords: mx.array, *, active_rows: mx.array | None = None, target_active_rows: mx.array | None = None, kernel_size: int | Sequence[int] = 3, stride: int | Sequence[int] = 1, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build a sparse kernel relation from source coords to target coords.""" validate_coords(coords) validate_coords(target_coords) _require_matching_coord_dtype(coords, target_coords) spec = KernelSpec( size=kernel_size, stride=stride, padding=padding, dilation=dilation, ) offsets = kernel_offsets(spec.size, spec.dilation) source_active = _active_rows(active_rows, coords) target_active = _active_rows(target_active_rows, target_coords) native = ext.build_target_kernel_relation( coords, source_active, target_coords, target_active, spec.size, spec.stride, spec.padding, spec.dilation, ) return _kernel_relation_from_native( native, offsets=offsets, in_capacity=int(coords.shape[0]), source_coords=coords, source_active_rows=source_active, target_coords=target_coords, target_active_rows=target_active, stride=spec.stride, padding=spec.padding, kind='target', )
[docs] def kernel_offsets( kernel_size: int | Sequence[int], dilation: int | Sequence[int] = 1, ) -> tuple[Triple, ...]: """Enumerate spatial offsets for a dense 3D kernel footprint.""" kernel = triple(kernel_size, name='kernel_size') rate = triple(dilation, name='dilation') _require_positive(kernel, 'kernel_size') _require_positive(rate, 'dilation') axes = [] for size in kernel: if size % 2 == 1: radius = size // 2 axes.append(range(-radius, radius + 1)) else: axes.append(range(size)) return tuple( (int(x * rate[0]), int(y * rate[1]), int(z * rate[2])) for x in axes[0] for y in axes[1] for z in axes[2] )
[docs] def build_kernel_relation( coords: mx.array, active_rows: mx.array | None = None, kernel_size: int | Sequence[int] = 3, stride: int | Sequence[int] = 1, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build a forward sparse convolution/pooling relation.""" validate_coords(coords) spec = KernelSpec( size=kernel_size, stride=stride, padding=padding, dilation=dilation, ) offsets = kernel_offsets(spec.size, spec.dilation) source_active = _active_rows(active_rows, coords) native = ext.build_kernel_relation( coords, source_active, spec.size, spec.stride, spec.padding, spec.dilation, ) return _kernel_relation_from_native( native, offsets=offsets, in_capacity=int(coords.shape[0]), source_coords=coords, source_active_rows=source_active, stride=spec.stride, padding=spec.padding, kind='forward', )
[docs] def build_submanifold_kernel_relation( coords: mx.array, active_rows: mx.array | None = None, kernel_size: int | Sequence[int] = 3, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build a submanifold relation whose output support is ``coords``.""" validate_coords(coords) spec = KernelSpec( size=kernel_size, stride=1, padding=0, dilation=dilation, ) if not spec.is_centered_submanifold: raise ValueError( 'submanifold relations require odd kernels, stride=1, ' 'padding=0, and positive dilation.' ) offsets = kernel_offsets(spec.size, spec.dilation) source_active = _active_rows(active_rows, coords) native = ext.build_submanifold_kernel_relation( coords, source_active, spec.size, spec.dilation, ) return _kernel_relation_from_native( native, offsets=offsets, in_capacity=int(coords.shape[0]), out_coords=coords, source_coords=coords, source_active_rows=source_active, target_coords=coords, target_active_rows=source_active, stride=spec.stride, padding=spec.padding, kind='submanifold', )
[docs] def build_generative_relation( coords: mx.array, active_rows: mx.array | None = None, kernel_size: int | Sequence[int] = 2, stride: int | Sequence[int] = 2, ) -> KernelRelation: """Build a generative transpose-convolution relation.""" validate_coords(coords) kernel = triple(kernel_size, name='kernel_size') step = triple(stride, name='stride') _require_positive(kernel, 'kernel_size') _require_positive(step, 'stride') offsets = kernel_offsets(kernel) source_active = _active_rows(active_rows, coords) native = ext.build_generative_relation( coords, source_active, kernel, step, ) return _kernel_relation_from_native( native, offsets=offsets, in_capacity=int(coords.shape[0]), source_coords=coords, source_active_rows=source_active, stride=step, kind='generative', )
[docs] def build_transposed_kernel_relation( coords: mx.array, active_rows: mx.array | None = None, kernel_size: int | Sequence[int] = 2, stride: int | Sequence[int] = 2, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build a sparse transpose-convolution relation.""" validate_coords(coords) kernel = triple(kernel_size, name='kernel_size') step = triple(stride, name='stride') pad = triple(padding, name='padding') rate = triple(dilation, name='dilation') _require_positive(kernel, 'kernel_size') _require_positive(step, 'stride') _require_nonnegative(pad, 'padding') _require_positive(rate, 'dilation') offsets = kernel_offsets(kernel, rate) source_active = _active_rows(active_rows, coords) native = ext.build_transposed_kernel_relation( coords, source_active, kernel, step, pad, rate, ) return _kernel_relation_from_native( native, offsets=offsets, in_capacity=int(coords.shape[0]), source_coords=coords, source_active_rows=source_active, stride=step, padding=pad, kind='transposed', )
[docs] def build_knn_relation( source_coords: mx.array, query_coords: mx.array | None = None, *, source_active_rows: mx.array | None = None, query_active_rows: mx.array | None = None, k: int, ) -> NeighborRelation: """Build a k-nearest-neighbor relation between source and query coords.""" query_coords = source_coords if query_coords is None else query_coords source_active_rows = _active_rows(source_active_rows, source_coords) query_active_rows = ( source_active_rows if query_active_rows is None and query_coords is source_coords else _active_rows(query_active_rows, query_coords) ) validate_coords(source_coords) validate_coords(query_coords) _require_matching_coord_dtype(source_coords, query_coords) neighbor_count = _positive_int(k, 'k') native = ext.build_knn_relation( source_coords, source_active_rows, query_coords, query_active_rows, neighbor_count, ) return _neighbor_relation_from_native( native, query_capacity=int(query_coords.shape[0]), source_capacity=int(source_coords.shape[0]), max_neighbors=neighbor_count, )
[docs] def build_radius_relation( source_coords: mx.array, query_coords: mx.array | None = None, *, source_active_rows: mx.array | None = None, query_active_rows: mx.array | None = None, radius: float, max_neighbors: int | None = None, ) -> NeighborRelation: """Build a radius-neighborhood relation between source and query coords.""" query_coords = source_coords if query_coords is None else query_coords source_active_rows = _active_rows(source_active_rows, source_coords) query_active_rows = ( source_active_rows if query_active_rows is None and query_coords is source_coords else _active_rows(query_active_rows, query_coords) ) validate_coords(source_coords) validate_coords(query_coords) _require_matching_coord_dtype(source_coords, query_coords) radius_value = _nonnegative_float(radius, 'radius') neighbor_count = ( 0 if max_neighbors is None else _positive_int(max_neighbors, 'max_neighbors') ) native = ext.build_radius_relation( source_coords, source_active_rows, query_coords, query_active_rows, radius_value, neighbor_count, ) return _neighbor_relation_from_native( native, query_capacity=int(query_coords.shape[0]), source_capacity=int(source_coords.shape[0]), max_neighbors=( _radius_neighbor_capacity(radius_value) if max_neighbors is None else neighbor_count ), )
# MARK: - views def _kernel_relation_from_native( native: NativeKernelRelation, *, offsets: tuple[Triple, ...], in_capacity: int, out_coords: mx.array | None = None, source_coords: mx.array, source_active_rows: mx.array, target_coords: mx.array | None = None, target_active_rows: mx.array | None = None, stride: Triple = (1, 1, 1), padding: Triple = (0, 0, 0), kind: RelationKind = 'forward', ) -> KernelRelation: ( in_rows, out_rows, kernel_ids, row_offsets, native_out_coords, counts, in_row_offsets, in_edge_ids, kernel_row_offsets, kernel_edge_ids, ) = native relation_out_coords = ( native_out_coords if out_coords is None else out_coords ) return KernelRelation( in_rows, out_rows, kernel_ids, row_offsets=row_offsets, counts=counts, in_row_offsets=in_row_offsets, in_edge_ids=in_edge_ids, kernel_row_offsets=kernel_row_offsets, kernel_edge_ids=kernel_edge_ids, kernel_offsets=offsets, out_coords=relation_out_coords, n_in_capacity=in_capacity, n_out_capacity=int(relation_out_coords.shape[0]), n_kernels=len(offsets), source_coords=source_coords, source_active_rows=source_active_rows, target_coords=target_coords, target_active_rows=target_active_rows, stride=stride, padding=padding, kind=kind, ) def _neighbor_relation_from_native( native: NativeNeighborRelation, *, query_capacity: int, source_capacity: int, max_neighbors: int, ) -> NeighborRelation: ( query_rows, source_rows, neighbor_ids, distances, row_offsets, counts, ) = native return NeighborRelation( query_rows, source_rows, neighbor_ids, distances, row_offsets=row_offsets, counts=counts, n_query_capacity=query_capacity, n_source_capacity=source_capacity, max_neighbors=max_neighbors, ) # MARK: - helpers def _offset_array(offsets: tuple[Triple, ...]) -> mx.array: return mx.array(offsets, dtype=mx.int32) def _require_positive(values: Triple, name: str) -> None: if any(value <= 0 for value in values): raise ValueError(f'{name} values must be positive.') def _require_nonnegative(values: Triple, name: str) -> None: if any(value < 0 for value in values): raise ValueError(f'{name} values must be non-negative.') def _require_matching_coord_dtype(lhs: mx.array, rhs: mx.array) -> None: if lhs.dtype != rhs.dtype: raise ValueError('coordinate arrays must have matching dtype.') def _positive_int(value: int, name: str) -> int: out = int(value) if out <= 0: raise ValueError(f'{name} must be positive.') return out def _nonnegative_float(value: float, name: str) -> float: out = float(value) if out < 0: raise ValueError(f'{name} must be non-negative.') return out def _active_rows(value: mx.array | None, coords: mx.array) -> mx.array: if value is not None: if value.shape != (1,) or value.dtype != mx.int32: raise ValueError( 'active_rows must have shape (1,) and int32 dtype.' ) return value return mx.array([coords.shape[0]], dtype=mx.int32) def _radius_neighbor_capacity(radius: float) -> int: limit = math.ceil(radius) radius_squared = radius * radius count = 0 for dz in range(-limit, limit + 1): for dy in range(-limit, limit + 1): for dx in range(-limit, limit + 1): if dx * dx + dy * dy + dz * dz <= radius_squared: count += 1 return max(count, 1)