Source code for mlx_lattice.core.coords.quantization

from __future__ import annotations

from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal

import mlx.core as mx

from mlx_lattice._native import ext
from mlx_lattice.core.coords.validation import validate_coords

type VoxelReduction = Literal['sum', 'mean']
type PointVoxelInterpolation = Literal['nearest', 'linear']


[docs] @dataclass(frozen=True, slots=True) class SparseQuantization: """Sparse voxel coordinates plus point-to-voxel metadata. ``coords`` stores unique voxel coordinates in ``(batch, x, y, z)`` order. ``inverse_rows`` maps each input point row to its voxel row, and ``counts`` stores the number of active point rows accumulated into each voxel. """ coords: mx.array active_rows: mx.array inverse_rows: mx.array counts: mx.array def __post_init__(self) -> None: validate_coords(self.coords) if self.coords.dtype != mx.int32: raise ValueError('quantized coords must be int32.') if ( self.active_rows.shape != (1,) or self.active_rows.dtype != mx.int32 ): raise ValueError( 'active_rows must have shape (1,) and int32 dtype.' ) if ( self.inverse_rows.shape != (self.coords.shape[0],) or self.inverse_rows.dtype != mx.int32 ): raise ValueError( 'inverse_rows must have shape (N,) and int32 dtype.' ) if ( self.counts.shape != (self.coords.shape[0],) or self.counts.dtype != mx.int32 ): raise ValueError('counts must have shape (N,) and int32 dtype.') @property def capacity(self) -> int: return int(self.coords.shape[0]) @property def active_count(self) -> mx.array: return self.active_rows
[docs] @dataclass(frozen=True, slots=True) class PointVoxelMap: """Fixed-width point-to-voxel interpolation rows and weights. ``rows`` and ``weights`` both have shape ``(N, 8)``. Linear interpolation may use up to eight neighboring voxel rows per point; nearest interpolation uses one non-zero contribution. """ rows: mx.array weights: mx.array def __post_init__(self) -> None: if ( self.rows.ndim != 2 or self.rows.shape[1] != 8 or self.rows.dtype != mx.int32 ): raise ValueError('rows must have shape (N, 8) and int32 dtype.') if self.weights.shape != self.rows.shape: raise ValueError('weights must have the same shape as rows.') if self.weights.dtype != mx.float32: raise ValueError('weights must be float32.') @property def point_count(self) -> int: return int(self.rows.shape[0])
[docs] def sparse_quantize( points: mx.array, voxel_size: float | Sequence[float] = 1.0, *, batch_indices: mx.array | None = None, origin: float | Sequence[float] = 0.0, active_rows: mx.array | None = None, ) -> SparseQuantization: """Voxelize floating-point points into sparse integer coordinates. Points have shape ``(N, 3)`` and dtype ``float32``. Optional ``batch_indices`` assign points to batches; omitted batches default to zero. The result includes voxel coordinates, active row count, inverse point-to-voxel rows, and per-voxel counts. """ _validate_points(points) batches = _batch_indices(batch_indices, points.shape[0]) point_rows = _active_rows(active_rows, points.shape[0]) native = ext.sparse_quantize( points, batches, point_rows, _float_triple(voxel_size, 'voxel_size'), _float_triple(origin, 'origin'), ) return SparseQuantization(*native)
[docs] def voxelize_features( feats: mx.array, quantization: SparseQuantization, *, active_rows: mx.array | None = None, reduction: VoxelReduction = 'mean', ) -> mx.array: """Aggregate point features into voxels using sparse quantization data. ``feats`` must be ``float32`` with one row per original point. ``sum`` accumulates point rows directly; ``mean`` divides by each voxel count. """ if feats.ndim != 2: raise ValueError('feats must have shape (N, C).') if feats.dtype != mx.float32: raise ValueError('voxelize currently supports float32 features.') if feats.shape[0] != quantization.capacity: raise ValueError('feats and quantization metadata must share rows.') point_rows = _active_rows(active_rows, feats.shape[0]) return ext.voxelize_features( feats, quantization.inverse_rows, quantization.counts, point_rows, _validate_reduction(reduction), )
[docs] def build_point_voxel_map( points: mx.array, voxel_coords: mx.array, voxel_active_rows: mx.array, voxel_size: float | Sequence[float] = 1.0, *, batch_indices: mx.array | None = None, point_active_rows: mx.array | None = None, origin: float | Sequence[float] = 0.0, interpolation: PointVoxelInterpolation = 'linear', ) -> PointVoxelMap: """Build fixed-width interpolation rows from points to voxel centers. The map can be reused to sample multiple voxel feature arrays as long as point geometry, batch indices, voxel coordinates, voxel size, and origin are unchanged. """ _validate_points(points) validate_coords(voxel_coords) if voxel_coords.dtype != mx.int32: raise ValueError('voxel_coords must be int32.') batches = _batch_indices(batch_indices, points.shape[0]) point_rows = _active_rows(point_active_rows, points.shape[0]) if ( voxel_active_rows.shape != (1,) or voxel_active_rows.dtype != mx.int32 ): raise ValueError( 'voxel_active_rows must have shape (1,) and int32 dtype.' ) native = ext.build_point_voxel_map( points, batches, point_rows, voxel_coords, voxel_active_rows, _float_triple(voxel_size, 'voxel_size'), _float_triple(origin, 'origin'), _validate_interpolation(interpolation), ) return PointVoxelMap(*native)
[docs] def interpolate_point_features( voxel_feats: mx.array, point_voxel_map: PointVoxelMap, ) -> mx.array: """Interpolate voxel features back to point rows. ``voxel_feats`` must be ``float32`` with shape ``(N_voxels, C)``. The returned dense point feature array has shape ``(N_points, C)``. """ if voxel_feats.ndim != 2: raise ValueError('voxel_feats must have shape (N, C).') if voxel_feats.dtype != mx.float32: raise ValueError('point interpolation currently supports float32.') return ext.interpolate_point_features( voxel_feats, point_voxel_map.rows, point_voxel_map.weights, )
def _validate_points(points: mx.array) -> None: if points.ndim != 2 or points.shape[1] != 3: raise ValueError('points must have shape (N, 3).') if points.dtype != mx.float32: raise ValueError('points must be float32.') def _batch_indices(value: mx.array | None, rows: int) -> mx.array: if value is None: return mx.zeros((rows,), dtype=mx.int32) if value.shape != (rows,) or value.dtype != mx.int32: raise ValueError( 'batch_indices must have shape (N,) and int32 dtype.' ) return value def _active_rows(value: mx.array | None, rows: int) -> mx.array: if value is None: return mx.array([rows], dtype=mx.int32) if value.shape != (1,) or value.dtype != mx.int32: raise ValueError( 'active_rows must have shape (1,) and int32 dtype.' ) return value def _float_triple( value: float | Sequence[float], name: str, ) -> tuple[float, float, float]: if isinstance(value, int | float): return (float(value), float(value), float(value)) else: raw = tuple(float(item) for item in value) if len(raw) != 3: raise ValueError(f'{name} must contain exactly 3 values.') values = (raw[0], raw[1], raw[2]) if name == 'voxel_size' and any(item <= 0.0 for item in values): raise ValueError('voxel_size values must be positive.') return values def _validate_reduction(value: str) -> VoxelReduction: if value == 'sum': return 'sum' if value == 'mean': return 'mean' raise ValueError("reduction must be 'sum' or 'mean'.") def _validate_interpolation(value: str) -> PointVoxelInterpolation: if value == 'nearest': return 'nearest' if value == 'linear': return 'linear' raise ValueError("interpolation must be 'nearest' or 'linear'.")