from __future__ import annotations
from collections.abc import Sequence
from typing import Literal
import mlx.core as mx
import mlx.nn as mxnn
from mlx_lattice.core import KernelSpec, SparseTensor
from mlx_lattice.ops import (
avg_pool3d,
global_avg_pool,
global_max_pool,
global_sum_pool,
max_pool3d,
pool3d,
sum_pool3d,
)
PoolMode = Literal['sum', 'max', 'avg']
__all__ = [
'AvgPool3d',
'GlobalAvgPool',
'GlobalMaxPool',
'GlobalSumPool',
'MaxPool3d',
'Pool3d',
'SumPool3d',
]
[docs]
class Pool3d(mxnn.Module):
"""Configurable local sparse 3D pooling module.
``mode`` selects ``sum``, ``max``, or ``avg`` reduction over a sparse kernel
relation. The module returns a sparse tensor with output stride multiplied
by the pooling stride.
"""
def __init__(
self,
*,
mode: PoolMode = 'sum',
kernel_size: int | Sequence[int] = 2,
stride: int | Sequence[int] = 2,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
) -> None:
super().__init__()
self.mode = mode
self.spec = KernelSpec(
size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
def __call__(self, x: SparseTensor) -> SparseTensor:
return pool3d(
x,
mode=self.mode,
kernel_size=self.spec.size,
stride=self.spec.stride,
padding=self.spec.padding,
dilation=self.spec.dilation,
)
[docs]
class SumPool3d(Pool3d):
"""Local sparse sum-pooling module."""
def __init__(
self,
*,
kernel_size: int | Sequence[int] = 2,
stride: int | Sequence[int] = 2,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
) -> None:
super().__init__(
mode='sum',
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
def __call__(self, x: SparseTensor) -> SparseTensor:
return sum_pool3d(
x,
kernel_size=self.spec.size,
stride=self.spec.stride,
padding=self.spec.padding,
dilation=self.spec.dilation,
)
[docs]
class MaxPool3d(Pool3d):
"""Local sparse max-pooling module."""
def __init__(
self,
*,
kernel_size: int | Sequence[int] = 2,
stride: int | Sequence[int] = 2,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
) -> None:
super().__init__(
mode='max',
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
def __call__(self, x: SparseTensor) -> SparseTensor:
return max_pool3d(
x,
kernel_size=self.spec.size,
stride=self.spec.stride,
padding=self.spec.padding,
dilation=self.spec.dilation,
)
[docs]
class AvgPool3d(Pool3d):
"""Local sparse average-pooling module."""
def __init__(
self,
*,
kernel_size: int | Sequence[int] = 2,
stride: int | Sequence[int] = 2,
padding: int | Sequence[int] = 0,
dilation: int | Sequence[int] = 1,
) -> None:
super().__init__(
mode='avg',
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
def __call__(self, x: SparseTensor) -> SparseTensor:
return avg_pool3d(
x,
kernel_size=self.spec.size,
stride=self.spec.stride,
padding=self.spec.padding,
dilation=self.spec.dilation,
)
[docs]
class GlobalSumPool(mxnn.Module):
"""Batch-wise global sum-pooling module returning dense ``(B, C)`` rows."""
def __call__(self, x: SparseTensor) -> mx.array:
return global_sum_pool(x)
[docs]
class GlobalAvgPool(mxnn.Module):
"""Batch-wise global average-pooling module returning dense ``(B, C)`` rows."""
def __call__(self, x: SparseTensor) -> mx.array:
return global_avg_pool(x)
[docs]
class GlobalMaxPool(mxnn.Module):
"""Batch-wise global max-pooling module returning dense ``(B, C)`` rows."""
def __call__(self, x: SparseTensor) -> mx.array:
return global_max_pool(x)