Source code for mlx_lattice.core.coords.set_ops
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
import mlx.core as mx
from mlx_lattice._native import ext
from mlx_lattice.core.coords.validation import (
validate_coord_pair,
validate_coords,
)
from mlx_lattice.core.types import Triple, triple
[docs]
@dataclass(frozen=True, slots=True)
class CoordinateSet:
"""Capacity coordinate buffer with a lazy active count."""
coords: mx.array
active_rows: mx.array
def __post_init__(self) -> None:
validate_coords(self.coords)
if (
self.active_rows.shape != (1,)
or self.active_rows.dtype != mx.int32
):
raise ValueError(
'active_rows must have shape (1,) and int32 dtype.'
)
@property
def capacity(self) -> int:
return int(self.coords.shape[0])
@property
def active_count(self) -> mx.array:
return self.active_rows
[docs]
def downsample_coords(
coords: mx.array,
stride: int | Sequence[int] = 2,
) -> CoordinateSet:
"""Downsample coordinates by integer spatial stride."""
validate_coords(coords)
step = triple(stride, name='stride')
_require_positive(step, 'stride')
return CoordinateSet(*ext.downsample_coords(coords, step))
[docs]
def union_coords(lhs: mx.array, rhs: mx.array) -> CoordinateSet:
"""Return the coordinate-set union of two coordinate arrays."""
validate_coord_pair(lhs, rhs)
return CoordinateSet(*ext.union_coords(lhs, rhs))
[docs]
def intersection_coords(lhs: mx.array, rhs: mx.array) -> CoordinateSet:
"""Return coordinates present in both input coordinate arrays."""
validate_coord_pair(lhs, rhs)
return CoordinateSet(*ext.intersection_coords(lhs, rhs))
[docs]
def lookup_coords(coords: mx.array, queries: mx.array) -> mx.array:
"""Map query coordinates to row indices in ``coords``.
Missing query rows are encoded as ``-1``.
"""
validate_coord_pair(coords, queries, rhs_name='queries')
return ext.lookup_coords(coords, queries)
[docs]
def contains_coords(coords: mx.array, queries: mx.array) -> mx.array:
"""Return a boolean mask indicating which queries exist in ``coords``."""
return lookup_coords(coords, queries) >= 0
[docs]
def inverse_map(source: mx.array, target: mx.array) -> mx.array:
"""Return row indices that gather ``target`` rows from ``source``."""
return lookup_coords(source, target)
def _require_positive(values: Triple, name: str) -> None:
if any(value <= 0 for value in values):
raise ValueError(f'{name} values must be positive.')