Convolution routes¶
Sparse convolution has the richest backend selection matrix in the project. The public operations are:
mlx_lattice.ops.conv3d();mlx_lattice.ops.subm_conv3d();mlx_lattice.ops.conv_transpose3d();mlx_lattice.ops.generative_conv_transpose3d().
All four routes reduce over relation edges. The difference is how the output coordinate support is produced.
Semantic map kinds¶
Map kind |
Builder |
Output support |
|---|---|---|
|
|
Produced from input coordinates, stride, padding, and dilation. |
|
|
Explicit target coordinates supplied to |
|
|
Output support is the input coordinate identity; no coordinate expansion is performed. |
|
|
Expanded support for transpose convolution. |
|
|
Generated support from a transpose-convolution rule. |
forward, target, and submanifold map kinds are considered by the
sorted implicit-GEMM forward route. Transposed and generative convolutions use
relation traversal for their current public path.
Floating forward routes¶
Route |
Predicate |
Notes |
|---|---|---|
Pointwise matmul |
|
Computes |
Generic relation traversal |
Any valid relation convolution not captured by a more specific route |
Consumes edge arrays plus output/input/kernel CSR views. |
Dense-channel Metal kernels |
5D dense weight layout, |
Specialized forward kernels for common channel-aligned 3D convolutions. |
|
|
Optimizes the small-output-channel case. |
|
fp32 features and |
Vectorized output-channel traversal. |
fp16 gather kernel |
fp16 features |
Uses gather-style traversal instead of fp32 atomic fallback. |
fp32 atomic kernel |
fp32 features when no gather/vector route is selected |
Accumulates by relation edge. |
Sorted fp16 implicit-GEMM¶
The Python predicate for the sorted floating route is:
For a 5D dense weight, the Python layer maps
(C, 3, 3, 3, C) into a contiguous (27, C, C) tensor before dispatch.
The Metal runtime then chooses between:
Route |
Additional predicate |
Kernel family |
|---|---|---|
TensorOps sorted contraction |
Neural-accelerator capability, contiguous fp16 features/weights, sorted relation view |
Row-stationary TensorOps kernels with 64-row tiles. |
Direct sorted reference route |
Same shape/layout predicate but TensorOps is not preferred |
Row-stationary direct Metal kernels for C32/C64. |
The sorted view stores:
reorder_rows maps sorted output rows back to public output order, and
tile_masks stores occupancy masks for 64-row tiles.
Quantized forward routes¶
Packed quantized convolution is selected by passing QuantizedWeight.
Supported bit widths are 4 and 8. Packed weights use uint32 storage, affine
scales, affine biases, and group size 32, 64, or 128.
Route |
Predicate |
Notes |
|---|---|---|
Direct packed convolution |
Any valid quantized relation convolution |
Metal kernels dispatch by feature dtype and bit width: fp16/fp32 × int4/int8. |
Sorted quantized implicit-GEMM |
Sorted plan present, fp16 features, |
Contracts in sorted order and reorders output rows. |
TensorOps quantized contraction |
fp16 features, |
Dequantizes a temporary fp16 weight tile and runs TensorOps contraction. |
The direct packed route computes the affine reconstruction per group:
Autodiff routes¶
Floating sparse convolution defines JVP and VJP for features and weights. Backward execution has its own Metal route selection:
Backward path |
Predicate |
Route |
|---|---|---|
Input gradient TensorOps |
|
TensorOps input-gradient contraction. |
Input gradient dense-channel kernels |
Dense 5D weight, input capacity at least |
Specialized dense-channel kernels; grouped dense route is selected at larger input capacities. |
Weight gradient TensorOps |
|
Partitioned TensorOps contraction followed by reduction. |
Weight gradient classic kernels |
fp16, block4-compatible channels, |
Classic Metal kernels selected by channel, kernel volume, edge count, and input capacity. |
Quantized convolution is inference-oriented in the public surface. If a training path requires gradients through packed weights, dequantize explicitly and use the floating route.
Weight layouts¶
Floating convolution accepts:
dense 5D layout
(C_out, Kx, Ky, Kz, C_in);mapped kernel-major layout
(K, C_in, C_out)for internal sorted routes.
The sorted floating route requires mapped (27, C, C) storage. The public
conv3d call can still receive 5D dense weights; Python maps and caches the
contiguous internal view.
Validation checklist¶
When diagnosing a convolution route, record:
map kind: forward, target, submanifold, transposed, or generative;
feature dtype and coordinate dtype;
dense versus quantized weight;
kernel volume
K;C_inandC_out;relation output capacity and edge count;
Metal capability tier when running on GPU.