from __future__ import annotations
from dataclasses import dataclass
from typing import Literal
import mlx.core as mx
from mlx_lattice.core.types import Triple
RelationKind = Literal[
'forward',
'target',
'submanifold',
'transposed',
'generative',
]
[docs]
@dataclass(frozen=True, slots=True)
class RelationEdges:
"""Diagnostic edge arrays for a logical sparse kernel relation.
Each edge stores ``(in_row, out_row, kernel_id)`` at the same index in the
three arrays. Edge arrays are useful for debugging and for generic sparse
traversal routes; CSR views provide grouped execution order.
"""
in_rows: mx.array
out_rows: mx.array
kernel_ids: mx.array
def __post_init__(self) -> None:
_validate_row_array(self.in_rows, name='in_rows')
_validate_row_array(self.out_rows, name='out_rows')
_validate_row_array(self.kernel_ids, name='kernel_ids')
_require_same_rows(
self.in_rows,
self.out_rows,
self.kernel_ids,
names=('in_rows', 'out_rows', 'kernel_ids'),
)
@property
def capacity(self) -> int:
return int(self.in_rows.shape[0])
[docs]
@dataclass(frozen=True, slots=True)
class RelationCSRView:
"""CSR execution view over relation edge arrays.
``row_offsets`` has shape ``(rows + 1,)`` and ``int32`` dtype. Optional
``edge_ids`` remaps CSR order back to the canonical edge arrays.
"""
row_offsets: mx.array
edge_ids: mx.array | None = None
def __post_init__(self) -> None:
_validate_row_offsets(self.row_offsets)
if self.edge_ids is not None:
_validate_row_array(self.edge_ids, name='edge_ids')
RelationView = RelationCSRView
[docs]
@dataclass(frozen=True, slots=True)
class RelationImplicitGemmView:
"""Dense output-row by kernel-offset map for implicit-GEMM execution.
``out_in_map[o, k]`` stores the input row that contributes to output row
``o`` at kernel offset ``k``. Missing contributors are represented by the
native builder's sentinel value. ``row_masks`` stores bit masks describing
populated kernel offsets for each output row.
"""
out_in_map: mx.array
row_masks: mx.array
def __post_init__(self) -> None:
if self.out_in_map.ndim != 2:
raise ValueError('out_in_map must have shape (Nout, K).')
if self.out_in_map.dtype not in (mx.int32, mx.int64):
raise ValueError('out_in_map must be int32 or int64.')
if self.row_masks.ndim != 2:
raise ValueError('row_masks must have shape (Nout, M).')
if self.row_masks.dtype not in (mx.int32, mx.int64):
raise ValueError('row_masks must be int32 or int64.')
if int(self.row_masks.shape[0]) != int(self.out_in_map.shape[0]):
raise ValueError(
'row_masks must have one row per out_in_map row.'
)
expected_mask_words = (int(self.out_in_map.shape[1]) + 31) // 32
if int(self.row_masks.shape[1]) != expected_mask_words:
raise ValueError(
'row_masks must have ceil(K / 32) words per output row.'
)
[docs]
@dataclass(frozen=True, slots=True)
class RelationSortedImplicitGemmView:
"""Tile-sorted implicit-GEMM execution view for TensorOps kernels."""
sorted_out_in_map: mx.array
sorted_kv_out_in_map: mx.array
reorder_rows: mx.array
tile_masks: mx.array
def __post_init__(self) -> None:
if self.sorted_out_in_map.ndim != 2:
raise ValueError('sorted_out_in_map must have shape (Nout, K).')
if self.sorted_out_in_map.dtype not in (mx.int32, mx.int64):
raise ValueError('sorted_out_in_map must be int32 or int64.')
if self.sorted_kv_out_in_map.ndim != 2:
raise ValueError(
'sorted_kv_out_in_map must have shape (K, Nout).'
)
if self.sorted_kv_out_in_map.dtype not in (mx.int32, mx.int64):
raise ValueError('sorted_kv_out_in_map must be int32 or int64.')
if int(self.sorted_kv_out_in_map.shape[0]) != int(
self.sorted_out_in_map.shape[1]
) or int(self.sorted_kv_out_in_map.shape[1]) != int(
self.sorted_out_in_map.shape[0]
):
raise ValueError(
'sorted_kv_out_in_map must be the K-major view of '
'sorted_out_in_map.'
)
_validate_row_array(self.reorder_rows, name='reorder_rows')
_validate_row_array(self.tile_masks, name='tile_masks')
if int(self.reorder_rows.shape[0]) != int(
self.sorted_out_in_map.shape[0]
):
raise ValueError(
'reorder_rows must have one row per sorted_out_in_map row.'
)
[docs]
@dataclass(frozen=True, slots=True)
class SparseRelationContract:
"""Logical sparse relation contract shared by all execution views.
The contract records counts, capacities, optional source/target coordinate
buffers, kernel offsets, stride, padding, and relation kind. Backend routes
use this metadata to validate specialized kernels before consuming CSR or
implicit-GEMM views.
"""
counts: mx.array
kernel_offsets: tuple[Triple, ...]
out_coords: mx.array | None = None
n_in_capacity: int | None = None
n_out_capacity: int | None = None
n_kernels: int | None = None
source_coords: mx.array | None = None
source_active_rows: mx.array | None = None
target_coords: mx.array | None = None
target_active_rows: mx.array | None = None
stride: Triple = (1, 1, 1)
padding: Triple = (0, 0, 0)
kind: RelationKind = 'forward'
def __post_init__(self) -> None:
_validate_counts(self.counts)
if self.out_coords is not None:
_validate_coords(self.out_coords, name='out_coords')
if self.source_coords is not None:
_validate_coords(self.source_coords, name='source_coords')
if self.target_coords is not None:
_validate_coords(self.target_coords, name='target_coords')
if self.source_active_rows is not None:
_validate_active_rows(
self.source_active_rows, name='source_active_rows'
)
if self.target_active_rows is not None:
_validate_active_rows(
self.target_active_rows, name='target_active_rows'
)
if self.kind not in (
'forward',
'target',
'submanifold',
'transposed',
'generative',
):
raise ValueError(
"relation kind must be 'forward', 'target', "
"'submanifold', 'transposed', or 'generative'."
)
normalized_offsets = tuple(
(int(x), int(y), int(z)) for x, y, z in self.kernel_offsets
)
object.__setattr__(self, 'kernel_offsets', normalized_offsets)
object.__setattr__(
self, 'stride', tuple(int(value) for value in self.stride)
)
object.__setattr__(
self, 'padding', tuple(int(value) for value in self.padding)
)
n_in = _optional_count(self.n_in_capacity, 'n_in_capacity')
n_out = _optional_count(self.n_out_capacity, 'n_out_capacity')
n_kernels = _optional_count(self.n_kernels, 'n_kernels')
if normalized_offsets:
if n_kernels is not None and n_kernels != len(
normalized_offsets
):
raise ValueError('n_kernels must match kernel_offsets.')
n_kernels = len(normalized_offsets)
if self.out_coords is not None:
out_capacity = int(self.out_coords.shape[0])
if n_out is not None and n_out != out_capacity:
raise ValueError(
'n_out_capacity must match out_coords capacity.'
)
n_out = out_capacity
object.__setattr__(self, 'n_in_capacity', n_in)
object.__setattr__(self, 'n_out_capacity', n_out)
object.__setattr__(self, 'n_kernels', n_kernels)
@property
def edge_count(self) -> mx.array:
return self.counts[:1]
@property
def out_count(self) -> mx.array:
return self.counts[1:2]
[docs]
@dataclass(frozen=True, slots=True)
class NeighborEdges:
"""Semantic neighbor edge arrays for query/source relations.
Each edge stores the query row, source row, and neighbor rank or identifier
for geometric neighbor queries such as kNN and radius search.
"""
query_rows: mx.array
source_rows: mx.array
neighbor_ids: mx.array
def __post_init__(self) -> None:
_validate_row_array(self.query_rows, name='query_rows')
_validate_row_array(self.source_rows, name='source_rows')
_validate_row_array(self.neighbor_ids, name='neighbor_ids')
_require_same_rows(
self.query_rows,
self.source_rows,
self.neighbor_ids,
names=('query_rows', 'source_rows', 'neighbor_ids'),
)
@property
def capacity(self) -> int:
return int(self.query_rows.shape[0])
[docs]
@dataclass(frozen=True, slots=True, init=False)
class KernelRelation:
"""Sparse kernel relation with multiple execution-oriented views.
A relation stores semantic edges ``(in_row, out_row, kernel_id)`` plus CSR
views grouped by output row, input row, and kernel id. Forward convolution,
pooling, and backend-specific fast paths consume these views.
"""
contract: SparseRelationContract
edges: RelationEdges
output_csr: RelationCSRView
input_csr: RelationCSRView
kernel_csr: RelationCSRView
implicit_gemm: RelationImplicitGemmView | None = None
sorted_implicit_gemm: RelationSortedImplicitGemmView | None = None
def __init__(
self,
in_rows: mx.array,
out_rows: mx.array,
kernel_ids: mx.array,
*,
row_offsets: mx.array | None = None,
counts: mx.array | None = None,
in_row_offsets: mx.array | None = None,
in_edge_ids: mx.array | None = None,
kernel_row_offsets: mx.array | None = None,
kernel_edge_ids: mx.array | None = None,
kernel_offsets: tuple[Triple, ...] = (),
out_coords: mx.array | None = None,
n_in_capacity: int | None = None,
n_out_capacity: int | None = None,
n_kernels: int | None = None,
source_coords: mx.array | None = None,
source_active_rows: mx.array | None = None,
target_coords: mx.array | None = None,
target_active_rows: mx.array | None = None,
stride: Triple = (1, 1, 1),
padding: Triple = (0, 0, 0),
kind: RelationKind = 'forward',
implicit_gemm: RelationImplicitGemmView | None = None,
sorted_implicit_gemm: RelationSortedImplicitGemmView | None = None,
) -> None:
if counts is None:
counts = mx.array(
[
in_rows.shape[0],
0 if out_coords is None else out_coords.shape[0],
],
dtype=mx.int32,
)
_validate_counts(counts)
contract = SparseRelationContract(
counts=counts,
kernel_offsets=kernel_offsets,
out_coords=out_coords,
n_in_capacity=n_in_capacity,
n_out_capacity=n_out_capacity,
n_kernels=n_kernels,
source_coords=source_coords,
source_active_rows=source_active_rows,
target_coords=target_coords,
target_active_rows=target_active_rows,
stride=stride,
padding=padding,
kind=kind,
)
if row_offsets is None:
out_capacity = (
0
if contract.n_out_capacity is None
else int(contract.n_out_capacity)
)
row_offsets = mx.array([0] * (out_capacity + 1), dtype=mx.int32)
_validate_row_offsets(row_offsets)
edges = RelationEdges(in_rows, out_rows, kernel_ids)
if in_row_offsets is None:
in_capacity = (
0
if contract.n_in_capacity is None
else int(contract.n_in_capacity)
)
in_row_offsets = mx.array(
[0] * (in_capacity + 1), dtype=mx.int32
)
if in_edge_ids is None:
in_edge_ids = mx.array([0] * edges.capacity, dtype=mx.int32)
if kernel_row_offsets is None:
kernel_capacity = (
0 if contract.n_kernels is None else int(contract.n_kernels)
)
kernel_row_offsets = mx.array(
[0] * (kernel_capacity + 1), dtype=mx.int32
)
if kernel_edge_ids is None:
kernel_edge_ids = mx.array([0] * edges.capacity, dtype=mx.int32)
output_csr = RelationCSRView(row_offsets)
input_csr = RelationCSRView(in_row_offsets, in_edge_ids)
kernel_csr = RelationCSRView(kernel_row_offsets, kernel_edge_ids)
if (
contract.n_out_capacity is not None
and int(row_offsets.shape[0])
!= int(contract.n_out_capacity) + 1
):
raise ValueError(
'row_offsets must have length n_out_capacity + 1.'
)
object.__setattr__(self, 'contract', contract)
object.__setattr__(self, 'edges', edges)
object.__setattr__(self, 'output_csr', output_csr)
object.__setattr__(self, 'input_csr', input_csr)
object.__setattr__(self, 'kernel_csr', kernel_csr)
object.__setattr__(self, 'implicit_gemm', implicit_gemm)
object.__setattr__(
self, 'sorted_implicit_gemm', sorted_implicit_gemm
)
@property
def edge_capacity(self) -> int:
"""Static edge-buffer capacity."""
return self.edges.capacity
@property
def counts(self) -> mx.array:
"""Native counts array ``[edge_count, out_count]``."""
return self.contract.counts
@property
def edge_count(self) -> mx.array:
return self.contract.edge_count
@property
def out_count(self) -> mx.array:
return self.contract.out_count
@property
def kernel_offsets(self) -> tuple[Triple, ...]:
return self.contract.kernel_offsets
@property
def out_coords(self) -> mx.array | None:
return self.contract.out_coords
@property
def n_in_capacity(self) -> int | None:
return self.contract.n_in_capacity
@property
def n_out_capacity(self) -> int | None:
return self.contract.n_out_capacity
@property
def n_kernels(self) -> int | None:
return self.contract.n_kernels
@property
def row_offsets(self) -> mx.array:
return self.output_csr.row_offsets
@property
def in_row_offsets(self) -> mx.array:
return self.input_csr.row_offsets
@property
def in_edge_ids(self) -> mx.array:
edge_ids = self.input_csr.edge_ids
if edge_ids is None:
raise ValueError('input CSR view is missing edge ids.')
return edge_ids
@property
def kernel_row_offsets(self) -> mx.array:
return self.kernel_csr.row_offsets
@property
def kernel_edge_ids(self) -> mx.array:
edge_ids = self.kernel_csr.edge_ids
if edge_ids is None:
raise ValueError('kernel CSR view is missing edge ids.')
return edge_ids
@property
def out_view(self) -> RelationCSRView:
return self.output_csr
@property
def in_view(self) -> RelationCSRView:
return self.input_csr
@property
def kernel_view(self) -> RelationCSRView:
return self.kernel_csr
[docs]
def require_implicit_gemm(self) -> RelationImplicitGemmView:
"""Return or lazily build the implicit-GEMM execution view."""
view = self.implicit_gemm
if view is not None:
return view
from mlx_lattice.core.coords.builders import (
build_relation_implicit_gemm_view,
)
view = build_relation_implicit_gemm_view(self)
object.__setattr__(self, 'implicit_gemm', view)
return view
[docs]
def require_sorted_implicit_gemm(
self,
) -> RelationSortedImplicitGemmView:
"""Return or lazily build the tile-sorted implicit-GEMM view."""
sorted_view = self.sorted_implicit_gemm
if sorted_view is not None:
return sorted_view
view = self.require_implicit_gemm()
if view.row_masks.shape[1] != 1:
raise ValueError(
'sorted implicit GEMM view currently supports K <= 32.'
)
row_masks = view.row_masks[:, 0]
sorted_rows = mx.argsort(row_masks).astype(mx.int32)
sorted_out_in_map = view.out_in_map[sorted_rows]
sorted_kv_out_in_map = mx.contiguous(
mx.transpose(sorted_out_in_map)
)
tile_count = (int(view.out_in_map.shape[0]) + 63) // 64
words = []
for word in range(4):
word_masks = mx.zeros((tile_count,), dtype=mx.int32)
for row_offset in range(word * 16, (word + 1) * 16):
rows = (
mx.arange(tile_count, dtype=mx.int32) * 64 + row_offset
)
valid = rows < int(view.out_in_map.shape[0])
clipped = mx.minimum(
rows, int(view.out_in_map.shape[0]) - 1
)
masks = mx.where(valid, row_masks[sorted_rows[clipped]], 0)
word_masks = mx.bitwise_or(word_masks, masks)
words.append(word_masks)
tile_masks = mx.reshape(mx.stack(words, axis=1), (-1,))
sorted_view = RelationSortedImplicitGemmView(
sorted_out_in_map=sorted_out_in_map,
sorted_kv_out_in_map=sorted_kv_out_in_map,
reorder_rows=sorted_rows,
tile_masks=tile_masks,
)
object.__setattr__(self, 'sorted_implicit_gemm', sorted_view)
return sorted_view
[docs]
@dataclass(frozen=True, slots=True, init=False)
class NeighborRelation:
"""Sparse query/source neighbor relation with row-offset metadata.
Neighbor relations are not convolution kernel relations. They describe
geometric connectivity between query rows and source rows, carry distances
in edge order, and expose row offsets grouped by query row.
"""
edges: NeighborEdges
distances: mx.array
row_offsets: mx.array
counts: mx.array
n_query_capacity: int | None = None
n_source_capacity: int | None = None
max_neighbors: int | None = None
def __init__(
self,
query_rows: mx.array,
source_rows: mx.array,
neighbor_ids: mx.array,
distances: mx.array,
*,
row_offsets: mx.array | None = None,
counts: mx.array | None = None,
n_query_capacity: int | None = None,
n_source_capacity: int | None = None,
max_neighbors: int | None = None,
) -> None:
edges = NeighborEdges(query_rows, source_rows, neighbor_ids)
_validate_distance_array(distances)
_require_same_rows(
query_rows,
distances,
names=('query_rows', 'distances'),
)
if counts is None:
counts = mx.array([query_rows.shape[0], 0], dtype=mx.int32)
_validate_counts(counts)
if row_offsets is None:
query_capacity = (
0 if n_query_capacity is None else int(n_query_capacity)
)
row_offsets = mx.array(
[0] * (query_capacity + 1), dtype=mx.int32
)
_validate_row_offsets(row_offsets)
if n_query_capacity is not None and int(row_offsets.shape[0]) != (
int(n_query_capacity) + 1
):
raise ValueError(
'row_offsets must have length n_query_capacity + 1.'
)
object.__setattr__(self, 'edges', edges)
object.__setattr__(self, 'distances', distances)
object.__setattr__(self, 'row_offsets', row_offsets)
object.__setattr__(self, 'counts', counts)
object.__setattr__(
self,
'n_query_capacity',
_optional_count(n_query_capacity, 'n_query_capacity'),
)
object.__setattr__(
self,
'n_source_capacity',
_optional_count(n_source_capacity, 'n_source_capacity'),
)
object.__setattr__(
self,
'max_neighbors',
_optional_count(max_neighbors, 'max_neighbors'),
)
@property
def edge_capacity(self) -> int:
return self.edges.capacity
@property
def edge_count(self) -> mx.array:
return self.counts[:1]
@property
def query_count(self) -> mx.array:
return self.counts[1:2]
# MARK: - helpers
def _validate_row_array(value: mx.array, *, name: str) -> None:
if value.ndim != 1:
raise ValueError(f'{name} must have shape (E,).')
if value.dtype not in (mx.int32, mx.int64):
raise ValueError(f'{name} must be int32 or int64.')
def _validate_coords(value: mx.array, *, name: str) -> None:
if value.ndim != 2 or value.shape[1] != 4:
raise ValueError(f'{name} must have shape (N, 4).')
if value.dtype not in (mx.int32, mx.int64):
raise ValueError(f'{name} must be int32 or int64.')
def _validate_counts(value: mx.array) -> None:
if value.shape != (2,) or value.dtype != mx.int32:
raise ValueError(
'relation counts must have shape (2,) and int32 dtype.'
)
def _validate_active_rows(value: mx.array, *, name: str) -> None:
if value.shape != (1,) or value.dtype != mx.int32:
raise ValueError(f'{name} must have shape (1,) and int32 dtype.')
def _validate_row_offsets(value: mx.array) -> None:
if value.ndim != 1 or value.dtype != mx.int32:
raise ValueError(
'row_offsets must have shape (N + 1,) and int32 dtype.'
)
def _validate_distance_array(value: mx.array) -> None:
if value.ndim != 1:
raise ValueError('distances must have shape (E,).')
if value.dtype not in (mx.float32, mx.float64):
raise ValueError('distances must be float32 or float64.')
def _require_same_rows(
first: mx.array,
*rest: mx.array,
names: tuple[str, ...],
) -> None:
rows = int(first.shape[0])
for name, value in zip(names[1:], rest, strict=True):
if int(value.shape[0]) != rows:
raise ValueError(
f'{names[0]} and {name} must have the same row count.'
)
def _optional_count(value: int | None, name: str) -> int | None:
if value is None:
return None
normalized = int(value)
if normalized < 0:
raise ValueError(f'{name} must be non-negative.')
return normalized