Source code for mlx_lattice.nn.feature

from __future__ import annotations

from typing import TYPE_CHECKING, cast

import mlx.core as mx
import mlx.nn as mxnn

from mlx_lattice.core import SparseTensor
from mlx_lattice.ops import feature as F
from mlx_lattice.ops.feature import GeluApprox

__all__ = [
    'GELU',
    'BatchNorm',
    'Dropout',
    'LayerNorm',
    'LeakyReLU',
    'Linear',
    'RMSNorm',
    'ReLU',
    'SiLU',
    'Sigmoid',
    'Softplus',
    'Tanh',
]

if TYPE_CHECKING:
    from mlx_lattice.nn.quantized_feature import QuantizedLinear


[docs] class Linear(mxnn.Linear): """Sparse-feature linear projection module. This is the sparse analogue of ``mlx.nn.Linear``. It applies the dense projection to ``x.feats`` and preserves sparse coordinate identity. """ def __call__(self, x: SparseTensor) -> SparseTensor: return F.linear( x, self.weight, self.bias if 'bias' in self else None, )
[docs] def to_quantized( self, group_size: int | None = None, bits: int | None = None, mode: str = 'affine', quantize_input: bool = False, ) -> QuantizedLinear: from mlx_lattice.nn.quantized_feature import QuantizedLinear if quantize_input: raise ValueError( 'affine sparse QuantizedLinear uses floating-point activations.' ) return QuantizedLinear.from_linear( self, group_size=group_size, bits=4 if bits is None else bits, mode=mode, )
[docs] class ReLU(mxnn.ReLU): """Sparse-feature ReLU module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.relu(x)
[docs] class Sigmoid(mxnn.Sigmoid): """Sparse-feature sigmoid module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.sigmoid(x)
[docs] class GELU(mxnn.GELU): """Sparse-feature GELU module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.gelu(x, approximate=cast('GeluApprox', self._approx))
[docs] class SiLU(mxnn.SiLU): """Sparse-feature SiLU module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.silu(x)
[docs] class LeakyReLU(mxnn.LeakyReLU): """Sparse-feature leaky ReLU module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.leaky_relu(x, negative_slope=self._negative_slope)
[docs] class Tanh(mxnn.Tanh): """Sparse-feature tanh module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.tanh(x)
[docs] class Softplus(mxnn.Softplus): """Sparse-feature softplus module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.softplus(x)
[docs] class Dropout(mxnn.Dropout): """Sparse-feature dropout module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.dropout(x, p=1 - self._p_1, training=self.training)
[docs] class BatchNorm(mxnn.BatchNorm): """Sparse-feature batch normalization module. Statistics are computed over sparse feature rows. Running-stat behavior follows the underlying ``mlx.nn.BatchNorm`` fields. """ def __call__(self, x: SparseTensor) -> SparseTensor: mean = mx.mean(x.feats, axis=0) var = mx.var(x.feats, axis=0) if self.training and self.track_running_stats: mu = self.momentum self.running_mean = (1 - mu) * self.running_mean + mu * mean self.running_var = (1 - mu) * self.running_var + mu * var elif self.track_running_stats: mean = self.running_mean var = self.running_var return F.batch_norm( x, weight=self.weight if 'weight' in self else None, bias=self.bias if 'bias' in self else None, mean=mean, var=var, eps=self.eps, )
[docs] class LayerNorm(mxnn.LayerNorm): """Sparse-feature layer normalization module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.layer_norm( x, weight=self.weight if 'weight' in self else None, bias=self.bias if 'bias' in self else None, eps=self.eps, )
[docs] class RMSNorm(mxnn.RMSNorm): """Sparse-feature RMS normalization module preserving sparse coordinates.""" def __call__(self, x: SparseTensor) -> SparseTensor: return F.rms_norm(x, weight=self.weight, eps=self.eps)