ViT5ResidualBlock#
- class ViT5ResidualBlock(
- sequence_mixer_cfg,
- sequence_mixer_norm_cfg,
- mlp_cfg,
- mlp_norm_cfg,
- hidden_dim,
- layer_scale_init=1e-4,
- drop_path_rate=0.0,
- register_pooling_cfg=None,
- num_registers=0,
- register_start_idx=1,
- grn_cfg=None,
Bases:
ModuleViT-5 style residual block with LayerScale and stochastic depth.
Implements the two-branch pre-norm transformer block used in the ViT-5 family. Each forward pass executes:
# Branch 1 — sequence mixer x_normed = input_norm(x) [cond = register_pooling(x_normed[:, s:s+R, :]) # optional] mixer_out = sequence_mixer(x_normed[, conditioning=cond]) [mixer_out = grn(mixer_out) # optional] x = x + drop_path(ls_attn(mixer_out)) # Branch 2 — MLP x = x + drop_path(ls_mlp(mlp(mlp_norm(x))))
Differences vs. the generic
ResidualBlock:No condition-mixer branch. Conditioning is handled by register pooling inside branch 1 (see
register_pooling_cfg).Each branch has its own
LayerScale(ls_attn/ls_mlp) rather than a single shared dropout.A single
DropPathinstance is shared across both branches (drop_path).Input is always
[B, T, C]— the spatial dimensions are fully flattened into the token axis, and register tokens occupy known index positions.
Register-token conditioning:
When
register_pooling_cfgis notNoneandnum_registers > 0, register tokens are extracted from the normalised input at positions[register_start_idx : register_start_idx + num_registers], pooled byregister_poolinginto a(B, C)conditioning vector, and passed to the sequence mixer asconditioning=<vector>. This lets the mixer (typicallyViT5AttentionorViT5HyenaAdapter) apply FiLM modulation to its internal kernel.- Parameters:
sequence_mixer_cfg (LazyConfig)
sequence_mixer_norm_cfg (LazyConfig)
mlp_cfg (LazyConfig)
mlp_norm_cfg (LazyConfig)
hidden_dim (int)
layer_scale_init (float)
drop_path_rate (float)
register_pooling_cfg (LazyConfig | None)
num_registers (int)
register_start_idx (int)
grn_cfg (LazyConfig | None)
- input_norm#
Pre-norm applied before the sequence mixer. Parameters are tagged
_no_weight_decay = True.- Type:
- sequence_mixer#
Instantiated sequence-mixing operator — typically a
ViT5AttentionorViT5HyenaAdapter(both include their own QKV / output projections).- Type:
- mlp_norm#
Pre-norm applied before the MLP. Parameters are tagged
_no_weight_decay = True.- Type:
- mlp#
Position-wise MLP applied after
mlp_norm.- Type:
- ls_attn#
Per-element learnable scale for the sequence mixer branch.
nn.Identitywhenlayer_scale_init == 0.- Type:
- ls_mlp#
Per-element learnable scale for the MLP branch.
nn.Identitywhenlayer_scale_init == 0.- Type:
- drop_path#
Stochastic depth applied after both LayerScale modules.
nn.Identitywhendrop_path_rate == 0.- Type:
- register_pooling#
Optional module that maps
[B, num_registers, C]to a(B, C)conditioning vector.Nonewhen register-based conditioning is disabled.- Type:
torch.nn.Module | None
- grn#
Optional Global Response Normalization (ConvNeXt V2) applied to the sequence mixer output before
ls_attn.Nonewhen disabled.- Type:
torch.nn.Module | None
- __init__(
- sequence_mixer_cfg,
- sequence_mixer_norm_cfg,
- mlp_cfg,
- mlp_norm_cfg,
- hidden_dim,
- layer_scale_init=1e-4,
- drop_path_rate=0.0,
- register_pooling_cfg=None,
- num_registers=0,
- register_start_idx=1,
- grn_cfg=None,
Instantiate norms, sequence mixer, MLP, and optional register pooling.
- Parameters:
sequence_mixer_cfg (LazyConfig) – LazyConfig for the sequence mixer. Typical targets are
ViT5AttentionorViT5HyenaAdapter. Both include QKV and output projections internally (unlike the generic block where projections live inQKVSequenceMixer).sequence_mixer_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the sequence mixer, e.g.
RMSNorm(hidden_dim).mlp_cfg (LazyConfig) – LazyConfig for the position-wise MLP, e.g.
MLP.mlp_norm_cfg (LazyConfig) – LazyConfig for the pre-norm applied before the MLP, e.g.
RMSNorm(hidden_dim).hidden_dim (int) – Channel dimension
Cshared by all sub-modules. Used to sizeLayerScale(one learnable scalar per channel).layer_scale_init (float) – Initial value for both
ls_attnandls_mlpLayerScale gammas. Set to0to replace LayerScale withnn.Identity(disables per-channel learned scaling entirely). Typical values:1e-4(early training) or1.0(fine-tune from a strong checkpoint).drop_path_rate (float) – Stochastic depth probability. Set to
0.0to replaceDropPathwithnn.Identity(no drop during training). A singleDropPathinstance is shared between both branches.register_pooling_cfg (LazyConfig | None) – Optional LazyConfig for a register pooling module whose
forward(regs)accepts[B, num_registers, C]and returns(B, C). WhenNoneor whennum_registers == 0, register conditioning is disabled and the sequence mixer is called without aconditioningkwarg.num_registers (int) – Number of register tokens
Rin the sequence. Must be consistent with the token layout baked intosequence_mixer_cfg(e.g.ViT5Attention.num_registers). Only used whenregister_pooling_cfgis notNone.register_start_idx (int) – Zero-based token index at which the register block begins. With the standard ViT-5 token layout
[patches, CLS, registers], this equalsnum_patches_h * num_patches_w + 1for CLS-readout models andnum_patches_h * num_patches_wfor GAP-readout models. Typically injected by the network constructor.grn_cfg (LazyConfig | None) – Optional LazyConfig for a
GlobalResponseNormmodule. When provided, GRN is applied to the sequence mixer output ([B, T, C]) beforels_attn, promoting inter-channel feature competition (ConvNeXt V2 recipe).
- flop_count(num_tokens, inference=False)#
Count FLOPs for one ViT-5 residual block.
Counts MACs multiplied by 2 (multiply + add) for every sub-module that exposes a
flop_countmethod, and falls back to 0 for modules that do not (e.g. GRN, whenflop_countis absent).Computation graph (one forward pass):
input_norm → flop_count(T) register_pooling → flop_count(D) [when enabled] sequence_mixer → flop_count(T, inf) grn → flop_count(T) [when enabled; 0 if absent] ls_attn → flop_count(T) [when LayerScale, not Identity] mlp_norm → flop_count(T) mlp → flop_count(T) ls_mlp → flop_count(T) [when LayerScale, not Identity]
drop_pathcontributes 0 FLOPs (stochastic identity — no arithmetic on active samples beyond the residual add, which is counted separately by the caller if desired).- Parameters:
num_tokens (int) – Sequence length
Tpassed to each sub-module. This should equalnum_patches + (1 if has_cls else 0) + num_registers.inference (bool) – Passed through to
self.sequence_mixer.flop_count. Some sequence mixers (e.g. Hyena with precomputed kernels) have lower inference FLOPs than training FLOPs.
- Returns:
Total FLOPs as an integer. Does not include the two residual adds (
2 * T * CFLOPs each), which are typically negligible.- Return type:
- forward(x, condition=None)#
Apply the ViT-5 residual block.
Executes two residual branches in sequence:
Sequence mixer branch — normalise, optionally extract a register conditioning vector, run the sequence mixer (and optional GRN), scale with LayerScale, apply stochastic depth, add residual.
MLP branch — normalise, run MLP, scale with LayerScale, apply stochastic depth, add residual.
Register conditioning detail: when
self.register_poolingis notNone, the slicex_normed[:, register_start_idx : register_start_idx + num_registers, :]is extracted from the normalised input and pooled to shape(B, C). This vector is forwarded to the sequence mixer asconditioning=<vector>, which the mixer uses for FiLM modulation (e.g. scaling SIREN kernel features).- Parameters:
x (Tensor) – Input token sequence of shape
[B, T, C].Bis the batch size andCis the channel (hidden) dimension.T = num_patches + (1 if has_cls else 0) + num_registers (+ pad_size for Hyena blocks)is the total token count following the ViT-5 layout[patches, (CLS,) registers, (padding,)]. Attention blocks receive the unpadded sequence; Hyena blocks receive the zero-padded sequence soT % grid_w == 0.condition (Tensor) – Accepted for API compatibility with
ResidualBlockbut always ignored in this class. ViT-5 conditioning is routed through register pooling, not through this argument. PassNone(the default) when calling directly.
- Returns:
Output tensor of shape
[B, T, C], the same shape asx.- Return type: