Source code for mlx_lattice.core.coords.ordering
from __future__ import annotations
from dataclasses import dataclass
import mlx.core as mx
from mlx_lattice._native import ext
from mlx_lattice.core.coords.validation import validate_coords
[docs]
@dataclass(frozen=True, slots=True)
class CoordinateOrdering:
"""Sorted coordinate rows and the row permutation used to produce them."""
coords: mx.array
order: mx.array
inverse_rows: mx.array
def __post_init__(self) -> None:
validate_coords(self.coords)
if (
self.order.ndim != 1
or self.order.dtype != mx.int32
or self.order.shape[0] != self.coords.shape[0]
):
raise ValueError('order must have shape (N,) and int32 dtype.')
if (
self.inverse_rows.ndim != 1
or self.inverse_rows.dtype != mx.int32
or self.inverse_rows.shape[0] != self.coords.shape[0]
):
raise ValueError(
'inverse_rows must have shape (N,) and int32 dtype.'
)
[docs]
def morton_codes(coords: mx.array) -> mx.array:
"""Return Morton/Z-order codes for batched sparse coordinates."""
validate_coords(coords)
if coords.dtype == mx.int32:
return ext.morton_codes(coords)
return ext.morton_codes(coords.astype(mx.int32))
[docs]
def morton_order(coords: mx.array) -> mx.array:
"""Return row indices that sort coordinates by Morton order."""
return mx.argsort(morton_codes(coords)).astype(mx.int32)
[docs]
def morton_sort_coords(coords: mx.array) -> CoordinateOrdering:
"""Return coordinates sorted by Morton order and the ordering indices."""
order = morton_order(coords)
sorted_coords = mx.take(coords, order, axis=0)
inverse = mx.argsort(order).astype(mx.int32)
return CoordinateOrdering(sorted_coords, order, inverse)