Mamba#

class Mamba(mamba_layer_cfg, bidirectional=False)#

Bases: Module

Selective state-space mixer for ND signals.

Wraps a 1D Mamba core layer (e.g. mamba_ssm.Mamba) and extends it to arbitrary spatial rank by flattening all spatial axes into a single sequence dimension before the SSM recurrence and reshaping back afterward.

The SSM recurrence computed by the core layer is:

\[\begin{split}h_t &= \bar{A}_t \, h_{t-1} + \bar{B}_t \, x_t \\ y_t &= C_t \, h_t\end{split}\]

where the transition matrices \(\bar{A}_t\), \(\bar{B}_t\) and the readout matrix \(C_t\) are all functions of \(x_t\), derived via linear projections inside the core layer. The step size \(\Delta_t\) (also input-dependent) controls the discretisation: \(\bar{A}_t = e^{\Delta_t A}\) (ZOH) and \(\bar{B}_t = \Delta_t B_t\) (Euler).

Scan order for ND inputs: spatial axes are flattened in row-major (C-contiguous) order, i.e. for 2D [H, W] the sequence visits tokens as (0,0), (0,1), …, (0,W-1), (1,0), … (raster-scan). For 3D [D,H,W] the depth axis varies slowest. This ordering is fixed (not learned). Vertically adjacent pixels are W steps apart in the flattened sequence; see the module docstring for the anisotropy implication.

Bidirectional mode: when bidirectional=True a second core layer processes the flattened sequence in reverse, and its (re-reversed) output is summed with the forward output. This gives every position a full-sequence receptive field in both causal directions, which is beneficial for non-causal spatial tasks such as image or volume modelling.

Parameters:
bidirectional#

Whether to apply a second reversed Mamba pass.

Type:

bool

core_layer#

The forward (or only) Mamba core. Must accept input of shape [B, S, C] and return [B, S, C].

Type:

torch.nn.Module

core_layer_rev#

The reverse Mamba core, instantiated only when bidirectional=True. When bidirectional=False this attribute is not registered and accessing it raises AttributeError by design, keeping the module’s parameter count and state_dict unaffected.

Type:

torch.nn.Module

Example:

import torch
from nvsubquadratic.lazy_config import LazyConfig
from nvsubquadratic.modules.mamba_nd import Mamba
from mamba_ssm import Mamba as MambaCore

mamba = Mamba(
    mamba_layer_cfg=LazyConfig(MambaCore)(d_model=128, d_state=16, d_conv=4, expand=2),
    bidirectional=True,
)

# 2D input: batch=2, spatial=(16, 16), channels=128
x = torch.randn(2, 16, 16, 128)
y = mamba(x)   # [2, 16, 16, 128]
__init__(mamba_layer_cfg, bidirectional=False)#

Initialise the Mamba-ND wrapper.

Parameters:
  • mamba_layer_cfg (LazyConfig) – LazyConfig for the underlying 1D Mamba core. The target class must accept a 3-D tensor of shape [B, S, C] (batch, sequence length, channels) and return a tensor of the same shape. Typical targets include mamba_ssm.Mamba and mamba_ssm.Mamba2. instantiate(mamba_layer_cfg) is called twice when bidirectional=True; each call constructs a fresh nn.Module with newly initialised weights, so the two directions do not share parameters.

  • bidirectional (bool) – If True, run a second Mamba core on the reversed sequence and sum both outputs. This doubles parameter count and compute but gives non-causal coverage of the full sequence – strongly recommended for spatial tasks (images, volumes). Defaults to False.

Raises:

Exception – Propagated from instantiate() if mamba_layer_cfg cannot be constructed. Check that the target class accepts [B, S, C] tensors and that all required constructor arguments are provided in the config.

forward(x)#

Apply the Mamba SSM to an ND input signal.

The forward pass performs the following steps:

  1. Flatten all spatial axes into one sequence dimension: [B, *spatial, C] to [B, S, C], where S = prod(spatial). The flattening follows row-major (C-contiguous) order.

  2. Forward SSM: out = core_layer(x) – applies the selective SSM recurrence \(y_t = C_t(\bar{A}_t h_{t-1} + \bar{B}_t x_t)\).

  3. Reverse SSM (only when bidirectional=True): out_rev = core_layer_rev(flip(x)) – runs the SSM on the reversed sequence, then flips back and adds to out:

    \[\text{out} \mathrel{+}= \mathrm{flip}(\mathrm{Mamba}_\mathrm{rev}(\mathrm{flip}(x)))\]
  4. Reshape back to the original spatial layout: [B, S, C] to [B, *spatial, C].

Implementation note#

The local variable x is rebound to the flattened [B, S, C] view after the rearrange call; the original spatial shape is preserved in x_shape for the final reshape.

param x:

Input tensor of shape (B, *spatial, C) where B is batch size, spatial is one or more spatial dimensions (e.g. (T,) for 1D sequences, (H, W) for 2D images, (D, H, W) for 3D volumes), and C is the channel (hidden) dimension. The tensor must be in channels-last (BHC / BHWc) layout, consistent with the rest of the library.

returns:

Output tensor of shape (B, *spatial, C) – same shape and layout as the input. When bidirectional=True the output is the element-wise sum of the forward and reverse SSM outputs, which doubles the effective output magnitude compared to a unidirectional pass; downstream normalisation layers (e.g. RMSNorm inside the residual block) absorb this scale.

Parameters:

x (Tensor)

Return type:

Tensor