ResidualBlock#
- class ResidualBlock(
- sequence_mixer_cfg,
- sequence_mixer_norm_cfg,
- condition_mixer_cfg,
- condition_mixer_norm_cfg,
- mlp_cfg,
- mlp_norm_cfg,
- dropout_cfg,
Bases:
ModuleStandard pre-norm residual block for ND signals.
Each forward pass executes up to three residual sub-branches:
x ──[input_norm]──► sequence_mixer ──► dropout ──►(+)──► x' x'──[cond_norm]───► condition_mixer ──► dropout ──►(+)──► x'' (optional) x''─[mlp_norm]────► mlp ──────────────► dropout ──►(+)──► output
Any branch is bypassed entirely (not just zeroed) when its
_cfgtarget istorch.nn.Identity. This design lets the same class serve pure-sequence networks (condition branch disabled), cross-attention encoder-decoders (condition branch enabled), and MLP-only ablations.The Identity constraint is one-directional: if a mixer/MLP config targets
Identity, the corresponding norm config must also beIdentity(enforced by assertion at init). The reverse — a norm set toIdentitywhile the mixer is active — is not checked and is the caller’s responsibility.All normalisation parameters (
input_norm,condition_mixer_norm,mlp_norm) are tagged with_no_weight_decay = Trueso that the optimiser can exclude them from weight-decay groups.See also
ResidualNetworkandClassificationResNetfor the canonical consumers of this block.- Parameters:
sequence_mixer_cfg (LazyConfig)
sequence_mixer_norm_cfg (LazyConfig)
condition_mixer_cfg (LazyConfig)
condition_mixer_norm_cfg (LazyConfig)
mlp_cfg (LazyConfig)
mlp_norm_cfg (LazyConfig)
dropout_cfg (LazyConfig)
- sequence_mixer#
The instantiated sequence-mixing operator (e.g.
QKVSequenceMixerwrapping Hyena, Attention, CKConv, or Mamba). May betorch.nn.Identitywhen the mixer branch is disabled.- Type:
- input_norm#
Pre-norm applied before the sequence mixer (e.g. LayerNorm or RMSNorm). May be
torch.nn.Identity.- Type:
- condition_mixer#
Cross-attention or conditioning operator applied after the sequence mixer. May be
torch.nn.Identityto disable the conditioning branch entirely. When not Identity, itsforward(x, condition)signature must accept the residual stream and a conditioning tensor whose channel dimension matchesC(the hidden channel dim ofx).- Type:
- condition_mixer_norm#
Pre-norm applied before
condition_mixer. Must betorch.nn.Identitywhencondition_mixeristorch.nn.Identity.- Type:
- mlp_norm#
Pre-norm applied before the MLP. Must be
torch.nn.Identitywhenmlpistorch.nn.Identity.- Type:
- dropout#
Dropout (or stochastic depth) applied after each active branch. Typically
DropPathortorch.nn.Dropout.- Type:
- __init__(
- sequence_mixer_cfg,
- sequence_mixer_norm_cfg,
- condition_mixer_cfg,
- condition_mixer_norm_cfg,
- mlp_cfg,
- mlp_norm_cfg,
- dropout_cfg,
Initialise the ResidualBlock.
All positional sub-modules are supplied as
LazyConfigobjects and instantiated here. Passingtorch.nn.Identityas the target for a*_cfg/*_norm_cfgpair disables that branch at zero cost (no forward computation, no parameters).- Parameters:
sequence_mixer_cfg (LazyConfig) – LazyConfig for the sequence mixer. Typical targets:
QKVSequenceMixer(which internally wraps Hyena, Attention, CKConv, or Mamba), ortorch.nn.Identityto skip the mixer branch entirely.sequence_mixer_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the sequence mixer. Must be
torch.nn.Identitywhensequence_mixer_cfgtargetstorch.nn.Identity.condition_mixer_cfg (LazyConfig) – LazyConfig for the conditioning / cross- attention operator. Pass
torch.nn.Identity(the default in most configs) to disable the conditioning branch. When active, the instantiated module’sforwardmust accept(x, condition)positional arguments whereconditionhas the same channel dimensionCasx. As of this writing,torch.nn.Identityis effectively the only production-ready value; a concrete cross-attention class is planned but not yet part of the public API.condition_mixer_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before
condition_mixer. Must betorch.nn.Identitywhencondition_mixer_cfgtargetstorch.nn.Identity.mlp_cfg (LazyConfig) – LazyConfig for the position-wise MLP. Typical target:
MLP. Passtorch.nn.Identityto skip the MLP branch.mlp_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the MLP. Must be
torch.nn.Identitywhenmlp_cfgtargetstorch.nn.Identity.dropout_cfg (LazyConfig) – LazyConfig for the dropout / stochastic-depth module applied after each active branch. Typical targets:
torch.nn.Dropout,DropPath, ortorch.nn.Identityfor no dropout.
- Raises:
AssertionError – If a mixer/MLP config targets
torch.nn.Identitybut the corresponding norm config does not. The constraint is one-directional: mixer=Identity ⟹ norm=Identity. The reverse (norm=Identity while mixer is active) is not checked.
- forward(x, condition)#
Apply the residual block to the input tensor.
Executes up to three residual sub-branches in order: sequence mixer, conditioning mixer (optional), and MLP (optional). A branch is skipped entirely when its corresponding module is
torch.nn.Identity; theconditionargument is only consumed when the conditioning branch is active.Path without conditioning (
condition_mixeris Identity):x → input_norm → sequence_mixer → dropout →(+)→ x' x' → mlp_norm → mlp → dropout →(+)→ output
Path with conditioning (
condition_mixeris not Identity):x → input_norm → sequence_mixer → dropout →(+)→ x' x' → condition_mixer_norm → condition_mixer(x', condition) → dropout →(+)→ x'' x'' → mlp_norm → mlp → dropout →(+)→ output
- Parameters:
x (Tensor) – Input feature tensor of shape
(B, *spatial_dims, C)whereBis the batch size,spatial_dimsare one or more spatial axes (e.g.(H, W)for 2-D images or(T,)for 1-D sequences), andCis the hidden channel dimension.condition (Tensor | None) – Conditioning tensor used by
condition_mixer. Its shape depends on the conditioning operator — a common choice is(B, *spatial_dims_condition, C)for cross-attention, or(B, C)for a global conditioning vector. The channel dimensionCmust match that ofx. This argument is ignored (and may safely beNone) whencondition_mixeristorch.nn.Identity.
- Returns:
Output tensor of the same shape as
x:(B, *spatial_dims, C).- Return type:
- Raises:
AssertionError – If
conditionisNonebutcondition_mixeris nottorch.nn.Identity(i.e. a conditioning tensor is required but was not provided).