Source code for mlx_lattice.core.coords.occupancy

from __future__ import annotations

from dataclasses import dataclass

import mlx.core as mx

from mlx_lattice._native import ext
from mlx_lattice.core.coords.validation import validate_coords


[docs] @dataclass(frozen=True, slots=True) class SparseOccupancy: """Downsampled coordinates plus an 8-bit child occupancy mask per row.""" coords: mx.array active_rows: mx.array occupancy: mx.array def __post_init__(self) -> None: validate_coords(self.coords) _validate_active_rows(self.active_rows) if ( self.occupancy.ndim != 1 or self.occupancy.dtype != mx.int32 or self.occupancy.shape[0] != self.coords.shape[0] ): raise ValueError( 'occupancy must have shape (N,) and int32 dtype.' ) @property def capacity(self) -> int: return int(self.coords.shape[0]) @property def active_count(self) -> mx.array: return self.active_rows
[docs] @dataclass(frozen=True, slots=True) class OccupancyExpansion: """Expanded child coordinates with parent row and child-index metadata.""" coords: mx.array active_rows: mx.array parent_rows: mx.array child_indices: mx.array def __post_init__(self) -> None: validate_coords(self.coords) _validate_active_rows(self.active_rows) _validate_row_array(self.parent_rows, 'parent_rows') _validate_row_array(self.child_indices, 'child_indices') if self.parent_rows.shape[0] != self.coords.shape[0]: raise ValueError('parent_rows must have shape (N,).') if self.child_indices.shape[0] != self.coords.shape[0]: raise ValueError('child_indices must have shape (N,).') @property def capacity(self) -> int: return int(self.coords.shape[0]) @property def active_count(self) -> mx.array: return self.active_rows
[docs] def occupancy_downsample( coords: mx.array, active_rows: mx.array | None = None, ) -> SparseOccupancy: """Downsample coordinates and record occupied children for each parent.""" coords = _coords_for_native(coords) active_rows = _active_rows_for(coords, active_rows) return SparseOccupancy(*ext.occupancy_downsample(coords, active_rows))
[docs] def occupancy_expand( coords: mx.array, occupancy: mx.array, active_rows: mx.array | None = None, ) -> OccupancyExpansion: """Expand occupied child coordinates from a parent occupancy mask.""" coords = _coords_for_native(coords) active_rows = _active_rows_for(coords, active_rows) _validate_row_array(occupancy, 'occupancy') if occupancy.shape[0] != coords.shape[0]: raise ValueError('occupancy must have shape (N,).') return OccupancyExpansion( *ext.occupancy_expand(coords, active_rows, occupancy) )
[docs] def child_coords_from_indices( parent_coords: mx.array, child_indices: mx.array, ) -> mx.array: """Compute child coordinates from parent coordinates and child indices.""" parent_coords = _coords_for_native(parent_coords) _validate_row_array(child_indices, 'child_indices') if child_indices.shape[0] != parent_coords.shape[0]: raise ValueError('child_indices must have shape (N,).') return ext.child_coords_from_indices(parent_coords, child_indices)
def _coords_for_native(coords: mx.array) -> mx.array: validate_coords(coords) if coords.dtype == mx.int32: return coords return coords.astype(mx.int32) def _active_rows_for( coords: mx.array, active_rows: mx.array | None, ) -> mx.array: if active_rows is None: return mx.array([coords.shape[0]], dtype=mx.int32) _validate_active_rows(active_rows) return active_rows def _validate_active_rows(active_rows: mx.array) -> None: if active_rows.shape != (1,) or active_rows.dtype != mx.int32: raise ValueError( 'active_rows must have shape (1,) and int32 dtype.' ) def _validate_row_array(value: mx.array, name: str) -> None: if value.ndim != 1 or value.dtype != mx.int32: raise ValueError(f'{name} must have shape (N,) and int32 dtype.')