AdaLNZeroResidualBlock#

class AdaLNZeroResidualBlock(
sequence_mixer_cfg,
sequence_mixer_norm_cfg,
mlp_cfg,
mlp_norm_cfg,
condition_norm_cfg,
dropout_cfg,
hidden_dim,
)#

Bases: Module

Pre-norm residual block with AdaLN-Zero conditioning (DiT-style).

Replaces fixed LayerNorm with Adaptive LayerNorm-Zero (AdaLN-Zero) modulation, following the Scalable Diffusion Transformers (DiT) recipe (Peebles & Xie, 2023). A single zero-initialised linear projection maps the conditioning vector to six affine parameters — shift, scale, and gate for each of the two branches — so that at initialisation the block outputs exactly zero (the residual stream is unchanged).

Unlike ResidualBlock, this block has no separate cross-attention / condition-mixer branch. All conditioning is routed through the AdaLN-Zero projection plus a conditioning kwarg forwarded directly into the sequence mixer (e.g. for FiLM inside Hyena). There are therefore two conditioning pathways in a single forward pass:

  1. AdaLN-Zero affine modulation — shift/scale/gate applied to the pre-norm outputs of the sequence mixer and MLP branches.

  2. Inner-mixer conditioning — the pooled conditioning vector cond is also passed as conditioning=cond to sequence_mixer.forward, allowing the inner operator (e.g. Hyena with FiLM kernel conditioning) to use it independently.

Forward computation (one block):

cond → [optional spatial mean] → condition_norm
     → SiLU → Linear(C, 6C) → split into 6 × (B, C)
       (shift_seq, scale_seq, gate_seq, shift_mlp, scale_mlp, gate_mlp)

# Sequence mixer branch
x_norm = sequence_norm(x)                         # pre-norm
x_mod  = x_norm * (1 + scale_seq) + shift_seq    # AdaLN modulation
seq_out = sequence_mixer(x_mod, conditioning=cond)  # cond also forwarded
seq_out = dropout(seq_out) * gate_seq             # zero-init gate
x = x + seq_out

# MLP branch
x_norm = mlp_norm(x)
x_mod  = x_norm * (1 + scale_mlp) + shift_mlp
mlp_out = mlp(x_mod)
mlp_out = dropout(mlp_out) * gate_mlp
x = x + mlp_out

The gate vectors are multiplied after dropout to provide per-token scaling; at init they are zero (because condition_proj weights are zero-initialised), so the block is a skip connection.

See also

ResidualNetwork and ClassificationResNet for the canonical consumers of this block.

Parameters:
sequence_mixer#

Instantiated sequence-mixing operator. Its forward must accept a conditioning keyword argument (forwarded from the pooled conditioning vector cond).

Type:

torch.nn.Module

sequence_norm#

Pre-norm applied to x before AdaLN modulation in the sequence mixer branch.

Type:

torch.nn.Module

mlp#

Position-wise MLP instantiated from mlp_cfg.

Type:

torch.nn.Module

mlp_norm#

Pre-norm applied to x before AdaLN modulation in the MLP branch.

Type:

torch.nn.Module

condition_norm#

Optional normalisation applied to the conditioning vector before it is projected by condition_proj. Tagged _no_weight_decay.

Type:

torch.nn.Module

dropout#

Dropout / stochastic-depth applied after each branch, before the gate multiply.

Type:

torch.nn.Module

condition_proj#

SiLU Linear(C, 6C) with zero-initialised weights and biases, producing the six AdaLN-Zero parameters. Zero init ensures the block is a pure residual connection at the start of training.

Type:

torch.nn.Sequential

__init__(
sequence_mixer_cfg,
sequence_mixer_norm_cfg,
mlp_cfg,
mlp_norm_cfg,
condition_norm_cfg,
dropout_cfg,
hidden_dim,
)#

Initialise the AdaLNZeroResidualBlock.

Parameters:
  • sequence_mixer_cfg (LazyConfig) – LazyConfig for the sequence-mixing operator. The instantiated module’s forward must accept a conditioning keyword argument (it receives the spatially- pooled conditioning vector cond of shape (B, C)).

  • sequence_mixer_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the AdaLN modulation of the sequence mixer branch (e.g. LayerNorm(hidden_dim)).

  • mlp_cfg (LazyConfig) – LazyConfig for the position-wise MLP.

  • mlp_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the AdaLN modulation of the MLP branch.

  • condition_norm_cfg (LazyConfig) – LazyConfig for the normalisation applied to the conditioning vector before condition_proj. Use torch.nn.Identity to skip conditioning normalisation.

  • dropout_cfg (LazyConfig) – LazyConfig for the dropout / stochastic-depth module applied after each branch and before the gate multiply.

  • hidden_dim (int) – Channel dimension C shared by all sub-modules. Used to size the condition_proj linear layer (Linear(C, 6*C)). This value must match the channel dimension baked into sequence_mixer_cfg and mlp_cfg — there is no runtime check, and a mismatch will produce a shape error deep inside condition_proj.

forward(x, condition)#

Apply AdaLN-Zero residual mixing conditioned on the provided tensor.

The conditioning tensor is reduced to a single latent vector per sample (spatial mean if it has spatial axes) before being projected to six affine parameters. These parameters modulate the pre-norm outputs of both branches via element-wise affine transforms, and gate each branch’s output before the residual add.

There are two conditioning pathways:

  1. AdaLN-Zero — shift/scale/gate parameters derived from condition_proj(cond) are applied to the pre-norm outputs of both the sequence-mixer and MLP branches.

  2. Inner-mixer conditioning — the same pooled cond vector is also forwarded as conditioning=cond into self.sequence_mixer, allowing operators such as Hyena to apply additional FiLM modulation to their kernel networks.

Parameters:
  • x (Tensor) – Input feature tensor of shape (B, *spatial_dims, C) where B is the batch size, spatial_dims are one or more spatial axes, and C = hidden_dim.

  • condition (Tensor | None) –

    Required conditioning tensor. Shape may be:

    • (B, C) — a pre-pooled global conditioning vector (e.g. a timestep / class embedding from a diffusion model).

    • (B, *spatial_dims_cond, C) — any spatial layout; the forward pass reduces it to (B, C) via a mean over all non-batch, non-channel axes.

    Must not be None.

Returns:

Output tensor of shape (B, *spatial_dims, C), the same shape as x.

Return type:

torch.Tensor

Raises:

ValueError – If condition is None.