Pooling operations¶
Pooling operations reduce sparse feature rows through a kernel relation or through batch metadata.
Local pooling computes:
sum and avg accept empty output rows as zero-valued reductions. max
requires at least one contributing row for every output row. Global pooling
uses batch_counts metadata from the input sparse tensor and returns a dense
(B, C) MLX array.
Related pages¶
Backend reduction routes: Pooling routes
Relation model: Coordinates and relations
Batch metadata: Sparse tensor model
Module wrappers: Pooling modules
- mlx_lattice.ops.pool.avg_pool3d(x, *, kernel_size=2, stride=2, padding=0, dilation=1)[source]¶
Apply local sparse average pooling.
The result feature at each output row is the sparse sum divided by the number of contributing relation edges for that output row.
- mlx_lattice.ops.pool.global_avg_pool(x)[source]¶
Average features independently for each batch.
Requires
x.batch_countsand returns a dense(B, C)MLX array. Empty batches produce zero rows.- Return type:
array- Parameters:
x (SparseTensor)
- mlx_lattice.ops.pool.global_max_pool(x)[source]¶
Max-reduce features independently for each batch.
Requires
x.batch_countsand returns a dense(B, C)MLX array. Empty batches are rejected because max has no neutral finite sparse row.- Return type:
array- Parameters:
x (SparseTensor)
- mlx_lattice.ops.pool.global_sum_pool(x)[source]¶
Sum features independently for each batch.
Requires
x.batch_countsand returns a dense(B, C)MLX array.- Return type:
array- Parameters:
x (SparseTensor)
- mlx_lattice.ops.pool.max_pool3d(x, *, kernel_size=2, stride=2, padding=0, dilation=1)[source]¶
Apply local sparse max pooling.
The result feature at each output row is the channel-wise maximum over contributing input rows in the sparse kernel relation.
- mlx_lattice.ops.pool.pool3d(x, *, mode='sum', kernel_size=2, stride=2, padding=0, dilation=1)[source]¶
Apply local sparse 3D pooling with
sum,max, oravgmode.Local pooling builds a forward kernel relation and reduces input features that contribute to each output coordinate. The output sparse stride is
x.stride * stride. Current native pooling routes acceptfloat32features; Metal routes additionally requireint32coordinates.