ResidualBlock#

class ResidualBlock(
sequence_mixer_cfg,
sequence_mixer_norm_cfg,
condition_mixer_cfg,
condition_mixer_norm_cfg,
mlp_cfg,
mlp_norm_cfg,
dropout_cfg,
)#

Bases: Module

Standard pre-norm residual block for ND signals.

Each forward pass executes up to three residual sub-branches:

x ──[input_norm]──► sequence_mixer ──► dropout ──►(+)──► x'
x'──[cond_norm]───► condition_mixer ──► dropout ──►(+)──► x''   (optional)
x''─[mlp_norm]────► mlp ──────────────► dropout ──►(+)──► output

Any branch is bypassed entirely (not just zeroed) when its _cfg target is torch.nn.Identity. This design lets the same class serve pure-sequence networks (condition branch disabled), cross-attention encoder-decoders (condition branch enabled), and MLP-only ablations.

The Identity constraint is one-directional: if a mixer/MLP config targets Identity, the corresponding norm config must also be Identity (enforced by assertion at init). The reverse — a norm set to Identity while the mixer is active — is not checked and is the caller’s responsibility.

All normalisation parameters (input_norm, condition_mixer_norm, mlp_norm) are tagged with _no_weight_decay = True so that the optimiser can exclude them from weight-decay groups.

See also

ResidualNetwork and ClassificationResNet for the canonical consumers of this block.

Parameters:
sequence_mixer#

The instantiated sequence-mixing operator (e.g. QKVSequenceMixer wrapping Hyena, Attention, CKConv, or Mamba). May be torch.nn.Identity when the mixer branch is disabled.

Type:

torch.nn.Module

input_norm#

Pre-norm applied before the sequence mixer (e.g. LayerNorm or RMSNorm). May be torch.nn.Identity.

Type:

torch.nn.Module

condition_mixer#

Cross-attention or conditioning operator applied after the sequence mixer. May be torch.nn.Identity to disable the conditioning branch entirely. When not Identity, its forward(x, condition) signature must accept the residual stream and a conditioning tensor whose channel dimension matches C (the hidden channel dim of x).

Type:

torch.nn.Module

condition_mixer_norm#

Pre-norm applied before condition_mixer. Must be torch.nn.Identity when condition_mixer is torch.nn.Identity.

Type:

torch.nn.Module

mlp#

Position-wise MLP (e.g. MLP). May be torch.nn.Identity to disable the MLP branch.

Type:

torch.nn.Module

mlp_norm#

Pre-norm applied before the MLP. Must be torch.nn.Identity when mlp is torch.nn.Identity.

Type:

torch.nn.Module

dropout#

Dropout (or stochastic depth) applied after each active branch. Typically DropPath or torch.nn.Dropout.

Type:

torch.nn.Module

__init__(
sequence_mixer_cfg,
sequence_mixer_norm_cfg,
condition_mixer_cfg,
condition_mixer_norm_cfg,
mlp_cfg,
mlp_norm_cfg,
dropout_cfg,
)#

Initialise the ResidualBlock.

All positional sub-modules are supplied as LazyConfig objects and instantiated here. Passing torch.nn.Identity as the target for a *_cfg / *_norm_cfg pair disables that branch at zero cost (no forward computation, no parameters).

Parameters:
  • sequence_mixer_cfg (LazyConfig) – LazyConfig for the sequence mixer. Typical targets: QKVSequenceMixer (which internally wraps Hyena, Attention, CKConv, or Mamba), or torch.nn.Identity to skip the mixer branch entirely.

  • sequence_mixer_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the sequence mixer. Must be torch.nn.Identity when sequence_mixer_cfg targets torch.nn.Identity.

  • condition_mixer_cfg (LazyConfig) – LazyConfig for the conditioning / cross- attention operator. Pass torch.nn.Identity (the default in most configs) to disable the conditioning branch. When active, the instantiated module’s forward must accept (x, condition) positional arguments where condition has the same channel dimension C as x. As of this writing, torch.nn.Identity is effectively the only production-ready value; a concrete cross-attention class is planned but not yet part of the public API.

  • condition_mixer_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before condition_mixer. Must be torch.nn.Identity when condition_mixer_cfg targets torch.nn.Identity.

  • mlp_cfg (LazyConfig) – LazyConfig for the position-wise MLP. Typical target: MLP. Pass torch.nn.Identity to skip the MLP branch.

  • mlp_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the MLP. Must be torch.nn.Identity when mlp_cfg targets torch.nn.Identity.

  • dropout_cfg (LazyConfig) – LazyConfig for the dropout / stochastic-depth module applied after each active branch. Typical targets: torch.nn.Dropout, DropPath, or torch.nn.Identity for no dropout.

Raises:

AssertionError – If a mixer/MLP config targets torch.nn.Identity but the corresponding norm config does not. The constraint is one-directional: mixer=Identity ⟹ norm=Identity. The reverse (norm=Identity while mixer is active) is not checked.

forward(x, condition)#

Apply the residual block to the input tensor.

Executes up to three residual sub-branches in order: sequence mixer, conditioning mixer (optional), and MLP (optional). A branch is skipped entirely when its corresponding module is torch.nn.Identity; the condition argument is only consumed when the conditioning branch is active.

Path without conditioning (condition_mixer is Identity):

x → input_norm → sequence_mixer → dropout →(+)→ x'
x' → mlp_norm → mlp → dropout →(+)→ output

Path with conditioning (condition_mixer is not Identity):

x  → input_norm → sequence_mixer → dropout →(+)→ x'
x' → condition_mixer_norm → condition_mixer(x', condition)
   → dropout →(+)→ x''
x'' → mlp_norm → mlp → dropout →(+)→ output
Parameters:
  • x (Tensor) – Input feature tensor of shape (B, *spatial_dims, C) where B is the batch size, spatial_dims are one or more spatial axes (e.g. (H, W) for 2-D images or (T,) for 1-D sequences), and C is the hidden channel dimension.

  • condition (Tensor | None) – Conditioning tensor used by condition_mixer. Its shape depends on the conditioning operator — a common choice is (B, *spatial_dims_condition, C) for cross-attention, or (B, C) for a global conditioning vector. The channel dimension C must match that of x. This argument is ignored (and may safely be None) when condition_mixer is torch.nn.Identity.

Returns:

Output tensor of the same shape as x: (B, *spatial_dims, C).

Return type:

torch.Tensor

Raises:

AssertionError – If condition is None but condition_mixer is not torch.nn.Identity (i.e. a conditioning tensor is required but was not provided).