AdaLNZeroResidualBlock#
- class AdaLNZeroResidualBlock(
- sequence_mixer_cfg,
- sequence_mixer_norm_cfg,
- mlp_cfg,
- mlp_norm_cfg,
- condition_norm_cfg,
- dropout_cfg,
- hidden_dim,
Bases:
ModulePre-norm residual block with AdaLN-Zero conditioning (DiT-style).
Replaces fixed LayerNorm with Adaptive LayerNorm-Zero (AdaLN-Zero) modulation, following the Scalable Diffusion Transformers (DiT) recipe (Peebles & Xie, 2023). A single zero-initialised linear projection maps the conditioning vector to six affine parameters — shift, scale, and gate for each of the two branches — so that at initialisation the block outputs exactly zero (the residual stream is unchanged).
Unlike
ResidualBlock, this block has no separate cross-attention / condition-mixer branch. All conditioning is routed through the AdaLN-Zero projection plus aconditioningkwarg forwarded directly into the sequence mixer (e.g. for FiLM inside Hyena). There are therefore two conditioning pathways in a single forward pass:AdaLN-Zero affine modulation — shift/scale/gate applied to the pre-norm outputs of the sequence mixer and MLP branches.
Inner-mixer conditioning — the pooled conditioning vector
condis also passed asconditioning=condtosequence_mixer.forward, allowing the inner operator (e.g.Hyenawith FiLM kernel conditioning) to use it independently.
Forward computation (one block):
cond → [optional spatial mean] → condition_norm → SiLU → Linear(C, 6C) → split into 6 × (B, C) (shift_seq, scale_seq, gate_seq, shift_mlp, scale_mlp, gate_mlp) # Sequence mixer branch x_norm = sequence_norm(x) # pre-norm x_mod = x_norm * (1 + scale_seq) + shift_seq # AdaLN modulation seq_out = sequence_mixer(x_mod, conditioning=cond) # cond also forwarded seq_out = dropout(seq_out) * gate_seq # zero-init gate x = x + seq_out # MLP branch x_norm = mlp_norm(x) x_mod = x_norm * (1 + scale_mlp) + shift_mlp mlp_out = mlp(x_mod) mlp_out = dropout(mlp_out) * gate_mlp x = x + mlp_outThe gate vectors are multiplied after dropout to provide per-token scaling; at init they are zero (because
condition_projweights are zero-initialised), so the block is a skip connection.See also
ResidualNetworkandClassificationResNetfor the canonical consumers of this block.- Parameters:
sequence_mixer_cfg (LazyConfig)
sequence_mixer_norm_cfg (LazyConfig)
mlp_cfg (LazyConfig)
mlp_norm_cfg (LazyConfig)
condition_norm_cfg (LazyConfig)
dropout_cfg (LazyConfig)
hidden_dim (int)
- sequence_mixer#
Instantiated sequence-mixing operator. Its
forwardmust accept aconditioningkeyword argument (forwarded from the pooled conditioning vectorcond).- Type:
- sequence_norm#
Pre-norm applied to
xbefore AdaLN modulation in the sequence mixer branch.- Type:
- mlp#
Position-wise MLP instantiated from
mlp_cfg.- Type:
- mlp_norm#
Pre-norm applied to
xbefore AdaLN modulation in the MLP branch.- Type:
- condition_norm#
Optional normalisation applied to the conditioning vector before it is projected by
condition_proj. Tagged_no_weight_decay.- Type:
- dropout#
Dropout / stochastic-depth applied after each branch, before the gate multiply.
- Type:
- condition_proj#
SiLU → Linear(C, 6C)with zero-initialised weights and biases, producing the six AdaLN-Zero parameters. Zero init ensures the block is a pure residual connection at the start of training.- Type:
- __init__(
- sequence_mixer_cfg,
- sequence_mixer_norm_cfg,
- mlp_cfg,
- mlp_norm_cfg,
- condition_norm_cfg,
- dropout_cfg,
- hidden_dim,
Initialise the AdaLNZeroResidualBlock.
- Parameters:
sequence_mixer_cfg (LazyConfig) – LazyConfig for the sequence-mixing operator. The instantiated module’s
forwardmust accept aconditioningkeyword argument (it receives the spatially- pooled conditioning vectorcondof shape(B, C)).sequence_mixer_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the AdaLN modulation of the sequence mixer branch (e.g.
LayerNorm(hidden_dim)).mlp_cfg (LazyConfig) – LazyConfig for the position-wise MLP.
mlp_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the AdaLN modulation of the MLP branch.
condition_norm_cfg (LazyConfig) – LazyConfig for the normalisation applied to the conditioning vector before
condition_proj. Usetorch.nn.Identityto skip conditioning normalisation.dropout_cfg (LazyConfig) – LazyConfig for the dropout / stochastic-depth module applied after each branch and before the gate multiply.
hidden_dim (int) – Channel dimension
Cshared by all sub-modules. Used to size thecondition_projlinear layer (Linear(C, 6*C)). This value must match the channel dimension baked intosequence_mixer_cfgandmlp_cfg— there is no runtime check, and a mismatch will produce a shape error deep insidecondition_proj.
- forward(x, condition)#
Apply AdaLN-Zero residual mixing conditioned on the provided tensor.
The conditioning tensor is reduced to a single latent vector per sample (spatial mean if it has spatial axes) before being projected to six affine parameters. These parameters modulate the pre-norm outputs of both branches via element-wise affine transforms, and gate each branch’s output before the residual add.
There are two conditioning pathways:
AdaLN-Zero — shift/scale/gate parameters derived from
condition_proj(cond)are applied to the pre-norm outputs of both the sequence-mixer and MLP branches.Inner-mixer conditioning — the same pooled
condvector is also forwarded asconditioning=condintoself.sequence_mixer, allowing operators such asHyenato apply additional FiLM modulation to their kernel networks.
- Parameters:
x (Tensor) – Input feature tensor of shape
(B, *spatial_dims, C)whereBis the batch size,spatial_dimsare one or more spatial axes, andC = hidden_dim.condition (Tensor | None) –
Required conditioning tensor. Shape may be:
(B, C)— a pre-pooled global conditioning vector (e.g. a timestep / class embedding from a diffusion model).(B, *spatial_dims_cond, C)— any spatial layout; the forward pass reduces it to(B, C)via a mean over all non-batch, non-channel axes.
Must not be
None.
- Returns:
Output tensor of shape
(B, *spatial_dims, C), the same shape asx.- Return type:
- Raises:
ValueError – If
conditionisNone.