Mamba#
- class Mamba(mamba_layer_cfg, bidirectional=False)#
Bases:
ModuleSelective state-space mixer for ND signals.
Wraps a 1D Mamba core layer (e.g.
mamba_ssm.Mamba) and extends it to arbitrary spatial rank by flattening all spatial axes into a single sequence dimension before the SSM recurrence and reshaping back afterward.The SSM recurrence computed by the core layer is:
\[\begin{split}h_t &= \bar{A}_t \, h_{t-1} + \bar{B}_t \, x_t \\ y_t &= C_t \, h_t\end{split}\]where the transition matrices \(\bar{A}_t\), \(\bar{B}_t\) and the readout matrix \(C_t\) are all functions of \(x_t\), derived via linear projections inside the core layer. The step size \(\Delta_t\) (also input-dependent) controls the discretisation: \(\bar{A}_t = e^{\Delta_t A}\) (ZOH) and \(\bar{B}_t = \Delta_t B_t\) (Euler).
Scan order for ND inputs: spatial axes are flattened in row-major (C-contiguous) order, i.e. for 2D
[H, W]the sequence visits tokens as (0,0), (0,1), …, (0,W-1), (1,0), … (raster-scan). For 3D[D,H,W]the depth axis varies slowest. This ordering is fixed (not learned). Vertically adjacent pixels areWsteps apart in the flattened sequence; see the module docstring for the anisotropy implication.Bidirectional mode: when
bidirectional=Truea second core layer processes the flattened sequence in reverse, and its (re-reversed) output is summed with the forward output. This gives every position a full-sequence receptive field in both causal directions, which is beneficial for non-causal spatial tasks such as image or volume modelling.- Parameters:
mamba_layer_cfg (LazyConfig)
bidirectional (bool)
- core_layer#
The forward (or only) Mamba core. Must accept input of shape
[B, S, C]and return[B, S, C].- Type:
- core_layer_rev#
The reverse Mamba core, instantiated only when
bidirectional=True. Whenbidirectional=Falsethis attribute is not registered and accessing it raisesAttributeErrorby design, keeping the module’s parameter count andstate_dictunaffected.- Type:
Example:
import torch from nvsubquadratic.lazy_config import LazyConfig from nvsubquadratic.modules.mamba_nd import Mamba from mamba_ssm import Mamba as MambaCore mamba = Mamba( mamba_layer_cfg=LazyConfig(MambaCore)(d_model=128, d_state=16, d_conv=4, expand=2), bidirectional=True, ) # 2D input: batch=2, spatial=(16, 16), channels=128 x = torch.randn(2, 16, 16, 128) y = mamba(x) # [2, 16, 16, 128]
- __init__(mamba_layer_cfg, bidirectional=False)#
Initialise the Mamba-ND wrapper.
- Parameters:
mamba_layer_cfg (LazyConfig) –
LazyConfigfor the underlying 1D Mamba core. The target class must accept a 3-D tensor of shape[B, S, C](batch, sequence length, channels) and return a tensor of the same shape. Typical targets includemamba_ssm.Mambaandmamba_ssm.Mamba2.instantiate(mamba_layer_cfg)is called twice whenbidirectional=True; each call constructs a freshnn.Modulewith newly initialised weights, so the two directions do not share parameters.bidirectional (bool) – If
True, run a second Mamba core on the reversed sequence and sum both outputs. This doubles parameter count and compute but gives non-causal coverage of the full sequence – strongly recommended for spatial tasks (images, volumes). Defaults toFalse.
- Raises:
Exception – Propagated from
instantiate()ifmamba_layer_cfgcannot be constructed. Check that the target class accepts[B, S, C]tensors and that all required constructor arguments are provided in the config.
- forward(x)#
Apply the Mamba SSM to an ND input signal.
The forward pass performs the following steps:
Flatten all spatial axes into one sequence dimension:
[B, *spatial, C]to[B, S, C], whereS = prod(spatial). The flattening follows row-major (C-contiguous) order.Forward SSM:
out = core_layer(x)– applies the selective SSM recurrence \(y_t = C_t(\bar{A}_t h_{t-1} + \bar{B}_t x_t)\).Reverse SSM (only when
bidirectional=True):out_rev = core_layer_rev(flip(x))– runs the SSM on the reversed sequence, then flips back and adds toout:\[\text{out} \mathrel{+}= \mathrm{flip}(\mathrm{Mamba}_\mathrm{rev}(\mathrm{flip}(x)))\]Reshape back to the original spatial layout:
[B, S, C]to[B, *spatial, C].
Implementation note#
The local variable
xis rebound to the flattened[B, S, C]view after therearrangecall; the original spatial shape is preserved inx_shapefor the finalreshape.- param x:
Input tensor of shape
(B, *spatial, C)whereBis batch size,spatialis one or more spatial dimensions (e.g.(T,)for 1D sequences,(H, W)for 2D images,(D, H, W)for 3D volumes), andCis the channel (hidden) dimension. The tensor must be in channels-last (BHC / BHWc) layout, consistent with the rest of the library.- returns:
Output tensor of shape
(B, *spatial, C)– same shape and layout as the input. Whenbidirectional=Truethe output is the element-wise sum of the forward and reverse SSM outputs, which doubles the effective output magnitude compared to a unidirectional pass; downstream normalisation layers (e.g.RMSNorminside the residual block) absorb this scale.