from __future__ import annotations
import math
from collections.abc import Sequence
from typing import TYPE_CHECKING
import mlx.core as mx
import mlx.nn as mxnn
from mlx_lattice.core import (
CoordinateMapKey,
KernelSpec,
SparseTensor,
)
from mlx_lattice.ops import (
conv3d,
conv_transpose3d,
generative_conv_transpose3d,
subm_conv3d,
)
__all__ = [
'Conv3d',
'ConvTranspose3d',
'GenerativeConvTranspose3d',
'SubmConv3d',
]
if TYPE_CHECKING:
from mlx_lattice.nn.quantized_conv import (
QuantizedConv3d,
QuantizedConvTranspose3d,
QuantizedGenerativeConvTranspose3d,
QuantizedSubmConv3d,
)
[docs]
class Conv3d(mxnn.Module):
"""Sparse 3D convolution module.
The module owns a dense 5D weight with layout
``(C_out, Kx, Ky, Kz, C_in)`` and optional bias. Calling the module delegates
to :func:`mlx_lattice.ops.conv3d`; coordinate support follows the same
forward or explicit-target semantics as the functional API.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int | Sequence[int] = 3,
stride: int | Sequence[int] = 1,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
bias: bool = True,
) -> None:
super().__init__()
self.spec = KernelSpec(
size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
_init_kernel_params(
self, in_channels, out_channels, self.spec, bias
)
def __call__(
self,
x: SparseTensor,
*,
coordinates: SparseTensor
| CoordinateMapKey
| mx.array
| None = None,
) -> SparseTensor:
return conv3d(
x,
self.weight,
_optional_bias(self),
kernel_size=self.spec.size,
stride=self.spec.stride,
padding=self.spec.padding,
dilation=self.spec.dilation,
coordinates=coordinates,
)
[docs]
def to_quantized(
self,
group_size: int | None = None,
bits: int | None = None,
*,
mode: str = 'affine',
quantize_input: bool = False,
) -> QuantizedConv3d:
from mlx_lattice.nn.quantized_conv import QuantizedConv3d
_validate_quantize_request(mode, quantize_input)
return QuantizedConv3d.from_conv(
self, group_size=group_size, bits=4 if bits is None else bits
)
[docs]
class SubmConv3d(mxnn.Module):
"""Submanifold sparse 3D convolution module.
The output sparse tensor reuses input coordinate identity. Kernel sizes
must be odd because the relation is centered on existing active rows.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int | Sequence[int] = 3,
dilation: int | Sequence[int] = 1,
bias: bool = True,
) -> None:
super().__init__()
self.spec = KernelSpec(
size=kernel_size,
stride=1,
padding=0,
dilation=dilation,
)
_init_kernel_params(
self, in_channels, out_channels, self.spec, bias
)
def __call__(self, x: SparseTensor) -> SparseTensor:
return subm_conv3d(
x,
self.weight,
_optional_bias(self),
kernel_size=self.spec.size,
dilation=self.spec.dilation,
)
[docs]
def to_quantized(
self,
group_size: int | None = None,
bits: int | None = None,
*,
mode: str = 'affine',
quantize_input: bool = False,
) -> QuantizedSubmConv3d:
from mlx_lattice.nn.quantized_conv import QuantizedSubmConv3d
_validate_quantize_request(mode, quantize_input)
return QuantizedSubmConv3d.from_conv(
self, group_size=group_size, bits=4 if bits is None else bits
)
[docs]
class ConvTranspose3d(mxnn.Module):
"""Sparse 3D transpose-convolution module.
Calling the module delegates to :func:`mlx_lattice.ops.conv_transpose3d`
and returns support generated by the transposed relation.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int | Sequence[int] = 2,
stride: int | Sequence[int] = 2,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
bias: bool = True,
) -> None:
super().__init__()
self.spec = KernelSpec(
size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
_init_kernel_params(
self, in_channels, out_channels, self.spec, bias
)
def __call__(self, x: SparseTensor) -> SparseTensor:
return conv_transpose3d(
x,
self.weight,
_optional_bias(self),
kernel_size=self.spec.size,
stride=self.spec.stride,
padding=self.spec.padding,
dilation=self.spec.dilation,
)
[docs]
def to_quantized(
self,
group_size: int | None = None,
bits: int | None = None,
*,
mode: str = 'affine',
quantize_input: bool = False,
) -> QuantizedConvTranspose3d:
from mlx_lattice.nn.quantized_conv import QuantizedConvTranspose3d
_validate_quantize_request(mode, quantize_input)
return QuantizedConvTranspose3d.from_conv(
self, group_size=group_size, bits=4 if bits is None else bits
)
[docs]
class GenerativeConvTranspose3d(mxnn.Module):
"""Generative sparse 3D transpose-convolution module.
Calling the module delegates to
:func:`mlx_lattice.ops.generative_conv_transpose3d` and generates output
coordinates from input rows and stride.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
*,
kernel_size: int | Sequence[int] = 2,
stride: int | Sequence[int] = 2,
bias: bool = True,
) -> None:
super().__init__()
self.spec = KernelSpec(size=kernel_size, stride=stride)
_init_kernel_params(
self, in_channels, out_channels, self.spec, bias
)
def __call__(self, x: SparseTensor) -> SparseTensor:
return generative_conv_transpose3d(
x,
self.weight,
_optional_bias(self),
kernel_size=self.spec.size,
stride=self.spec.stride,
)
[docs]
def to_quantized(
self,
group_size: int | None = None,
bits: int | None = None,
*,
mode: str = 'affine',
quantize_input: bool = False,
) -> QuantizedGenerativeConvTranspose3d:
from mlx_lattice.nn.quantized_conv import (
QuantizedGenerativeConvTranspose3d,
)
_validate_quantize_request(mode, quantize_input)
return QuantizedGenerativeConvTranspose3d.from_conv(
self, group_size=group_size, bits=4 if bits is None else bits
)
def _init_kernel_params(
module: mxnn.Module,
in_channels: int,
out_channels: int,
spec: KernelSpec,
bias: bool,
) -> None:
scale = math.sqrt(1.0 / (in_channels * spec.volume))
module.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(out_channels, *spec.size, in_channels),
)
if bias:
module.bias = mx.zeros((out_channels,))
def _optional_bias(module: mxnn.Module) -> mx.array | None:
return module.bias if 'bias' in module else None
def _validate_quantize_request(mode: str, quantize_input: bool) -> None:
if mode != 'affine':
raise ValueError(
'quantized sparse convolution supports affine mode.'
)
if quantize_input:
raise ValueError(
'quantized sparse convolution uses floating-point activations.'
)