Source code for mlx_lattice.ops.relations

from __future__ import annotations

from collections.abc import Sequence

import mlx.core as mx

from mlx_lattice.core.coords.builders import (
    build_generative_relation,
    build_kernel_relation,
    build_knn_relation,
    build_radius_relation,
    build_submanifold_kernel_relation,
    build_target_kernel_relation,
    build_transposed_kernel_relation,
    kernel_offsets,
)
from mlx_lattice.core.relations import KernelRelation, NeighborRelation
from mlx_lattice.core.tensor import SparseTensor

__all__ = [
    'build_generative_relation',
    'build_kernel_relation',
    'build_knn_relation',
    'build_radius_relation',
    'build_submanifold_kernel_relation',
    'build_target_kernel_relation',
    'build_transposed_kernel_relation',
    'gather_neighbor_features',
    'generative_kernel_relation',
    'kernel_offsets',
    'kernel_relation',
    'knn_relation',
    'radius_relation',
    'submanifold_kernel_relation',
    'target_kernel_relation',
    'transposed_kernel_relation',
]


[docs] def kernel_relation( x: SparseTensor, *, kernel_size: int | Sequence[int] = 3, stride: int | Sequence[int] = 1, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build the forward kernel relation for a sparse tensor. This is the functional wrapper around ``x.coord_manager.kernel_relation``. The returned relation can be inspected or passed to lower-level execution helpers. """ return x.coord_manager.kernel_relation( x.coord_key, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, )
[docs] def generative_kernel_relation( x: SparseTensor, *, kernel_size: int | Sequence[int] = 2, stride: int | Sequence[int] = 2, ) -> KernelRelation: """Build the generative transpose-convolution relation for a tensor.""" return x.coord_manager.generative_relation( x.coord_key, kernel_size=kernel_size, stride=stride, )
[docs] def submanifold_kernel_relation( x: SparseTensor, *, kernel_size: int | Sequence[int] = 3, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build the submanifold kernel relation for a sparse tensor. The returned relation fixes output support to ``x``'s coordinate identity. It is the relation semantic used by :func:`subm_conv3d`. """ return x.coord_manager.submanifold_kernel_relation( x.coord_key, kernel_size=kernel_size, dilation=dilation, )
[docs] def transposed_kernel_relation( x: SparseTensor, *, kernel_size: int | Sequence[int] = 2, stride: int | Sequence[int] = 2, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build the transpose-convolution relation for a sparse tensor.""" return x.coord_manager.transposed_kernel_relation( x.coord_key, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, )
[docs] def target_kernel_relation( x: SparseTensor, target: SparseTensor, *, kernel_size: int | Sequence[int] = 3, stride: int | Sequence[int] = 1, padding: int | Sequence[int] = 0, dilation: int | Sequence[int] = 1, ) -> KernelRelation: """Build a relation from ``x`` to an explicit target sparse tensor. Target coordinates fix the output support. If ``target`` belongs to a different coordinate manager, its coordinate array is inserted into ``x``'s manager before the relation is cached. """ if x.coord_manager is not target.coord_manager: target_key = x.coord_manager.insert_coords( target.coords, target.stride, target.active_rows ) else: target_key = target.coord_key return x.coord_manager.target_kernel_relation( x.coord_key, target_key, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, )
[docs] def knn_relation( source: SparseTensor, query: SparseTensor | None = None, *, k: int, ) -> NeighborRelation: """Build a k-nearest-neighbor relation for sparse tensors. ``source`` supplies candidate rows and ``query`` supplies query rows. When ``query`` is omitted, the relation is built within ``source``. Both tensors must use the same sparse stride. """ query = source if query is None else query _require_matching_stride(source, query) return build_knn_relation( source.coords, query.coords, source_active_rows=source.active_rows, query_active_rows=query.active_rows, k=k, )
[docs] def radius_relation( source: SparseTensor, query: SparseTensor | None = None, *, radius: float, max_neighbors: int | None = None, ) -> NeighborRelation: """Build a radius-neighborhood relation for sparse tensors. ``max_neighbors`` optionally caps the number of source rows associated with each query row. Distances are stored in the returned relation in edge order. """ query = source if query is None else query _require_matching_stride(source, query) return build_radius_relation( source.coords, query.coords, source_active_rows=source.active_rows, query_active_rows=query.active_rows, radius=radius, max_neighbors=max_neighbors, )
[docs] def gather_neighbor_features( source: SparseTensor, relation: NeighborRelation, ) -> mx.array: """Gather source features in neighbor-edge order. The result has shape ``(E, C)`` where ``E`` is the neighbor edge capacity and ``C`` is ``source.channels``. """ rows = relation.edges.source_rows.astype(mx.int32) return mx.take(source.feats, rows, axis=0)
def _require_matching_stride( source: SparseTensor, query: SparseTensor ) -> None: if source.stride != query.stride: raise ValueError( 'source and query tensors must use the same coordinate stride.' )