Source code for mlx_lattice.ops._quantized
from __future__ import annotations
import mlx.core as mx
from mlx_lattice.core.quantized import QuantizedWeight
[docs]
def quantized_matmul(
feats: mx.array,
weight: QuantizedWeight,
) -> mx.array:
"""Apply affine packed-weight matrix multiplication to feature rows."""
if weight.kernel_size != (1, 1, 1):
raise ValueError('quantized matmul requires a pointwise weight.')
if feats.ndim != 2 or feats.shape[1] != weight.in_channels:
raise ValueError(
'features must have shape (N, quantized_weight.in_channels).'
)
if feats.dtype != weight.scales.dtype:
raise ValueError('features must match quantized scale dtype.')
if weight.storage_in_channels != weight.in_channels:
feats = mx.pad(
feats,
[
(0, 0),
(0, weight.storage_in_channels - weight.in_channels),
],
)
return mx.quantized_matmul(
feats,
weight.weight[0],
weight.scales[0],
weight.biases[0],
transpose=True,
group_size=weight.group_size,
bits=weight.bits,
mode='affine',
)