Source code for mlx_lattice.core.coords.alignment
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
import mlx.core as mx
from mlx_lattice._native import ext
from mlx_lattice.core.coords.validation import validate_coord_pair
type SparseJoin = Literal['inner', 'left', 'right', 'outer']
[docs]
@dataclass(frozen=True, slots=True)
class SparseAlignment:
"""Coordinate value alignment for two sparse tensors.
``coords`` is the joined coordinate support. ``lhs_rows`` and ``rhs_rows``
gather features from the left and right tensors into that support; ``-1``
marks a missing value for the selected join mode.
"""
coords: mx.array
active_rows: mx.array
lhs_rows: mx.array
rhs_rows: mx.array
def __post_init__(self) -> None:
if self.coords.ndim != 2 or self.coords.shape[1] != 4:
raise ValueError('coords must have shape (N, 4).')
if self.coords.dtype not in (mx.int32, mx.int64):
raise ValueError('coords must be int32 or int64.')
if (
self.active_rows.shape != (1,)
or self.active_rows.dtype != mx.int32
):
raise ValueError(
'active_rows must have shape (1,) and int32 dtype.'
)
if self.lhs_rows.shape != (self.coords.shape[0],):
raise ValueError('lhs_rows must have shape (N,).')
if self.rhs_rows.shape != (self.coords.shape[0],):
raise ValueError('rhs_rows must have shape (N,).')
if (
self.lhs_rows.dtype != mx.int32
or self.rhs_rows.dtype != mx.int32
):
raise ValueError('alignment rows must be int32.')
@property
def capacity(self) -> int:
return int(self.coords.shape[0])
@property
def active_count(self) -> mx.array:
return self.active_rows
[docs]
def build_sparse_alignment(
lhs_coords: mx.array,
lhs_active_rows: mx.array,
rhs_coords: mx.array,
rhs_active_rows: mx.array,
*,
join: SparseJoin = 'inner',
) -> SparseAlignment:
"""Build value-aligned row maps between two coordinate arrays.
Coordinates must have shape ``(N, 4)`` and matching dtype. Active-row
scalars describe the valid prefix of each coordinate buffer. The native
builder returns joined coordinates and row maps for ``inner``, ``left``,
``right``, or ``outer`` support.
"""
validate_coord_pair(lhs_coords, rhs_coords)
if lhs_coords.dtype != mx.int32:
raise ValueError(
'sparse alignment currently requires int32 coords.'
)
_validate_active_rows(lhs_active_rows, 'lhs_active_rows')
_validate_active_rows(rhs_active_rows, 'rhs_active_rows')
return SparseAlignment(
*ext.build_sparse_alignment(
lhs_coords,
lhs_active_rows,
rhs_coords,
rhs_active_rows,
_validate_join(join),
)
)
def _validate_join(value: str) -> SparseJoin:
if value == 'inner':
return 'inner'
if value == 'left':
return 'left'
if value == 'right':
return 'right'
if value == 'outer':
return 'outer'
raise ValueError("join must be 'inner', 'left', 'right', or 'outer'.")
def _validate_active_rows(value: mx.array, name: str) -> None:
if value.shape != (1,) or value.dtype != mx.int32:
raise ValueError(f'{name} must have shape (1,) and int32 dtype.')