ViT5ResidualBlock#

class ViT5ResidualBlock(
sequence_mixer_cfg,
sequence_mixer_norm_cfg,
mlp_cfg,
mlp_norm_cfg,
hidden_dim,
layer_scale_init=1e-4,
drop_path_rate=0.0,
register_pooling_cfg=None,
num_registers=0,
register_start_idx=1,
grn_cfg=None,
)#

Bases: Module

ViT-5 style residual block with LayerScale and stochastic depth.

Implements the two-branch pre-norm transformer block used in the ViT-5 family. Each forward pass executes:

# Branch 1 — sequence mixer
x_normed = input_norm(x)
[cond = register_pooling(x_normed[:, s:s+R, :])  # optional]
mixer_out = sequence_mixer(x_normed[, conditioning=cond])
[mixer_out = grn(mixer_out)                       # optional]
x = x + drop_path(ls_attn(mixer_out))

# Branch 2 — MLP
x = x + drop_path(ls_mlp(mlp(mlp_norm(x))))

Differences vs. the generic ResidualBlock:

  • No condition-mixer branch. Conditioning is handled by register pooling inside branch 1 (see register_pooling_cfg).

  • Each branch has its own LayerScale (ls_attn / ls_mlp) rather than a single shared dropout.

  • A single DropPath instance is shared across both branches (drop_path).

  • Input is always [B, T, C] — the spatial dimensions are fully flattened into the token axis, and register tokens occupy known index positions.

Register-token conditioning:

When register_pooling_cfg is not None and num_registers > 0, register tokens are extracted from the normalised input at positions [register_start_idx : register_start_idx + num_registers], pooled by register_pooling into a (B, C) conditioning vector, and passed to the sequence mixer as conditioning=<vector>. This lets the mixer (typically ViT5Attention or ViT5HyenaAdapter) apply FiLM modulation to its internal kernel.

Parameters:
input_norm#

Pre-norm applied before the sequence mixer. Parameters are tagged _no_weight_decay = True.

Type:

torch.nn.Module

sequence_mixer#

Instantiated sequence-mixing operator — typically a ViT5Attention or ViT5HyenaAdapter (both include their own QKV / output projections).

Type:

torch.nn.Module

mlp_norm#

Pre-norm applied before the MLP. Parameters are tagged _no_weight_decay = True.

Type:

torch.nn.Module

mlp#

Position-wise MLP applied after mlp_norm.

Type:

torch.nn.Module

ls_attn#

Per-element learnable scale for the sequence mixer branch. nn.Identity when layer_scale_init == 0.

Type:

LayerScale | torch.nn.Identity

ls_mlp#

Per-element learnable scale for the MLP branch. nn.Identity when layer_scale_init == 0.

Type:

LayerScale | torch.nn.Identity

drop_path#

Stochastic depth applied after both LayerScale modules. nn.Identity when drop_path_rate == 0.

Type:

DropPath | torch.nn.Identity

register_pooling#

Optional module that maps [B, num_registers, C] to a (B, C) conditioning vector. None when register-based conditioning is disabled.

Type:

torch.nn.Module | None

grn#

Optional Global Response Normalization (ConvNeXt V2) applied to the sequence mixer output before ls_attn. None when disabled.

Type:

torch.nn.Module | None

num_registers#

Number of register tokens per sample.

Type:

int

register_start_idx#

Token index at which register tokens begin.

Type:

int

__init__(
sequence_mixer_cfg,
sequence_mixer_norm_cfg,
mlp_cfg,
mlp_norm_cfg,
hidden_dim,
layer_scale_init=1e-4,
drop_path_rate=0.0,
register_pooling_cfg=None,
num_registers=0,
register_start_idx=1,
grn_cfg=None,
)#

Instantiate norms, sequence mixer, MLP, and optional register pooling.

Parameters:
  • sequence_mixer_cfg (LazyConfig) – LazyConfig for the sequence mixer. Typical targets are ViT5Attention or ViT5HyenaAdapter. Both include QKV and output projections internally (unlike the generic block where projections live in QKVSequenceMixer).

  • sequence_mixer_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the sequence mixer, e.g. RMSNorm(hidden_dim).

  • mlp_cfg (LazyConfig) – LazyConfig for the position-wise MLP, e.g. MLP.

  • mlp_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the MLP, e.g. RMSNorm(hidden_dim).

  • hidden_dim (int) – Channel dimension C shared by all sub-modules. Used to size LayerScale (one learnable scalar per channel).

  • layer_scale_init (float) – Initial value for both ls_attn and ls_mlp LayerScale gammas. Set to 0 to replace LayerScale with nn.Identity (disables per-channel learned scaling entirely). Typical values: 1e-4 (early training) or 1.0 (fine-tune from a strong checkpoint).

  • drop_path_rate (float) – Stochastic depth probability. Set to 0.0 to replace DropPath with nn.Identity (no drop during training). A single DropPath instance is shared between both branches.

  • register_pooling_cfg (LazyConfig | None) – Optional LazyConfig for a register pooling module whose forward(regs) accepts [B, num_registers, C] and returns (B, C). When None or when num_registers == 0, register conditioning is disabled and the sequence mixer is called without a conditioning kwarg.

  • num_registers (int) – Number of register tokens R in the sequence. Must be consistent with the token layout baked into sequence_mixer_cfg (e.g. ViT5Attention.num_registers). Only used when register_pooling_cfg is not None.

  • register_start_idx (int) – Zero-based token index at which the register block begins. With the standard ViT-5 token layout [patches, CLS, registers], this equals num_patches_h * num_patches_w + 1 for CLS-readout models and num_patches_h * num_patches_w for GAP-readout models. Typically injected by the network constructor.

  • grn_cfg (LazyConfig | None) – Optional LazyConfig for a GlobalResponseNorm module. When provided, GRN is applied to the sequence mixer output ([B, T, C]) before ls_attn, promoting inter-channel feature competition (ConvNeXt V2 recipe).

flop_count(num_tokens, inference=False)#

Count FLOPs for one ViT-5 residual block.

Counts MACs multiplied by 2 (multiply + add) for every sub-module that exposes a flop_count method, and falls back to 0 for modules that do not (e.g. GRN, when flop_count is absent).

Computation graph (one forward pass):

input_norm        → flop_count(T)
register_pooling  → flop_count(D)        [when enabled]
sequence_mixer    → flop_count(T, inf)
grn               → flop_count(T)        [when enabled; 0 if absent]
ls_attn           → flop_count(T)        [when LayerScale, not Identity]
mlp_norm          → flop_count(T)
mlp               → flop_count(T)
ls_mlp            → flop_count(T)        [when LayerScale, not Identity]

drop_path contributes 0 FLOPs (stochastic identity — no arithmetic on active samples beyond the residual add, which is counted separately by the caller if desired).

Parameters:
  • num_tokens (int) – Sequence length T passed to each sub-module. This should equal num_patches + (1 if has_cls else 0) + num_registers.

  • inference (bool) – Passed through to self.sequence_mixer.flop_count. Some sequence mixers (e.g. Hyena with precomputed kernels) have lower inference FLOPs than training FLOPs.

Returns:

Total FLOPs as an integer. Does not include the two residual adds (2 * T * C FLOPs each), which are typically negligible.

Return type:

int

forward(x, condition=None)#

Apply the ViT-5 residual block.

Executes two residual branches in sequence:

  1. Sequence mixer branch — normalise, optionally extract a register conditioning vector, run the sequence mixer (and optional GRN), scale with LayerScale, apply stochastic depth, add residual.

  2. MLP branch — normalise, run MLP, scale with LayerScale, apply stochastic depth, add residual.

Register conditioning detail: when self.register_pooling is not None, the slice x_normed[:, register_start_idx : register_start_idx + num_registers, :] is extracted from the normalised input and pooled to shape (B, C). This vector is forwarded to the sequence mixer as conditioning=<vector>, which the mixer uses for FiLM modulation (e.g. scaling SIREN kernel features).

Parameters:
  • x (Tensor) – Input token sequence of shape [B, T, C]. B is the batch size and C is the channel (hidden) dimension. T = num_patches + (1 if has_cls else 0) + num_registers (+ pad_size for Hyena blocks) is the total token count following the ViT-5 layout [patches, (CLS,) registers, (padding,)]. Attention blocks receive the unpadded sequence; Hyena blocks receive the zero-padded sequence so T % grid_w == 0.

  • condition (Tensor) – Accepted for API compatibility with ResidualBlock but always ignored in this class. ViT-5 conditioning is routed through register pooling, not through this argument. Pass None (the default) when calling directly.

Returns:

Output tensor of shape [B, T, C], the same shape as x.

Return type:

torch.Tensor