from __future__ import annotations
from typing import Literal
import mlx.core as mx
from mlx_lattice.core import QuantizedWeight, SparseTensor
from mlx_lattice.ops._quantized import quantized_matmul
GeluApprox = Literal['none', 'precise', 'tanh', 'fast']
__all__ = [
'batch_norm',
'dropout',
'gelu',
'layer_norm',
'leaky_relu',
'linear',
'relu',
'rms_norm',
'sigmoid',
'silu',
'softplus',
'tanh',
]
[docs]
def linear(
x: SparseTensor,
weight: mx.array | QuantizedWeight,
bias: mx.array | None = None,
) -> SparseTensor:
"""Apply a dense or quantized linear projection to sparse features.
Coordinates are preserved. Dense weights use shape ``(C_out, C_in)`` and
packed weights use ``QuantizedWeight`` with ``linear`` layout. Optional
bias has shape ``(C_out,)``.
"""
if isinstance(weight, QuantizedWeight):
feats = quantized_matmul(x.feats, weight)
return x.replace(feats=_with_bias(feats, bias))
_require_2d_weight(weight)
if weight.shape[1] != x.channels:
raise ValueError('weight input channels must match x.channels.')
feats = x.feats @ weight.T
return x.replace(feats=_with_bias(feats, bias))
[docs]
def relu(x: SparseTensor) -> SparseTensor:
"""Apply ReLU to sparse features while preserving coordinates."""
return x.replace(feats=mx.maximum(x.feats, 0))
[docs]
def sigmoid(x: SparseTensor) -> SparseTensor:
"""Apply sigmoid to sparse features while preserving coordinates."""
return x.replace(feats=mx.sigmoid(x.feats))
[docs]
def gelu(
x: SparseTensor,
*,
approximate: GeluApprox = 'none',
) -> SparseTensor:
"""Apply GELU to sparse features while preserving coordinates.
``approximate`` accepts ``'none'``/``'precise'`` for the erf formula,
``'tanh'`` for the tanh approximation, or ``'fast'`` for the sigmoid-based
approximation.
"""
if approximate in ('none', 'precise'):
scale = mx.array(0.5, dtype=x.feats.dtype)
root_half = mx.array(0.7071067811865476, dtype=x.feats.dtype)
return x.replace(
feats=scale * x.feats * (1 + mx.erf(x.feats * root_half))
)
if approximate == 'tanh':
coeff = mx.array(0.044715, dtype=x.feats.dtype)
scale = mx.array(0.7978845608028654, dtype=x.feats.dtype)
return x.replace(
feats=0.5
* x.feats
* (1 + mx.tanh(scale * (x.feats + coeff * x.feats**3)))
)
if approximate == 'fast':
return x.replace(feats=x.feats * mx.sigmoid(1.702 * x.feats))
raise ValueError(
"approximate must be 'none', 'precise', 'tanh', or 'fast'."
)
[docs]
def silu(x: SparseTensor) -> SparseTensor:
"""Apply SiLU/Swish to sparse features while preserving coordinates."""
return x.replace(feats=x.feats * mx.sigmoid(x.feats))
[docs]
def leaky_relu(
x: SparseTensor,
*,
negative_slope: float = 0.01,
) -> SparseTensor:
"""Apply leaky ReLU to sparse features while preserving coordinates."""
slope = mx.array(float(negative_slope), dtype=x.feats.dtype)
return x.replace(feats=mx.where(x.feats >= 0, x.feats, x.feats * slope))
[docs]
def tanh(x: SparseTensor) -> SparseTensor:
"""Apply hyperbolic tangent to sparse features while preserving coordinates."""
return x.replace(feats=mx.tanh(x.feats))
[docs]
def softplus(
x: SparseTensor,
*,
beta: float = 1.0,
threshold: float = 20.0,
) -> SparseTensor:
"""Apply numerically thresholded softplus to sparse features.
Values above ``threshold`` in the scaled domain return the input directly
to avoid unnecessary exponential work.
"""
if beta <= 0:
raise ValueError('beta must be positive.')
scaled = x.feats * beta
feats = mx.where(
scaled > threshold,
x.feats,
mx.log(1 + mx.exp(scaled)) / beta,
)
return x.replace(feats=feats)
[docs]
def dropout(
x: SparseTensor,
*,
p: float = 0.5,
training: bool = True,
) -> SparseTensor:
"""Apply inverted dropout to sparse features during training.
Coordinates are preserved. When ``training`` is false or ``p`` is zero, the
feature matrix is returned unchanged inside a new sparse wrapper.
"""
if p < 0 or p >= 1:
raise ValueError('p must satisfy 0 <= p < 1.')
if not training or p == 0:
return x.replace(feats=x.feats)
keep = 1.0 - p
mask = mx.random.bernoulli(p=keep, shape=x.feats.shape)
return x.replace(feats=x.feats * mask.astype(x.feats.dtype) / keep)
[docs]
def batch_norm(
x: SparseTensor,
*,
weight: mx.array | None = None,
bias: mx.array | None = None,
mean: mx.array | None = None,
var: mx.array | None = None,
eps: float = 1e-5,
) -> SparseTensor:
"""Apply per-channel batch normalization to sparse features.
If ``mean`` or ``var`` is omitted, statistics are computed from active
feature rows. Optional affine ``weight`` and ``bias`` have shape ``(C,)``.
"""
if eps <= 0:
raise ValueError('eps must be positive.')
mean = mx.mean(x.feats, axis=0) if mean is None else mean
var = mx.var(x.feats, axis=0) if var is None else var
_require_channel_vector(mean, x.channels, 'mean')
_require_channel_vector(var, x.channels, 'var')
feats = (x.feats - mean) * mx.rsqrt(var + eps)
return x.replace(feats=_affine(feats, weight=weight, bias=bias))
[docs]
def layer_norm(
x: SparseTensor,
*,
weight: mx.array | None = None,
bias: mx.array | None = None,
eps: float = 1e-5,
) -> SparseTensor:
"""Apply layer normalization independently to each sparse row."""
if eps <= 0:
raise ValueError('eps must be positive.')
if weight is not None:
_require_channel_vector(weight, x.channels, 'weight')
if bias is not None:
_require_channel_vector(bias, x.channels, 'bias')
return x.replace(feats=mx.fast.layer_norm(x.feats, weight, bias, eps))
[docs]
def rms_norm(
x: SparseTensor,
*,
weight: mx.array | None = None,
eps: float = 1e-5,
) -> SparseTensor:
"""Apply RMS normalization independently to each sparse row."""
if eps <= 0:
raise ValueError('eps must be positive.')
if weight is not None:
_require_channel_vector(weight, x.channels, 'weight')
return x.replace(feats=mx.fast.rms_norm(x.feats, weight, eps))
# MARK: - helpers
def _affine(
feats: mx.array,
*,
weight: mx.array | None,
bias: mx.array | None,
) -> mx.array:
if weight is not None:
_require_channel_vector(weight, int(feats.shape[1]), 'weight')
feats = feats * weight
return _with_bias(feats, bias)
def _with_bias(feats: mx.array, bias: mx.array | None) -> mx.array:
if bias is None:
return feats
_require_channel_vector(bias, int(feats.shape[1]), 'bias')
return feats + bias
def _require_2d_weight(weight: mx.array) -> None:
if weight.ndim != 2:
raise ValueError('weight must have shape (C_out, C_in).')
def _require_channel_vector(
value: mx.array,
channels: int,
name: str,
) -> None:
if value.ndim != 1 or value.shape[0] != channels:
raise ValueError(f'{name} must have shape ({channels},).')