End-to-end sparse workflow¶
This page describes the lifecycle of data in a typical mlx-lattice model.
It focuses on how user-level operations connect to coordinate identity,
relations, backend dispatch, and module composition.
Workflow cross-reference¶
Workflow concern |
Concept/backend page |
API page |
|---|---|---|
Sparse support and identity |
||
Kernel relations |
||
Modules and feature blocks |
||
Branch joins |
||
Backend route diagnosis |
Stage 1: choose the sparse support¶
An active coordinate has one meaning throughout a pipeline. Common choices are:
Support type |
Active coordinate means |
Typical source |
|---|---|---|
Voxelized point cloud |
At least one point fell into the voxel |
|
Submanifold network |
Site remains active through feature transforms |
|
Strided encoder |
Downsampled support from a previous level |
|
Sparse decoder |
Generated/expanded site |
Transposed or generative convolution. |
Algebra output |
Coordinate selected by join policy |
|
The library can align coordinates, but it cannot infer that two branches mean the same physical quantity. Make the support semantics explicit in model code.
Stage 2: build tensors with metadata¶
Construct SparseTensor with coordinates, features, stride, and
batch_counts when global pooling or decomposition is needed:
x = SparseTensor(coords, feats, stride=1, batch_counts=(len(coords),))
For batched data, prefer sparse_collate:
from mlx_lattice.ops import sparse_collate
batch = sparse_collate([coords0, coords1], [feats0, feats1])
sparse_collate records batch counts and prepends the batch column. This
avoids a common bug where feature rows are concatenated but batch ownership is
lost.
Stage 3: compose feature-preserving blocks¶
Feature-only operations preserve coordinate identity. This includes
activations, normalization, linear layers, and most mlx_lattice.nn feature
modules:
from mlx_lattice import nn
block = [
nn.Linear(32, 64),
nn.BatchNorm(64),
nn.SiLU(),
nn.LayerNorm(64),
]
h = x
for op in block:
h = op(h)
These modules call functional operations that replace feats but keep the
same coordinate key when row count and stride are unchanged.
Stage 4: insert relation-changing blocks¶
Convolution and pooling build kernel relations. A relation contains the sparse edge set used by the backend:
The relation belongs to the coordinate manager and is cached by coordinate key, target key, relation kind, and kernel specification. Model code does not need to manually cache edge arrays unless it is implementing a custom operator.
encoder = nn.Conv3d(32, 64, kernel_size=3, stride=2)
subm = nn.SubmConv3d(64, 64, kernel_size=3)
pooled = nn.MaxPool3d(kernel_size=2, stride=2)
h = encoder(x)
h = subm(h)
h = pooled(h.astype(mx.float32))
Stage 5: align branches before combining¶
Residual and skip connections are safe only when support is known. If both branches preserve identity, direct sparse addition uses the identity fast path. If supports may differ, choose a join:
from mlx_lattice.ops import sparse_add
skip = sparse_add(decoder_branch, encoder_branch, join='inner')
inner is common for strict skip connections. outer is common for
occupancy-style merging. left and right are useful when one branch owns
the desired output support.
Stage 6: bridge point and voxel domains¶
Point/voxel maps are geometry metadata. The same points, voxel size, origin, batch indices, and interpolation mode must be used when reusing a map.
voxels = voxelize(points, point_feats, voxel_size=0.05)
voxel_feats = model(voxels)
point_feats = devoxelize(points, voxel_feats, voxel_size=0.05)
For point heads that need both raw point features and voxel features, keep the point rows as the dense domain and use devoxelization as the bridge back from the sparse domain.
Stage 7: quantize after the floating contract is correct¶
Quantized modules are drop-in inference modules for supported layers. A typical workflow is:
build and validate the floating module;
construct the quantized module from the floating module;
compare against the dequantized/floating contract on representative sparse supports;
benchmark the public operation on the target backend.
Quantized convolution keeps the same sparse relation semantics. The storage and execution route change; the coordinate support does not.
Stage 8: benchmark by public shape¶
Record:
active rows and pattern;
coordinate dtype and feature dtype;
channel count;
kernel volume and relation kind;
packed quantized bit width and group size;
backend device;
batch metadata when global pooling is involved.
Internal kernel names are useful for maintainers during diagnosis, but public benchmark reports are reproducible from public inputs.