Feature modules

Feature modules are sparse-aware wrappers around row-local or channel-local mlx.nn style operations. They replace x.feats and preserve sparse coordinate identity.

Module summary

Module family

Examples

Coordinate effect

Projection

Linear

Preserves coordinate identity.

Activations

ReLU, GELU, SiLU, Sigmoid, Tanh

Preserves coordinate identity.

Normalization

BatchNorm, LayerNorm, RMSNorm

Preserves coordinate identity.

Regularization

Dropout

Preserves coordinate identity.

class mlx_lattice.nn.feature.GELU(approx='none')[source]

Bases: GELU

Sparse-feature GELU module preserving sparse coordinates.

class mlx_lattice.nn.feature.BatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]

Bases: BatchNorm

Sparse-feature batch normalization module.

Statistics are computed over sparse feature rows. Running-stat behavior follows the underlying mlx.nn.BatchNorm fields.

Parameters:
class mlx_lattice.nn.feature.Dropout(p=0.5)[source]

Bases: Dropout

Sparse-feature dropout module preserving sparse coordinates.

Parameters:

p (float)

class mlx_lattice.nn.feature.LayerNorm(dims, eps=1e-05, affine=True, bias=True)[source]

Bases: LayerNorm

Sparse-feature layer normalization module preserving sparse coordinates.

Parameters:
class mlx_lattice.nn.feature.LeakyReLU(negative_slope=0.01)[source]

Bases: LeakyReLU

Sparse-feature leaky ReLU module preserving sparse coordinates.

class mlx_lattice.nn.feature.Linear(input_dims, output_dims, bias=True)[source]

Bases: Linear

Sparse-feature linear projection module.

This is the sparse analogue of mlx.nn.Linear. It applies the dense projection to x.feats and preserves sparse coordinate identity.

Parameters:
to_quantized(group_size=None, bits=None, mode='affine', quantize_input=False)[source]

Return a quantized approximation of this layer.

If quantize_input is False, returns a QuantizedLinear (weights are quantized). If quantize_input is True, returns a QQLinear (weights and activations are quantized).

Parameters:
  • group_size (int | None) – The quantization group size (see mlx.core.quantize()). Default: None.

  • bits (int | None) – The number of bits per parameter (see mlx.core.quantize()). Default: None.

  • mode (str) – The quantization method to use (see mlx.core.quantize()). Default: "affine".

  • quantize_input (bool) – Whether to quantize input. Default: False.

Returns:

A quantized version of this layer.

Return type:

QuantizedLinear

Notes

Quantized input is only supported for "nvfp4" and "mxfp8" modes.

class mlx_lattice.nn.feature.RMSNorm(dims, eps=1e-05)[source]

Bases: RMSNorm

Sparse-feature RMS normalization module preserving sparse coordinates.

Parameters:
class mlx_lattice.nn.feature.ReLU[source]

Bases: ReLU

Sparse-feature ReLU module preserving sparse coordinates.

class mlx_lattice.nn.feature.SiLU[source]

Bases: SiLU

Sparse-feature SiLU module preserving sparse coordinates.

class mlx_lattice.nn.feature.Sigmoid[source]

Bases: Sigmoid

Sparse-feature sigmoid module preserving sparse coordinates.

class mlx_lattice.nn.feature.Softplus[source]

Bases: Softplus

Sparse-feature softplus module preserving sparse coordinates.

class mlx_lattice.nn.feature.Tanh[source]

Bases: Tanh

Sparse-feature tanh module preserving sparse coordinates.