Source code for mlx_lattice.core.coords.ordering

from __future__ import annotations

from dataclasses import dataclass

import mlx.core as mx

from mlx_lattice._native import ext
from mlx_lattice.core.coords.validation import validate_coords


[docs] @dataclass(frozen=True, slots=True) class CoordinateOrdering: """Sorted coordinate rows and the row permutation used to produce them.""" coords: mx.array order: mx.array inverse_rows: mx.array def __post_init__(self) -> None: validate_coords(self.coords) if ( self.order.ndim != 1 or self.order.dtype != mx.int32 or self.order.shape[0] != self.coords.shape[0] ): raise ValueError('order must have shape (N,) and int32 dtype.') if ( self.inverse_rows.ndim != 1 or self.inverse_rows.dtype != mx.int32 or self.inverse_rows.shape[0] != self.coords.shape[0] ): raise ValueError( 'inverse_rows must have shape (N,) and int32 dtype.' )
[docs] def morton_codes(coords: mx.array) -> mx.array: """Return Morton/Z-order codes for batched sparse coordinates.""" validate_coords(coords) if coords.dtype == mx.int32: return ext.morton_codes(coords) return ext.morton_codes(coords.astype(mx.int32))
[docs] def morton_order(coords: mx.array) -> mx.array: """Return row indices that sort coordinates by Morton order.""" return mx.argsort(morton_codes(coords)).astype(mx.int32)
[docs] def morton_sort_coords(coords: mx.array) -> CoordinateOrdering: """Return coordinates sorted by Morton order and the ordering indices.""" order = morton_order(coords) sorted_coords = mx.take(coords, order, axis=0) inverse = mx.argsort(order).astype(mx.int32) return CoordinateOrdering(sorted_coords, order, inverse)