Source code for mlx_lattice.core.quantized

from __future__ import annotations

from dataclasses import dataclass
from typing import Literal

import mlx.core as mx

from mlx_lattice.core.types import Triple

QuantizedWeightLayout = Literal['linear', 'kernel_major', 'dense_5d']

__all__ = [
    'QuantizedWeight',
    'QuantizedWeightLayout',
    'dequantize_weight',
    'quantize_weight',
]


[docs] @dataclass(frozen=True, slots=True) class QuantizedWeight: """Packed affine INT4/INT8 inference weight. The object stores packed ``uint32`` integer codes plus per-group affine ``scales`` and ``biases``. Logical values are reconstructed as ``scale * code + bias`` by quantized linear and convolution paths. ``layout`` records the logical source shape: ``linear`` for ``(C_out, C_in)``, ``kernel_major`` for ``(K, C_in, C_out)``, and ``dense_5d`` for ``(C_out, Kx, Ky, Kz, C_in)``. """ weight: mx.array scales: mx.array biases: mx.array group_size: int bits: int in_channels: int out_channels: int kernel_size: Triple layout: QuantizedWeightLayout def __post_init__(self) -> None: if self.bits not in (4, 8): raise ValueError('quantized weight bits must be 4 or 8.') if self.group_size not in (32, 64, 128): raise ValueError( 'quantized weight group_size must be 32, 64, or 128.' ) if self.weight.dtype != mx.uint32 or self.weight.ndim != 3: raise ValueError( 'packed weight must be a three-dimensional uint32 array.' ) if self.scales.dtype not in (mx.float16, mx.float32): raise ValueError('quantized scales must be float16 or float32.') if self.biases.dtype != self.scales.dtype: raise ValueError('quantized biases must match scales dtype.') if self.scales.ndim != 3 or self.biases.shape != self.scales.shape: raise ValueError( 'scales and biases must have shape (K, C_out, G).' ) if self.in_channels <= 0 or self.out_channels <= 0: raise ValueError('quantized weight channels must be positive.') if any(size <= 0 for size in self.kernel_size): raise ValueError( 'quantized weight kernel dimensions must be positive.' ) kernel_rows = _volume(self.kernel_size) if self.weight.shape[0] != kernel_rows: raise ValueError( 'packed weight kernel rows do not match kernel_size.' ) if self.is_pointwise: if self.weight.shape[1] != self.out_channels: raise ValueError( 'pointwise packed weight output channels do not match.' ) if self.scales.shape[:2] != self.weight.shape[:2]: raise ValueError( 'pointwise weight and quantization rows must match.' ) else: if self.weight.shape[2] != self.out_channels: raise ValueError( 'spatial packed weight output channels do not match.' ) if self.scales.shape[0] != self.weight.shape[0] or ( self.scales.shape[2] != self.out_channels ): raise ValueError( 'spatial weight and quantization rows must match.' ) if self.storage_in_channels % self.group_size != 0: raise ValueError( 'packed storage channels must be divisible by group_size.' ) group_dim = 2 if self.is_pointwise else 1 if ( self.scales.shape[group_dim] != self.storage_in_channels // self.group_size ): raise ValueError( 'quantization group count does not match storage.' ) if self.storage_in_channels < self.in_channels: raise ValueError( 'packed storage channels cannot be smaller than logical channels.' ) @property def storage_in_channels(self) -> int: packed_dim = 2 if self.is_pointwise else 1 return int(self.weight.shape[packed_dim]) * 32 // self.bits @property def is_pointwise(self) -> bool: return _volume(self.kernel_size) == 1 @property def nbytes(self) -> int: return self.weight.nbytes + self.scales.nbytes + self.biases.nbytes
[docs] def quantize_weight( weight: mx.array, *, group_size: int | None = None, bits: int = 4, ) -> QuantizedWeight: """Pack a linear or sparse-convolution weight for inference. Args: weight: Floating ``float16`` or ``float32`` weight. Accepted shapes are ``(C_out, C_in)``, ``(K, C_in, C_out)``, or ``(C_out, Kx, Ky, Kz, C_in)``. group_size: Quantization group size. ``None`` chooses ``64`` for ``C_in >= 64`` and ``32`` otherwise. bits: Packed integer width, either ``4`` or ``8``. Returns: A ``QuantizedWeight`` containing packed storage and affine metadata. Input channels are padded in storage to the selected group size when needed; logical ``in_channels`` remains the original channel count. """ if weight.dtype not in (mx.float16, mx.float32): raise ValueError( 'weight quantization supports float16 and float32.' ) if bits not in (4, 8): raise ValueError('weight quantization supports only 4 or 8 bits.') rows, layout, kernel_size, in_channels, out_channels = _weight_rows( weight ) group_size = _resolve_group_size(in_channels, group_size) storage_in_channels = _round_up(in_channels, group_size) if storage_in_channels != in_channels: rows = mx.pad( rows, [(0, 0), (0, storage_in_channels - in_channels)] ) packed, scales, biases = mx.quantize( rows, group_size=group_size, bits=bits, mode='affine', ) kernel_rows = _volume(kernel_size) packed = packed.reshape((kernel_rows, out_channels, -1)) scales = scales.reshape((kernel_rows, out_channels, -1)) biases = biases.reshape((kernel_rows, out_channels, -1)) if kernel_rows > 1: packed = mx.contiguous(packed.transpose(0, 2, 1)) scales = mx.contiguous(scales.transpose(0, 2, 1)) biases = mx.contiguous(biases.transpose(0, 2, 1)) return QuantizedWeight( weight=packed, scales=scales, biases=biases, group_size=group_size, bits=bits, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, layout=layout, )
[docs] def dequantize_weight(weight: QuantizedWeight) -> mx.array: """Restore the logical floating-point weight represented by ``weight``. The returned array uses the original logical layout recorded by ``weight.layout`` and slices away any padded storage channels. """ kernel_rows = _volume(weight.kernel_size) packed = weight.weight scales = weight.scales biases = weight.biases if not weight.is_pointwise: packed = packed.transpose(0, 2, 1) scales = scales.transpose(0, 2, 1) biases = biases.transpose(0, 2, 1) rows = mx.dequantize( packed.reshape((kernel_rows * weight.out_channels, -1)), scales.reshape((kernel_rows * weight.out_channels, -1)), biases.reshape((kernel_rows * weight.out_channels, -1)), group_size=weight.group_size, bits=weight.bits, mode='affine', )[:, : weight.in_channels] if weight.layout == 'linear': return rows.reshape((weight.out_channels, weight.in_channels)) mapped = rows.reshape( (kernel_rows, weight.out_channels, weight.in_channels) ) if weight.layout == 'kernel_major': return mapped.transpose(0, 2, 1) return mapped.reshape( (*weight.kernel_size, weight.out_channels, weight.in_channels) ).transpose(3, 0, 1, 2, 4)
def _weight_rows( weight: mx.array, ) -> tuple[mx.array, QuantizedWeightLayout, Triple, int, int]: if weight.ndim == 2: out_channels, in_channels = map(int, weight.shape) return ( weight, 'linear', (1, 1, 1), in_channels, out_channels, ) if weight.ndim == 3: kernel_rows, in_channels, out_channels = map(int, weight.shape) return ( weight.transpose(0, 2, 1).reshape( (kernel_rows * out_channels, in_channels) ), 'kernel_major', (kernel_rows, 1, 1), in_channels, out_channels, ) if weight.ndim == 5: out_channels, kx, ky, kz, in_channels = map(int, weight.shape) return ( weight.transpose(1, 2, 3, 0, 4).reshape( (kx * ky * kz * out_channels, in_channels) ), 'dense_5d', (kx, ky, kz), in_channels, out_channels, ) raise ValueError( 'weight must have shape (C_out, C_in), (K, C_in, C_out), ' 'or (C_out, Kx, Ky, Kz, C_in).' ) def _resolve_group_size(in_channels: int, group_size: int | None) -> int: if group_size is None: return 64 if in_channels >= 64 else 32 if group_size not in (32, 64, 128): raise ValueError('group_size must be 32, 64, or 128.') return group_size def _round_up(value: int, multiple: int) -> int: return ((value + multiple - 1) // multiple) * multiple def _volume(size: Triple) -> int: return size[0] * size[1] * size[2]