Feature modules¶
Feature modules are sparse-aware wrappers around row-local or channel-local
mlx.nn style operations. They replace x.feats and preserve sparse
coordinate identity.
Module summary¶
Module family |
Examples |
Coordinate effect |
|---|---|---|
Projection |
|
Preserves coordinate identity. |
Activations |
|
Preserves coordinate identity. |
Normalization |
|
Preserves coordinate identity. |
Regularization |
|
Preserves coordinate identity. |
- class mlx_lattice.nn.feature.GELU(approx='none')[source]¶
Bases:
GELUSparse-feature GELU module preserving sparse coordinates.
- class mlx_lattice.nn.feature.BatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)[source]¶
Bases:
BatchNormSparse-feature batch normalization module.
Statistics are computed over sparse feature rows. Running-stat behavior follows the underlying
mlx.nn.BatchNormfields.
- class mlx_lattice.nn.feature.Dropout(p=0.5)[source]¶
Bases:
DropoutSparse-feature dropout module preserving sparse coordinates.
- Parameters:
p (float)
- class mlx_lattice.nn.feature.LayerNorm(dims, eps=1e-05, affine=True, bias=True)[source]¶
Bases:
LayerNormSparse-feature layer normalization module preserving sparse coordinates.
- class mlx_lattice.nn.feature.LeakyReLU(negative_slope=0.01)[source]¶
Bases:
LeakyReLUSparse-feature leaky ReLU module preserving sparse coordinates.
- class mlx_lattice.nn.feature.Linear(input_dims, output_dims, bias=True)[source]¶
Bases:
LinearSparse-feature linear projection module.
This is the sparse analogue of
mlx.nn.Linear. It applies the dense projection tox.featsand preserves sparse coordinate identity.- to_quantized(group_size=None, bits=None, mode='affine', quantize_input=False)[source]¶
Return a quantized approximation of this layer.
If
quantize_inputisFalse, returns aQuantizedLinear(weights are quantized). Ifquantize_inputisTrue, returns aQQLinear(weights and activations are quantized).- Parameters:
group_size (
int|None) – The quantization group size (seemlx.core.quantize()). Default:None.bits (
int|None) – The number of bits per parameter (seemlx.core.quantize()). Default:None.mode (
str) – The quantization method to use (seemlx.core.quantize()). Default:"affine".quantize_input (
bool) – Whether to quantize input. Default:False.
- Returns:
A quantized version of this layer.
- Return type:
Notes
Quantized input is only supported for
"nvfp4"and"mxfp8"modes.
- class mlx_lattice.nn.feature.RMSNorm(dims, eps=1e-05)[source]¶
Bases:
RMSNormSparse-feature RMS normalization module preserving sparse coordinates.
- class mlx_lattice.nn.feature.ReLU[source]¶
Bases:
ReLUSparse-feature ReLU module preserving sparse coordinates.
- class mlx_lattice.nn.feature.SiLU[source]¶
Bases:
SiLUSparse-feature SiLU module preserving sparse coordinates.
- class mlx_lattice.nn.feature.Sigmoid[source]¶
Bases:
SigmoidSparse-feature sigmoid module preserving sparse coordinates.