Source code for mlx_lattice.nn.quantized_feature

from __future__ import annotations

import mlx.core as mx
import mlx.nn as mxnn

from mlx_lattice.core import QuantizedWeight, SparseTensor, quantize_weight
from mlx_lattice.ops import feature as F

__all__ = ['QuantizedLinear']


[docs] class QuantizedLinear(mxnn.Module): """Affine weight-quantized sparse-feature linear module. The module stores packed int4/int8 weight, scale, and bias arrays and uses floating-point sparse activations. Calling the module preserves coordinate identity and projects ``x.feats`` to ``output_dims`` channels. """ def __init__( self, input_dims: int, output_dims: int, bias: bool = True, group_size: int | None = None, bits: int = 4, mode: str = 'affine', ) -> None: super().__init__() _require_affine(mode) from mlx_lattice.nn.feature import Linear source = Linear(input_dims, output_dims, bias=bias) self._assign_quantized(source.weight, group_size, bits) if bias: self.bias = source.bias self.freeze() def __call__(self, x: SparseTensor) -> SparseTensor: return F.linear( x, self._quantized_weight(), self.bias if 'bias' in self else None, )
[docs] @classmethod def from_linear( cls, linear: mxnn.Module, group_size: int | None = None, bits: int = 4, mode: str = 'affine', ) -> QuantizedLinear: _require_affine(mode) output_dims, input_dims = map(int, linear.weight.shape) out = cls( input_dims, output_dims, bias='bias' in linear, group_size=group_size, bits=bits, mode=mode, ) out._assign_quantized(linear.weight, group_size, bits) if 'bias' in linear: out.bias = linear.bias out.freeze() return out
def _assign_quantized( self, weight: mx.array, group_size: int | None, bits: int, ) -> None: quantized = quantize_weight( weight, group_size=group_size, bits=bits ) self.weight = quantized.weight self.scales = quantized.scales self.biases = quantized.biases self.group_size = quantized.group_size self.bits = quantized.bits self.input_dims = quantized.in_channels self.output_dims = quantized.out_channels def _quantized_weight(self) -> QuantizedWeight: return QuantizedWeight( self.weight, self.scales, self.biases, self.group_size, self.bits, self.input_dims, self.output_dims, (1, 1, 1), 'linear', )
def _require_affine(mode: str) -> None: if mode != 'affine': raise ValueError('QuantizedLinear currently supports affine mode.')