Pooling modules¶
Pooling modules wrap local relation pooling and batch-wise global pooling. Local pooling returns sparse tensors; global pooling returns dense batch rows.
Module summary¶
Module |
Reduction |
Output type |
|---|---|---|
|
Configurable |
|
|
Fixed local reduction mode. |
|
|
Batch-wise dense reduction. |
MLX array |
- class mlx_lattice.nn.pool.AvgPool3d(*, kernel_size=2, stride=2, padding=0, dilation=1)[source]¶
Bases:
Pool3dLocal sparse average-pooling module.
- class mlx_lattice.nn.pool.GlobalAvgPool[source]¶
Bases:
ModuleBatch-wise global average-pooling module returning dense
(B, C)rows.
- class mlx_lattice.nn.pool.GlobalMaxPool[source]¶
Bases:
ModuleBatch-wise global max-pooling module returning dense
(B, C)rows.
- class mlx_lattice.nn.pool.GlobalSumPool[source]¶
Bases:
ModuleBatch-wise global sum-pooling module returning dense
(B, C)rows.
- class mlx_lattice.nn.pool.MaxPool3d(*, kernel_size=2, stride=2, padding=0, dilation=1)[source]¶
Bases:
Pool3dLocal sparse max-pooling module.
- class mlx_lattice.nn.pool.Pool3d(*, mode='sum', kernel_size=2, stride=2, padding=0, dilation=1)[source]¶
Bases:
ModuleConfigurable local sparse 3D pooling module.
modeselectssum,max, oravgreduction over a sparse kernel relation. The module returns a sparse tensor with output stride multiplied by the pooling stride.