from __future__ import annotations
from typing import TYPE_CHECKING, cast
import mlx.core as mx
import mlx.nn as mxnn
from mlx_lattice.core import SparseTensor
from mlx_lattice.ops import feature as F
from mlx_lattice.ops.feature import GeluApprox
__all__ = [
'GELU',
'BatchNorm',
'Dropout',
'LayerNorm',
'LeakyReLU',
'Linear',
'RMSNorm',
'ReLU',
'SiLU',
'Sigmoid',
'Softplus',
'Tanh',
]
if TYPE_CHECKING:
from mlx_lattice.nn.quantized_feature import QuantizedLinear
[docs]
class Linear(mxnn.Linear):
"""Sparse-feature linear projection module.
This is the sparse analogue of ``mlx.nn.Linear``. It applies the dense
projection to ``x.feats`` and preserves sparse coordinate identity.
"""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.linear(
x,
self.weight,
self.bias if 'bias' in self else None,
)
[docs]
def to_quantized(
self,
group_size: int | None = None,
bits: int | None = None,
mode: str = 'affine',
quantize_input: bool = False,
) -> QuantizedLinear:
from mlx_lattice.nn.quantized_feature import QuantizedLinear
if quantize_input:
raise ValueError(
'affine sparse QuantizedLinear uses floating-point activations.'
)
return QuantizedLinear.from_linear(
self,
group_size=group_size,
bits=4 if bits is None else bits,
mode=mode,
)
[docs]
class ReLU(mxnn.ReLU):
"""Sparse-feature ReLU module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.relu(x)
[docs]
class Sigmoid(mxnn.Sigmoid):
"""Sparse-feature sigmoid module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.sigmoid(x)
[docs]
class GELU(mxnn.GELU):
"""Sparse-feature GELU module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.gelu(x, approximate=cast('GeluApprox', self._approx))
[docs]
class SiLU(mxnn.SiLU):
"""Sparse-feature SiLU module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.silu(x)
[docs]
class LeakyReLU(mxnn.LeakyReLU):
"""Sparse-feature leaky ReLU module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.leaky_relu(x, negative_slope=self._negative_slope)
[docs]
class Tanh(mxnn.Tanh):
"""Sparse-feature tanh module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.tanh(x)
[docs]
class Softplus(mxnn.Softplus):
"""Sparse-feature softplus module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.softplus(x)
[docs]
class Dropout(mxnn.Dropout):
"""Sparse-feature dropout module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.dropout(x, p=1 - self._p_1, training=self.training)
[docs]
class BatchNorm(mxnn.BatchNorm):
"""Sparse-feature batch normalization module.
Statistics are computed over sparse feature rows. Running-stat behavior
follows the underlying ``mlx.nn.BatchNorm`` fields.
"""
def __call__(self, x: SparseTensor) -> SparseTensor:
mean = mx.mean(x.feats, axis=0)
var = mx.var(x.feats, axis=0)
if self.training and self.track_running_stats:
mu = self.momentum
self.running_mean = (1 - mu) * self.running_mean + mu * mean
self.running_var = (1 - mu) * self.running_var + mu * var
elif self.track_running_stats:
mean = self.running_mean
var = self.running_var
return F.batch_norm(
x,
weight=self.weight if 'weight' in self else None,
bias=self.bias if 'bias' in self else None,
mean=mean,
var=var,
eps=self.eps,
)
[docs]
class LayerNorm(mxnn.LayerNorm):
"""Sparse-feature layer normalization module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.layer_norm(
x,
weight=self.weight if 'weight' in self else None,
bias=self.bias if 'bias' in self else None,
eps=self.eps,
)
[docs]
class RMSNorm(mxnn.RMSNorm):
"""Sparse-feature RMS normalization module preserving sparse coordinates."""
def __call__(self, x: SparseTensor) -> SparseTensor:
return F.rms_norm(x, weight=self.weight, eps=self.eps)