Pooling modules

Pooling modules wrap local relation pooling and batch-wise global pooling. Local pooling returns sparse tensors; global pooling returns dense batch rows.

Module summary

Module

Reduction

Output type

Pool3d

Configurable sum, max, or avg.

SparseTensor

SumPool3d / MaxPool3d / AvgPool3d

Fixed local reduction mode.

SparseTensor

GlobalSumPool / GlobalAvgPool / GlobalMaxPool

Batch-wise dense reduction.

MLX array

class mlx_lattice.nn.pool.AvgPool3d(*, kernel_size=2, stride=2, padding=0, dilation=1)[source]

Bases: Pool3d

Local sparse average-pooling module.

Parameters:
  • kernel_size (int | Sequence[int])

  • stride (int | Sequence[int])

  • padding (int | Sequence[int])

  • dilation (int | Sequence[int])

class mlx_lattice.nn.pool.GlobalAvgPool[source]

Bases: Module

Batch-wise global average-pooling module returning dense (B, C) rows.

class mlx_lattice.nn.pool.GlobalMaxPool[source]

Bases: Module

Batch-wise global max-pooling module returning dense (B, C) rows.

class mlx_lattice.nn.pool.GlobalSumPool[source]

Bases: Module

Batch-wise global sum-pooling module returning dense (B, C) rows.

class mlx_lattice.nn.pool.MaxPool3d(*, kernel_size=2, stride=2, padding=0, dilation=1)[source]

Bases: Pool3d

Local sparse max-pooling module.

Parameters:
  • kernel_size (int | Sequence[int])

  • stride (int | Sequence[int])

  • padding (int | Sequence[int])

  • dilation (int | Sequence[int])

class mlx_lattice.nn.pool.Pool3d(*, mode='sum', kernel_size=2, stride=2, padding=0, dilation=1)[source]

Bases: Module

Configurable local sparse 3D pooling module.

mode selects sum, max, or avg reduction over a sparse kernel relation. The module returns a sparse tensor with output stride multiplied by the pooling stride.

Parameters:
  • mode (PoolMode)

  • kernel_size (int | Sequence[int])

  • stride (int | Sequence[int])

  • padding (int | Sequence[int])

  • dilation (int | Sequence[int])

class mlx_lattice.nn.pool.SumPool3d(*, kernel_size=2, stride=2, padding=0, dilation=1)[source]

Bases: Pool3d

Local sparse sum-pooling module.

Parameters:
  • kernel_size (int | Sequence[int])

  • stride (int | Sequence[int])

  • padding (int | Sequence[int])

  • dilation (int | Sequence[int])