from __future__ import annotations
from collections.abc import Sequence
import mlx.core as mx
from mlx_lattice.core import (
CoordinateMapKey,
KernelSpec,
QuantizedWeight,
SparseTensor,
)
from mlx_lattice.core.types import Triple
from mlx_lattice.ops._quantized import quantized_matmul
from mlx_lattice.ops._relation_exec import (
sparse_conv_features_from_relation,
sparse_quantized_conv_features_from_relation,
)
__all__ = [
'conv3d',
'conv_transpose3d',
'generative_conv_transpose3d',
'subm_conv3d',
]
[docs]
def conv3d(
x: SparseTensor,
weight: mx.array | QuantizedWeight,
bias: mx.array | None = None,
*,
kernel_size: int | Sequence[int] = 3,
stride: int | Sequence[int] = 1,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
coordinates: SparseTensor | CoordinateMapKey | mx.array | None = None,
) -> SparseTensor:
"""Apply sparse 3D convolution to a ``SparseTensor``.
The pointwise ``1x1x1`` case with no explicit target coordinates uses a
direct feature matrix multiply and preserves input coordinates. Other
shapes build a sparse kernel relation and dispatch to the selected native
CPU or Metal backend. ``coordinates`` may provide an explicit target
lattice for target-aligned convolution.
Args:
x: Input sparse tensor with feature shape ``(N_in, C_in)``.
weight: Floating weight or packed ``QuantizedWeight``. Floating weights
accept ``(C_out, C_in)``, ``(K, C_in, C_out)``, or
``(C_out, Kx, Ky, Kz, C_in)`` depending on kernel geometry.
bias: Optional ``(C_out,)`` bias matching output feature dtype.
kernel_size: 3D kernel size.
stride: 3D convolution stride. Output sparse stride is
``x.stride * stride``.
padding: 3D padding used by the relation builder.
dilation: 3D dilation used by the relation builder.
coordinates: Optional explicit target support as a sparse tensor,
coordinate key, or coordinate array.
Returns:
Sparse tensor with output features ``(N_out, C_out)``. If
``coordinates`` is supplied, output coordinates are exactly that target
support; otherwise they are generated by the forward relation.
"""
spec = KernelSpec(
size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
output_stride = _mul_stride(x.stride, spec.stride)
if coordinates is None and spec.is_pointwise:
return x.replace(
feats=_with_bias(_pointwise_features(x, weight), bias)
)
if coordinates is not None:
target_key = _target_key(x, coordinates, output_stride)
reuse_input = target_key == x.coord_key
return _relation_conv(
x,
_target_weight(weight, spec),
bias,
spec,
map_kind='forward' if reuse_input else 'target',
output_stride=target_key.stride,
reuse_input_coords=reuse_input,
target_key=target_key,
)
return _relation_conv(
x,
weight,
bias,
spec,
map_kind='forward',
output_stride=output_stride,
)
[docs]
def subm_conv3d(
x: SparseTensor,
weight: mx.array | QuantizedWeight,
bias: mx.array | None = None,
*,
kernel_size: int | Sequence[int] = 3,
dilation: int | Sequence[int] = 1,
) -> SparseTensor:
"""Apply submanifold sparse convolution without changing coordinates.
Submanifold convolution uses an odd kernel, stride ``1``, and no padding
expansion. The result reuses input coordinate identity, so downstream
feature-only operations and relation caches observe the same support.
"""
spec = KernelSpec(
size=kernel_size,
stride=1,
padding=0,
dilation=dilation,
)
_require_odd_kernel(spec.size, 'subm_conv3d')
if spec.size == (1, 1, 1) and spec.dilation == (1, 1, 1):
return x.replace(
feats=_with_bias(_pointwise_features(x, weight), bias)
)
return _relation_conv(
x,
weight,
bias,
spec,
map_kind='submanifold',
output_stride=x.stride,
reuse_input_coords=True,
)
[docs]
def conv_transpose3d(
x: SparseTensor,
weight: mx.array | QuantizedWeight,
bias: mx.array | None = None,
*,
kernel_size: int | Sequence[int] = 2,
stride: int | Sequence[int] = 2,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
) -> SparseTensor:
"""Apply sparse transpose convolution using a transposed kernel relation.
The output sparse stride is ``x.stride / stride`` and each component of
``stride`` must divide the corresponding input stride. Output support is
produced by the transposed relation builder.
"""
spec = KernelSpec(
size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
return _relation_conv(
x,
weight,
bias,
spec,
map_kind='transposed',
output_stride=_div_stride(x.stride, spec.stride),
)
[docs]
def generative_conv_transpose3d(
x: SparseTensor,
weight: mx.array | QuantizedWeight,
bias: mx.array | None = None,
*,
kernel_size: int | Sequence[int] = 2,
stride: int | Sequence[int] = 2,
) -> SparseTensor:
"""Generate output coordinates from a sparse transpose-convolution rule.
This operation expands support directly from input rows and transpose
stride. It is useful for sparse decoders that need generated coordinates
before subsequent target-aligned operations.
"""
spec = KernelSpec(size=kernel_size, stride=stride)
return _relation_conv(
x,
weight,
bias,
spec,
map_kind='generative',
output_stride=_div_stride(x.stride, spec.stride),
)
# MARK: - execution policy
def _relation_conv(
x: SparseTensor,
weight: mx.array | QuantizedWeight,
bias: mx.array | None,
spec: KernelSpec,
*,
map_kind: str,
output_stride: Triple,
reuse_input_coords: bool = False,
target_key: CoordinateMapKey | None = None,
) -> SparseTensor:
_validate_feature_dtype(x.feats, weight)
_validate_metal_coord_dtype(x)
_validate_weight_for_kernel(x, weight, spec.volume)
relation = _kernel_relation(x, spec, map_kind, target_key=target_key)
if relation.n_out_capacity is None or relation.n_kernels is None:
raise ValueError(
'kernel relation is missing static shape metadata.'
)
feats = (
sparse_quantized_conv_features_from_relation(
x.feats, weight, relation
)
if isinstance(weight, QuantizedWeight)
else sparse_conv_features_from_relation(x.feats, weight, relation)
)
if reuse_input_coords:
return x.replace(feats=_with_bias(feats, bias))
if relation.out_coords is None:
raise ValueError('kernel relation is missing output coordinates.')
if target_key is not None and map_kind == 'target':
return SparseTensor(
x.coord_manager.coords(target_key),
_with_bias(feats, bias),
stride=output_stride,
coord_key=target_key,
coord_manager=x.coord_manager,
active_rows=x.coord_manager.active_rows(target_key),
)
return SparseTensor(
relation.out_coords,
_with_bias(feats, bias),
stride=output_stride,
coord_manager=x.coord_manager,
active_rows=relation.out_count,
)
def _kernel_relation(
x: SparseTensor,
spec: KernelSpec,
map_kind: str,
*,
target_key: CoordinateMapKey | None = None,
):
if map_kind == 'forward':
return x.coord_manager.kernel_relation(
x.coord_key,
kernel_size=spec.size,
stride=spec.stride,
padding=spec.padding,
dilation=spec.dilation,
)
if map_kind == 'submanifold':
return x.coord_manager.submanifold_kernel_relation(
x.coord_key,
kernel_size=spec.size,
dilation=spec.dilation,
)
if map_kind == 'transposed':
return x.coord_manager.transposed_kernel_relation(
x.coord_key,
kernel_size=spec.size,
stride=spec.stride,
padding=spec.padding,
dilation=spec.dilation,
)
if map_kind == 'generative':
return x.coord_manager.generative_relation(
x.coord_key,
kernel_size=spec.size,
stride=spec.stride,
)
if map_kind == 'target':
if target_key is None:
raise ValueError('target_key is required for target relations.')
return x.coord_manager.target_kernel_relation(
x.coord_key,
target_key,
kernel_size=spec.size,
stride=spec.stride,
padding=spec.padding,
dilation=spec.dilation,
)
raise ValueError(
"map_kind must be 'forward', 'submanifold', 'transposed', "
"'generative', or 'target'."
)
def _target_key(
x: SparseTensor,
coordinates: SparseTensor | CoordinateMapKey | mx.array,
output_stride: Triple,
) -> CoordinateMapKey:
if isinstance(coordinates, SparseTensor):
if coordinates.stride != output_stride:
raise ValueError(
'target coordinates must use the convolution output stride.'
)
if coordinates.coord_manager is x.coord_manager:
return coordinates.coord_key
return x.coord_manager.insert_coords(
coordinates.coords,
coordinates.stride,
coordinates.active_rows,
)
if isinstance(coordinates, CoordinateMapKey):
if not x.coord_manager.owns(coordinates):
raise ValueError(
'target coordinate key must belong to x.coord_manager.'
)
if coordinates.stride != output_stride:
raise ValueError(
'target coordinate key must use the convolution output stride.'
)
return coordinates
return x.coord_manager.insert_coords(coordinates, output_stride)
def _pointwise_features(
x: SparseTensor,
weight: mx.array | QuantizedWeight,
) -> mx.array:
_validate_feature_dtype(x.feats, weight)
if isinstance(weight, QuantizedWeight):
return quantized_matmul(x.feats, weight)
matrix = _pointwise_weight_matrix(x, weight)
return x.feats @ matrix.T
def _target_weight(
weight: mx.array | QuantizedWeight,
spec: KernelSpec,
) -> mx.array | QuantizedWeight:
if isinstance(weight, QuantizedWeight):
return weight
if not spec.is_pointwise or weight.ndim != 2:
return weight
return mx.expand_dims(weight.T, axis=0)
# MARK: - validation
def _validate_feature_dtype(
feats: mx.array,
weight: mx.array | QuantizedWeight,
) -> None:
if feats.dtype not in (mx.float32, mx.float16):
raise ValueError(
'convolution supports float32 and float16 tensors.'
)
weight_dtype = (
weight.scales.dtype
if isinstance(weight, QuantizedWeight)
else weight.dtype
)
if weight_dtype != feats.dtype:
raise ValueError('convolution weights must match feature dtype.')
def _validate_metal_coord_dtype(x: SparseTensor) -> None:
if mx.default_device() == mx.gpu and x.coords.dtype != mx.int32:
raise ValueError(
'Metal sparse convolution requires int32 coordinates.'
)
def _pointwise_weight_matrix(x: SparseTensor, weight: mx.array) -> mx.array:
if weight.ndim == 2:
if weight.shape[1] != x.channels:
raise ValueError('weight input channels must match x.channels.')
return weight
if (
weight.ndim == 5
and weight.shape[1] == 1
and weight.shape[2] == 1
and weight.shape[3] == 1
):
if weight.shape[4] != x.channels:
raise ValueError('weight input channels must match x.channels.')
return weight[:, 0, 0, 0, :]
if weight.ndim == 3 and weight.shape[0] == 1:
if weight.shape[1] != x.channels:
raise ValueError('weight input channels must match x.channels.')
return weight[0].T
raise ValueError(
'pointwise weight must have shape (C_out, C_in), '
'(C_out, 1, 1, 1, C_in), or (1, C_in, C_out).'
)
def _validate_weight_for_kernel(
x: SparseTensor,
weight: mx.array | QuantizedWeight,
kernel_rows: int,
) -> None:
if isinstance(weight, QuantizedWeight):
if weight.in_channels != x.channels:
raise ValueError(
'quantized weight input channels must match x.channels.'
)
if _volume(weight.kernel_size) != kernel_rows:
raise ValueError(
'quantized weight kernel rows must match the convolution kernel.'
)
return
if weight.ndim == 3:
if weight.shape[1] != x.channels:
raise ValueError('weight input channels must match x.channels.')
if weight.shape[0] != kernel_rows:
raise ValueError(
'weight kernel rows must match the convolution kernel.'
)
return
if weight.ndim != 5:
raise ValueError(
'mapped convolution weight must have shape (K, C_in, C_out) '
'or (C_out, Kx, Ky, Kz, C_in).'
)
if weight.shape[4] != x.channels:
raise ValueError('weight input channels must match x.channels.')
if (
int(weight.shape[1] * weight.shape[2] * weight.shape[3])
!= kernel_rows
):
raise ValueError(
'weight kernel rows must match the convolution kernel.'
)
def _with_bias(feats: mx.array, bias: mx.array | None) -> mx.array:
if bias is None:
return feats
if bias.ndim != 1:
raise ValueError('bias must have shape (C_out,).')
if bias.shape[0] != feats.shape[1]:
raise ValueError('bias channels must match output channels.')
if bias.dtype != feats.dtype:
raise ValueError('bias dtype must match output features.')
return feats + bias
def _require_odd_kernel(values: Triple, op_name: str) -> None:
if any(value % 2 == 0 for value in values):
raise ValueError(f'{op_name} requires odd kernel_size values.')
def _mul_stride(lhs: Triple, rhs: Triple) -> Triple:
return (lhs[0] * rhs[0], lhs[1] * rhs[1], lhs[2] * rhs[2])
def _div_stride(lhs: Triple, rhs: Triple) -> Triple:
out = []
for left, right in zip(lhs, rhs, strict=True):
if left % right != 0:
raise ValueError(
'transpose stride must divide the input tensor stride.'
)
out.append(left // right)
return (out[0], out[1], out[2])
def _volume(size: Triple) -> int:
return size[0] * size[1] * size[2]