Source code for mlx_lattice.ops.entropy
from __future__ import annotations
import mlx.core as mx
from mlx_lattice import _ext
[docs]
def normalized_cdf(prob: mx.array) -> mx.array:
"""Convert probability rows to int16 normalized CDF rows.
``prob`` is a two-dimensional probability table with one distribution per
row. The returned CDF is suitable for the range-coding helpers in this
module.
"""
if prob.ndim != 2:
raise ValueError('prob must be a two-dimensional array.')
return _ext.normalized_cdf(prob)
[docs]
def range_encode(cdf: mx.array, symbols: mx.array) -> bytes:
"""Encode symbols using normalized CDF rows and return a byte stream."""
if cdf.ndim != 2:
raise ValueError('cdf must be a two-dimensional array.')
if symbols.ndim != 1:
raise ValueError('symbols must be a one-dimensional array.')
return _ext.range_encode(cdf, symbols.astype(mx.int32))
[docs]
def range_decode(cdf: mx.array, stream: bytes) -> mx.array:
"""Decode symbols from a range-coded byte stream using CDF rows."""
if cdf.ndim != 2:
raise ValueError('cdf must be a two-dimensional array.')
return _ext.range_decode(cdf, stream)
[docs]
def range_encode_from_prob(prob: mx.array, symbols: mx.array) -> bytes:
"""Normalize probability rows and range-encode symbols in one call."""
if prob.ndim != 2:
raise ValueError('prob must be a two-dimensional array.')
if symbols.ndim != 1:
raise ValueError('symbols must be a one-dimensional array.')
return _ext.range_encode_from_prob(prob, symbols.astype(mx.int32))
[docs]
def range_decode_from_prob(prob: mx.array, stream: bytes) -> mx.array:
"""Normalize probability rows and range-decode symbols in one call."""
if prob.ndim != 2:
raise ValueError('prob must be a two-dimensional array.')
return _ext.range_decode_from_prob(prob, stream)
[docs]
def rans_encode_from_prob(prob: mx.array, symbols: mx.array) -> bytes:
"""Encode symbols with byte-oriented rANS from probability rows."""
if prob.ndim != 2:
raise ValueError('prob must be a two-dimensional array.')
if symbols.ndim != 1:
raise ValueError('symbols must be a one-dimensional array.')
return _ext.rans_encode_from_prob(prob, symbols.astype(mx.int32))
[docs]
def rans_decode_from_prob(prob: mx.array, stream: bytes) -> mx.array:
"""Decode symbols with byte-oriented rANS from probability rows."""
if prob.ndim != 2:
raise ValueError('prob must be a two-dimensional array.')
return _ext.rans_decode_from_prob(prob, stream)