ViT5Attention#
- class ViT5Attention(
- hidden_dim,
- num_heads,
- num_patches_h,
- num_patches_w,
- num_registers=4,
- has_cls=True,
- qk_norm=None,
- rope_base=10000.0,
- reg_rope_base=100.0,
- attn_dropout=0.0,
- proj_dropout=0.0,
- qkv_bias=False,
- out_proj_bias=False,
- scale=None,
- init_fn_qkv_proj=None,
- init_fn_out_proj=None,
Bases:
ModuleViT-5 multi-head self-attention with RMSNorm QK-Norm and register-aware RoPE.
This module is the primary sequence-mixing operator for the ViT-5 family of hierarchical vision transformers. It computes standard scaled dot-product attention:
\[\text{head}_i = \text{softmax}\!\left( \frac{Q_i K_i^\top}{\sqrt{d_k}} \right) V_i, \quad \text{out} = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) W_O\]where \(d_k = C / H\) is the per-head dimension and \(W_O\) is the output projection. Q, K, V are obtained from a single fused linear projection \([Q, K, V] = x W_{QKV}\).
Token layout
Input shape:
[B, T, C]whereT = num_patches_h * num_patches_w + (1 if has_cls else 0) + num_registers. Token ordering within the sequence axis:[ patch_0, patch_1, ..., patch_{H*W-1}, (CLS,) reg_0, ..., reg_{R-1} ] <----- H*W patch tokens ---------> <--1--> <---- R registers ---->This ordering must be consistent with the token layout produced by the network’s patchify + register-injection layers (see
ViT5Classifier).Positional encoding
Three distinct positional encodings are applied:
Patch tokens — 2D RoPE with base frequency
rope_base(default 10000). The H×W grid is linearised in row-major (Y-then-X) order.CLS token — identity rotation: cos=1, sin=0. No positional bias is imposed on the class token.
Register tokens — 2D RoPE with base
reg_rope_base(default 100), treating theRregisters as asqrt(R) × sqrt(R)grid. A lower base value (reg_rope_base=100vsrope_base=10000) yields higher rotation frequencies (theta decays more slowly across head-dim pairs), giving denser angular spacing for register positions. This reflects their role as global context carriers without fixed spatial meaning.
All three tables are concatenated into a single buffer pair (
rope_cos,rope_sin) of shape[T, head_dim]and applied with a single broadcast multiply inforward().QK normalisation
When
qk_normis provided, two independent norm modules (q_norm,k_norm) are instantiated and applied to Q and K afterqkv.unbind()produces tensors of shape[B, T, H, d_k], and before RoPE. The norm is expected to be a learnable RMSNorm or equivalent (accepting input of shape[B, T, H, d_k]and normalising along the last axis). Unlike the genericAttentionmodule which uses a fixed L2 (cosine) normalisation, the learnable per-head norm here allows the model to control the scale of the dot products.Note
Norm is applied before RoPE in this module (order:
unbind → q_norm/k_norm → rope → SDPA), whereas the genericAttentionapplies RoPE before L2-norm. The order matters for checkpoint compatibility — swapping the two will change the effective positional encoding applied to normalised queries and keys.Differences vs.
AttentionSelf-contained QKV + output projections (generic uses outer
QKVSequenceMixer).RMSNorm QK-Norm instead of L2 normalisation.
Dual-base register-aware RoPE instead of single-base uniform RoPE.
Fixed
[B, T, C]input — no multi-dimensional spatial support, no causal masking, no context-parallelism guard.
Total channel dimension
C.- Type:
- has_cls#
Whether the token sequence includes a CLS token between the patch tokens and the register tokens.
- Type:
- attn_dropout#
Dropout probability applied to attention weights during training; set to 0.0 at inference.
- Type:
- qkv#
Fused QKV projection:
Linear(C, 3C, bias=qkv_bias).- Type:
nn.Linear
- proj#
Output projection:
Linear(C, C, bias=out_proj_bias).- Type:
nn.Linear
- proj_drop#
Dropout on the projected output.
- Type:
nn.Dropout | nn.Identity
- q_norm#
Per-head query normaliser. Present only when
qk_normis provided (i.e.self.qk_norm is True).- Type:
nn.Module
- k_norm#
Per-head key normaliser. Present only when
qk_normis provided.- Type:
nn.Module
- rope_cos#
Non-persistent buffer of shape
[T, head_dim]containing the concatenated patch + CLS + register cosine tables.- Type:
- rope_sin#
Non-persistent buffer of shape
[T, head_dim]containing the concatenated patch + CLS + register sine tables.- Type:
- Parameters:
hidden_dim (int) – Total hidden dimension
C. Must be divisible bynum_heads.num_heads (int) – Number of attention heads
H.num_patches_h (int) – Height of the patch grid (number of patch rows). Used to build the patch 2D RoPE table.
num_patches_w (int) – Width of the patch grid (number of patch columns). Used to build the patch 2D RoPE table.
num_registers (int) – Number of register tokens
Rappended after the (optional) CLS token. Should be a perfect square when > 0 so that the register RoPE grid is exactlysqrt(R) × sqrt(R). IfRis not a perfect square,reg_rope_h = reg_rope_w = int(R**0.5)silently truncates, producing onlyreg_rope_h * reg_rope_w < RRoPE rows and causing atorch.catshape mismatch at init time. Defaults to4.has_cls (bool) – If
True, the token sequence contains one CLS token immediately after the patch tokens. The CLS token receives identity RoPE (cos=1, sin=0). Defaults toTrue.qk_norm (LazyConfig | None) –
LazyConfigfor the per-head QK normalisation module (e.g.RMSNorm(head_dim)). WhenNone, QK normalisation is disabled. Defaults toNone.rope_base (float) – Base frequency \(\\theta_0\) for the patch RoPE frequency schedule. Defaults to
10000.0.reg_rope_base (float) – Base frequency for the register-token RoPE schedule. A lower base (higher frequency) gives denser angular spacing. Defaults to
100.0.attn_dropout (float) – Dropout rate on attention weights, applied only during training (
module.training is True). Defaults to0.0.proj_dropout (float) – Dropout rate on the output projection. When
0.0,proj_dropis annn.Identity. Defaults to0.0.qkv_bias (bool) – Whether to include a bias term in the fused QKV projection. Defaults to
False.out_proj_bias (bool) – Whether to include a bias term in the output projection. Defaults to
False.scale (float | None) – Explicit attention logit scale. When
None, the scale defaults tohead_dim ** -0.5. Defaults toNone.init_fn_qkv_proj (Callable[[Tensor], None] | None) – Optional callable
fn(weight: Tensor) -> Noneapplied toself.qkv.weightafter construction. The bias, if present, is zero-initialised. WhenNone, PyTorch’s default Xavier uniform initialisation is used. Defaults toNone.init_fn_out_proj (Callable[[Tensor], None] | None) – Optional callable
fn(weight: Tensor) -> Noneapplied toself.proj.weightafter construction. The bias, if present, is zero-initialised. WhenNone, PyTorch’s default initialisation is used. Defaults toNone.
- Raises:
AssertionError – If
hidden_dim % num_heads != 0.
Example:
import torch from nvsubquadratic.modules.vit5_attention import ViT5Attention # 2D patch grid of 14x14 with 4 register tokens and 1 CLS token, no QK norm attn = ViT5Attention( hidden_dim=384, num_heads=6, num_patches_h=14, num_patches_w=14, num_registers=4, has_cls=True, ) T = 14 * 14 + 1 + 4 # patches + CLS + registers = 201 x = torch.randn(2, T, 384) # [B, T, C] out = attn(x) # [B, T, C] assert out.shape == x.shape # To enable QK-norm, pass a LazyConfig targeting any norm module # that accepts [B, T, H, d_k] tensors and normalises along the last axis.
- __init__(
- hidden_dim,
- num_heads,
- num_patches_h,
- num_patches_w,
- num_registers=4,
- has_cls=True,
- qk_norm=None,
- rope_base=10000.0,
- reg_rope_base=100.0,
- attn_dropout=0.0,
- proj_dropout=0.0,
- qkv_bias=False,
- out_proj_bias=False,
- scale=None,
- init_fn_qkv_proj=None,
- init_fn_out_proj=None,
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- flop_count(num_tokens, inference=False)#
Count FLOPs for multi-head self-attention on
num_tokenstokens.The
inferenceflag is accepted for API consistency but does not change the count — attention has no cacheable precomputation analogous to SIREN kernels.Let T = num_tokens, D =
self.hidden_dim.- FLOPs breakdown:
QKV projection (Linear(D, 3D)): 6 * T * D² Three projections packed into one: 2 * T * D * 3D.
QK-Norm (2x RMSNorm on Q and K): Delegated to self.q_norm / self.k_norm. Only counted when
self.qk_normis True; 0 otherwise. Each norm module must exposeflop_count(num_tokens: int) -> intreturning the cost for a sequence ofnum_tokenstokens across all heads (i.e. for the full[B, T, H, d_k]shaped input).RoPE on Q and K: 4 * T * D Each of Q, K: x * cos + rotate(x) * sin = 2 elementwise multiplies per element, over T * D elements, for both Q and K. This assumes full RoPE (all
head_dimdimensions rotated), which is the case here: the cos/sin buffers have shape[T, head_dim]and broadcast across all heads. For partial RoPE (only the firstrope_dimof each head rotated, remainder passed through), the count would instead be4 * T * num_heads * rope_dim.SDPA (Q@K^T + attn@V): 4 * T² * D Q@K^T: 2 * T * T * D. attn@V: 2 * T * T * D. (Softmax cost ~3 * T * H is negligible and omitted.)
Output projection (Linear(D, D)): 2 * T * D²
Total: 8 * T * D² + 4 * T² * D + 4 * T * D + qk_norm_flops.
Note
num_tokensshould equalnum_patches_h * num_patches_w + (1 if has_cls else 0) + num_registersto match the actual sequence length seen duringforward(). Passing a different value will give a proportionally scaled estimate.- Parameters:
- Returns:
Total FLOPs as an integer.
- Return type:
- forward(x)#
Apply ViT-5 multi-head self-attention to a token sequence.
Executes the following pipeline:
QKV projection —
x W_{QKV}reshaped to[B, T, 3, H, d_k], then split into Q, K, V each of shape[B, T, H, d_k].(Optional) QK normalisation —
q_norm(Q)andk_norm(K)applied independently along the last (head-dim) axis.RoPE —
Q' = Q * cos + rotate(Q) * sinandK' = K * cos + rotate(K) * sin, wherecos/sinare the precomputed[T, head_dim]buffers broadcast to[1, T, 1, head_dim]over the batch and head axes. Uses_rotate_half_per_axis()(split-half convention).Transpose for SDPA — rearrange to
[B, H, T, d_k].Scaled dot-product attention — delegates to
F.scaled_dot_product_attention; PyTorch auto-selects the best backend (CuDNN on H100, FlashAttention on A100, etc.). Thedropout_pis set toself.attn_dropoutduring training and 0.0 at inference.Merge heads —
out.transpose(1, 2).reshape(B, T, C).Output projection + dropout —
proj_drop(proj(out)).
- Parameters:
x (Tensor) –
Input token sequence of shape
[B, T, C]where:B— batch size,T = num_patches_h * num_patches_w + (1 if has_cls else 0) + num_registers— total token count following the ViT-5 layout[patches, (CLS,) registers],C = hidden_dim— channel dimension.
The spatial dimensions of the patch grid are baked into the precomputed
rope_cos/rope_sinbuffers;Tmust matchrope_cos.shape[0]exactly.- Returns:
Output tensor of shape
[B, T, C], the same shape as the input.- Return type:
- Raises:
RuntimeError – If
Tdoes not matchrope_cos.shape[0], causing a shape mismatch in the broadcast multiplyq * cos. The expected value isnum_patches_h * num_patches_w + int(has_cls) + num_registersas set at construction time.
- extra_repr()#
Return a concise string summary of this module’s configuration.
Note
has_clsandscaleare consequential hyperparameters that are not included in the output string. Usemodule.has_clsandmodule.scaleto inspect them directly.- Returns:
Comma-separated key=value pairs covering
hidden_dim,num_heads,qk_norm,num_registers, patch grid size,rope_base, andreg_rope_base.- Return type: