Source code for mlx_lattice.ops._relation_exec

from __future__ import annotations

import weakref

import mlx.core as mx

from mlx_lattice._native import ext
from mlx_lattice.core.quantized import QuantizedWeight
from mlx_lattice.core.relations import KernelRelation

_PACKED_WEIGHT_CACHE: dict[
    int,
    tuple[
        weakref.ReferenceType[mx.array], tuple[int, ...], mx.Dtype, mx.array
    ],
] = {}


[docs] def sparse_quantized_conv_features_from_relation( feats: mx.array, weight: QuantizedWeight, relation: KernelRelation, ) -> mx.array: """Execute quantized sparse convolution over a prebuilt relation.""" if relation.n_out_capacity is None or relation.n_kernels is None: raise ValueError( 'kernel relation is missing static shape metadata.' ) if relation.n_kernels != weight.weight.shape[0]: raise ValueError( 'quantized weight kernel rows must match the relation.' ) sorted_kv_out_in_map = _empty_i32() reorder_rows = _empty_i32() tile_masks = _empty_i32() if _can_use_sorted_quantized_implicit_gemm(feats, weight, relation): view = relation.require_sorted_implicit_gemm() sorted_kv_out_in_map = view.sorted_kv_out_in_map reorder_rows = view.reorder_rows tile_masks = view.tile_masks return ext.sparse_quantized_conv_features( feats, weight.weight, weight.scales, weight.biases, relation.edges.in_rows, relation.edges.out_rows, relation.edges.kernel_ids, relation.counts, relation.output_csr.row_offsets, sorted_kv_out_in_map, reorder_rows, tile_masks, relation.n_out_capacity, relation.n_kernels, weight.in_channels, weight.out_channels, weight.storage_in_channels, weight.group_size, weight.bits, )
def _empty_i32() -> mx.array: return mx.array([], dtype=mx.int32) def _can_use_sorted_quantized_implicit_gemm( feats: mx.array, weight: QuantizedWeight, relation: KernelRelation, ) -> bool: return ( relation.contract.kind in ('forward', 'target', 'submanifold') and feats.dtype == mx.float16 and relation.n_kernels == 27 and weight.storage_in_channels == weight.in_channels and weight.in_channels in (32, 64) and weight.out_channels in (32, 64) and weight.group_size <= weight.in_channels )
[docs] def sparse_conv_features_from_relation( feats: mx.array, weight: mx.array, relation: KernelRelation, ) -> mx.array: """Execute sparse convolution over a prebuilt relation.""" if relation.n_out_capacity is None or relation.n_kernels is None: raise ValueError( 'kernel relation is missing static shape metadata.' ) if _can_use_sorted_implicit_gemm(feats, weight, relation): return sparse_conv_features_sorted_from_relation( feats, weight, relation ) input_csr = relation.input_csr kernel_csr = relation.kernel_csr if input_csr.edge_ids is None or kernel_csr.edge_ids is None: raise ValueError('kernel relation is missing grouped CSR views.') return ext.sparse_conv_features( feats, weight, relation.edges.in_rows, relation.edges.out_rows, relation.edges.kernel_ids, relation.counts, relation.output_csr.row_offsets, input_csr.row_offsets, input_csr.edge_ids, kernel_csr.row_offsets, kernel_csr.edge_ids, relation.n_out_capacity, relation.n_kernels, )
[docs] def sparse_conv_features_sorted_from_relation( feats: mx.array, weight: mx.array, relation: KernelRelation, *, store_sorted: bool = False, ) -> mx.array: """Execute the sorted implicit-GEMM convolution path explicitly.""" if relation.n_out_capacity is None or relation.n_kernels is None: raise ValueError( 'kernel relation is missing static shape metadata.' ) if not _can_use_sorted_implicit_gemm(feats, weight, relation): raise ValueError( 'sorted implicit GEMM is not supported for this relation, ' 'feature tensor, or weight tensor.' ) view = relation.require_sorted_implicit_gemm() return ext.sparse_conv_features_sorted_implicit_gemm( feats, _mapped_weight(weight), view.sorted_out_in_map, view.sorted_kv_out_in_map, view.reorder_rows, view.tile_masks, relation.edges.in_rows, relation.edges.out_rows, relation.edges.kernel_ids, relation.counts, relation.output_csr.row_offsets, relation.input_csr.row_offsets, relation.in_edge_ids, relation.kernel_csr.row_offsets, relation.kernel_edge_ids, relation.n_out_capacity, relation.n_kernels, store_sorted=store_sorted, )
[docs] def sparse_conv_features_sorted_direct_reference_from_relation( feats: mx.array, weight: mx.array, relation: KernelRelation, *, store_sorted: bool = False, ) -> mx.array: """Execute the sorted direct reference path for diagnostics/tests.""" if relation.n_out_capacity is None or relation.n_kernels is None: raise ValueError( 'kernel relation is missing static shape metadata.' ) if not _can_use_sorted_implicit_gemm(feats, weight, relation): raise ValueError( 'sorted direct convolution reference is not supported for this ' 'relation, feature tensor, or weight tensor.' ) view = relation.require_sorted_implicit_gemm() return ext.sparse_conv_features_sorted_direct_reference( feats, _mapped_weight(weight), view.sorted_out_in_map, view.reorder_rows, view.tile_masks, relation.n_out_capacity, relation.n_kernels, store_sorted=store_sorted, )
def _can_use_sorted_implicit_gemm( feats: mx.array, weight: mx.array, relation: KernelRelation, ) -> bool: if relation.contract.kind not in ('forward', 'target', 'submanifold'): return False if feats.dtype != mx.float16 or weight.dtype != mx.float16: return False if relation.n_kernels != 27: return False if feats.ndim != 2 or int(feats.shape[1]) not in (32, 64): return False channels = int(feats.shape[1]) if weight.ndim == 3: return int(weight.shape[0]) == 27 and tuple(weight.shape[1:]) == ( channels, channels, ) return weight.ndim == 5 and tuple(weight.shape) == ( channels, 3, 3, 3, channels, ) def _mapped_weight(weight: mx.array) -> mx.array: if weight.ndim == 3: return weight cache_key = id(weight) shape = tuple(int(dim) for dim in weight.shape) cached = _PACKED_WEIGHT_CACHE.get(cache_key) if cached is not None: cached_ref, cached_shape, cached_dtype, cached_weight = cached if ( cached_ref() is weight and cached_shape == shape and cached_dtype == weight.dtype ): return cached_weight channels = int(weight.shape[0]) packed = mx.contiguous( weight.transpose(1, 2, 3, 4, 0).reshape((-1, channels, channels)) ) def clear_cached_weight( ref: weakref.ReferenceType[mx.array], key: int = cache_key ) -> None: cached_entry = _PACKED_WEIGHT_CACHE.get(key) if cached_entry is not None and cached_entry[0] is ref: _PACKED_WEIGHT_CACHE.pop(key, None) weight_ref = weakref.ref(weight, clear_cached_weight) _PACKED_WEIGHT_CACHE[cache_key] = ( weight_ref, shape, weight.dtype, packed, ) return packed
[docs] def sparse_pool_features_from_relation( feats: mx.array, relation: KernelRelation, *, input_exclusive: bool, mode: str, ) -> mx.array: """Execute sparse pooling over a prebuilt kernel relation.""" if relation.n_out_capacity is None or relation.n_kernels is None: raise ValueError( 'kernel relation is missing static shape metadata.' ) return ext.sparse_pool_features( feats, relation.edges.in_rows, relation.edges.out_rows, relation.edges.kernel_ids, relation.output_csr.row_offsets, relation.counts, input_exclusive, mode, relation.n_out_capacity, relation.n_kernels, )