Source code for mlx_lattice.ops.tensor

from __future__ import annotations

from collections.abc import Sequence
from typing import Literal, cast

import mlx.core as mx

from mlx_lattice.core.coords import SparseAlignment, build_sparse_alignment
from mlx_lattice.core.tensor import SparseTensor
from mlx_lattice.core.types import triple

type SparseJoin = Literal['inner', 'left', 'right', 'outer']


[docs] def sparse_collate( coords: Sequence[mx.array], feats: Sequence[mx.array], *, stride: int | Sequence[int] = 1, ) -> SparseTensor: """Collate unbatched sparse coordinates/features into one tensor. Input coordinates must have shape ``(N, 3)``; the function prepends a batch column and records ``batch_counts`` metadata. Feature arrays must each have shape ``(N_i, C)`` and all coordinate arrays must share dtype. """ if len(coords) != len(feats): raise ValueError('coords and feats batch sizes must match.') if not coords: raise ValueError('expected at least one sparse tensor batch.') batched_coords = [] batched_feats = [] coord_dtype = coords[0].dtype for batch, (coord_rows, feat_rows) in enumerate( zip(coords, feats, strict=True) ): if coord_rows.ndim != 2 or coord_rows.shape[1] != 3: raise ValueError('collated coords must have shape (N, 3).') if coord_rows.dtype not in (mx.int32, mx.int64): raise ValueError('collated coords must be int32 or int64.') if coord_rows.dtype != coord_dtype: raise ValueError('collated coords must share a dtype.') if feat_rows.ndim != 2: raise ValueError('collated feats must have shape (N, C).') if coord_rows.shape[0] != feat_rows.shape[0]: raise ValueError( 'collated coords and feats must have matching rows.' ) batch_col = mx.full( (coord_rows.shape[0], 1), batch, dtype=coord_rows.dtype ) batched_coords.append( mx.concatenate([batch_col, coord_rows], axis=1) ) batched_feats.append(feat_rows) return SparseTensor( mx.concatenate(batched_coords, axis=0), mx.concatenate(batched_feats, axis=0), stride=triple(stride, name='stride'), batch_counts=tuple(int(values.shape[0]) for values in coords), )
[docs] def align_sparse( lhs: SparseTensor, rhs: SparseTensor, *, join: SparseJoin = 'inner', ) -> SparseAlignment: """Build coordinate-value alignment metadata for two sparse tensors. Both tensors must have the same sparse stride and coordinate dtype. The returned row maps can be used to gather feature rows under ``inner``, ``left``, ``right``, or ``outer`` sparse support semantics. """ _require_compatible_sparse_tensors(lhs, rhs) return build_sparse_alignment( lhs.coords, lhs.active_rows, rhs.coords, rhs.active_rows, join=join, )
[docs] def gather_aligned_features( x: SparseTensor, rows: mx.array, *, fill: float = 0.0, ) -> mx.array: """Gather sparse features with ``-1`` rows filled by ``fill``. ``rows`` must be an ``int32`` vector. Non-negative entries gather ``x.feats``; negative entries produce the scalar fill value in every channel. """ if rows.ndim != 1 or rows.dtype != mx.int32: raise ValueError('rows must have shape (N,) and int32 dtype.') clipped = mx.maximum(rows, 0) gathered = mx.take(x.feats, clipped, axis=0) valid = (rows >= 0).astype(x.feats.dtype)[:, None] if fill == 0.0: return gathered * valid fill_value = mx.array(float(fill), dtype=x.feats.dtype) return mx.where(valid.astype(mx.bool_), gathered, fill_value)
[docs] def sparse_binary_op( lhs: SparseTensor, rhs: SparseTensor, op: Literal['add', 'sub', 'mul', 'maximum', 'minimum'], *, join: SparseJoin = 'outer', lhs_fill: float = 0.0, rhs_fill: float = 0.0, ) -> SparseTensor: """Apply an elementwise binary op after coordinate-value alignment. If operands share coordinate identity, the operation runs directly on feature matrices. Otherwise it builds a sparse alignment and fills missing rows according to ``lhs_fill`` and ``rhs_fill``. """ _require_compatible_sparse_tensors(lhs, rhs) if lhs.channels != rhs.channels: raise ValueError( 'sparse binary operands must have matching channels.' ) if lhs.same_coords(rhs): return lhs.replace(feats=_apply_binary_op(lhs.feats, rhs.feats, op)) alignment = align_sparse(lhs, rhs, join=join) lhs_feats = gather_aligned_features( lhs, alignment.lhs_rows, fill=lhs_fill ) rhs_feats = gather_aligned_features( rhs, alignment.rhs_rows, fill=rhs_fill ) return SparseTensor( alignment.coords, _apply_binary_op(lhs_feats, rhs_feats, op), lhs.stride, coord_manager=lhs.coord_manager, active_rows=alignment.active_rows, )
[docs] def sparse_add( lhs: SparseTensor, rhs: SparseTensor, *, join: SparseJoin = 'outer', ) -> SparseTensor: """Add sparse tensors after coordinate alignment.""" return sparse_binary_op(lhs, rhs, 'add', join=join)
[docs] def sparse_sub( lhs: SparseTensor, rhs: SparseTensor, *, join: SparseJoin = 'outer', ) -> SparseTensor: """Subtract sparse tensors after coordinate alignment.""" return sparse_binary_op(lhs, rhs, 'sub', join=join)
[docs] def sparse_mul( lhs: SparseTensor, rhs: SparseTensor, *, join: SparseJoin = 'inner', ) -> SparseTensor: """Multiply sparse tensors after coordinate alignment.""" return sparse_binary_op(lhs, rhs, 'mul', join=join)
[docs] def sparse_maximum( lhs: SparseTensor, rhs: SparseTensor, *, join: SparseJoin = 'inner', ) -> SparseTensor: """Take elementwise maximum after coordinate alignment.""" return sparse_binary_op(lhs, rhs, 'maximum', join=join)
[docs] def sparse_minimum( lhs: SparseTensor, rhs: SparseTensor, *, join: SparseJoin = 'inner', ) -> SparseTensor: """Take elementwise minimum after coordinate alignment.""" return sparse_binary_op(lhs, rhs, 'minimum', join=join)
[docs] def cat( tensors: Sequence[SparseTensor], *, join: SparseJoin = 'inner', ) -> SparseTensor: """Concatenate sparse features, aligning coordinates when needed. Multiple tensors are supported when they already share coordinate identity. Value-aligned concatenation currently accepts two tensors because the join support and fill behavior must be unambiguous. """ if not tensors: raise ValueError('expected at least one sparse tensor.') first = tensors[0] if all(first.same_coords(tensor) for tensor in tensors[1:]): return first.replace( feats=mx.concatenate( [tensor.feats for tensor in tensors], axis=1 ) ) if len(tensors) != 2: raise ValueError( 'value-aligned cat currently supports exactly two sparse tensors.' ) return sparse_cat_aligned(first, tensors[1], join=join)
[docs] def sparse_cat_aligned( lhs: SparseTensor, rhs: SparseTensor, *, join: SparseJoin = 'inner', ) -> SparseTensor: """Concatenate two sparse tensors after coordinate-value alignment.""" _require_compatible_sparse_tensors(lhs, rhs) if lhs.same_coords(rhs): return lhs.replace( feats=mx.concatenate([lhs.feats, rhs.feats], axis=1) ) alignment = align_sparse(lhs, rhs, join=join) lhs_feats = gather_aligned_features(lhs, alignment.lhs_rows) rhs_feats = gather_aligned_features(rhs, alignment.rhs_rows) return SparseTensor( alignment.coords, mx.concatenate([lhs_feats, rhs_feats], axis=1), lhs.stride, coord_manager=lhs.coord_manager, active_rows=alignment.active_rows, )
[docs] def crop( x: SparseTensor, *, min_coord: Sequence[int], max_coord: Sequence[int], ) -> SparseTensor: """Crop a sparse tensor to an inclusive spatial coordinate box. ``min_coord`` and ``max_coord`` are spatial triples over ``x, y, z``. The batch column is not filtered. """ lower = _spatial_bound(min_coord, 'min_coord') upper = _spatial_bound(max_coord, 'max_coord') if any(lo > hi for lo, hi in zip(lower, upper, strict=True)): raise ValueError('min_coord must be <= max_coord.') spatial = x.coords[:, 1:] active = mx.arange(x.capacity, dtype=mx.int32) < x.active_rows[0] mask = active for axis, (lo, hi) in enumerate(zip(lower, upper, strict=True)): values = spatial[:, axis] mask = mask & (values >= lo) & (values <= hi) return prune_mask(x, mask)
[docs] def replace_feature(x: SparseTensor, feats: mx.array) -> SparseTensor: """Return ``x`` with its feature matrix replaced. The replacement must satisfy the ``SparseTensor`` row contract: same row count as ``x.coords`` and two-dimensional ``(N, C_new)`` shape. """ return x.replace(feats=feats)
def _apply_binary_op(lhs: mx.array, rhs: mx.array, op: str) -> mx.array: if op == 'add': return lhs + rhs if op == 'sub': return lhs - rhs if op == 'mul': return lhs * rhs if op == 'maximum': return mx.maximum(lhs, rhs) if op == 'minimum': return mx.minimum(lhs, rhs) raise ValueError(f'unknown sparse binary op: {op}.') def _require_compatible_sparse_tensors( lhs: SparseTensor, rhs: SparseTensor, ) -> None: if lhs.stride != rhs.stride: raise ValueError('sparse tensor strides must match.') if lhs.coords.dtype != rhs.coords.dtype: raise ValueError('sparse tensor coordinate dtypes must match.') def _spatial_bound(value: Sequence[int], name: str) -> tuple[int, int, int]: raw = tuple(int(item) for item in value) if len(raw) != 3: raise ValueError(f'{name} must contain exactly 3 values.') return (raw[0], raw[1], raw[2])
[docs] def prune(x: SparseTensor, rows: mx.array) -> SparseTensor: """Keep the specified sparse rows. ``rows`` is a one-dimensional index vector. The result contains gathered coordinates and features and receives a new coordinate key in the same manager. """ if rows.ndim != 1: raise ValueError('rows must have shape (M,).') rows = rows.astype(mx.int32) return SparseTensor( mx.take(x.coords, rows, axis=0), mx.take(x.feats, rows, axis=0), x.stride, coord_manager=x.coord_manager, )
[docs] def prune_mask(x: SparseTensor, mask: mx.array) -> SparseTensor: """Keep sparse rows selected by a boolean mask. ``mask`` must have length equal to the sparse buffer capacity. Selected rows are gathered into a compact sparse tensor. """ if mask.ndim != 1: raise ValueError('mask must have shape (N,).') if mask.shape[0] != x.capacity: raise ValueError('mask must match the sparse tensor capacity.') if mask.dtype != mx.bool_: raise ValueError('mask must be boolean.') ordering = mx.argsort(mask.astype(mx.int32)).astype(mx.int32) count = int(cast('int', mx.sum(mask).tolist())) rows = mx.array([], dtype=mx.int32) if count == 0 else ordering[-count:] return prune(x, rows)
[docs] def topk_rows( x: SparseTensor, counts: Sequence[int], *, rho: float = 1.0, ) -> mx.array: """Select top-scoring rows per batch using the first feature channel. ``counts`` gives the requested rows per batch and ``rho`` scales those counts before clamping to the available rows in each batch. Requires ``x.batch_counts``. """ if rho <= 0: raise ValueError('rho must be positive.') row_counts = x.batch_counts if row_counts is None: raise ValueError('batch_counts metadata is required for topk_rows.') if len(counts) != len(row_counts): raise ValueError('counts must match the batch count.') selected = [] start = 0 for keep, row_count in zip(counts, row_counts, strict=True): stop = start + int(row_count) if stop > x.capacity: raise ValueError( 'batch row counts exceed sparse tensor row count.' ) rows = mx.arange(start, stop, dtype=mx.int32) start = stop k = min(int(int(keep) * rho), int(rows.shape[0])) if k <= 0: continue scores = mx.take(x.feats[:, 0], rows, axis=0) order = mx.argsort(scores) selected.append(mx.take(rows, order[-k:], axis=0)) if start != x.capacity: raise ValueError('counts must cover all sparse tensor rows.') if not selected: return mx.array([], dtype=mx.int32) return mx.concatenate(selected, axis=0)