ViT5Attention#

class ViT5Attention(
hidden_dim,
num_heads,
num_patches_h,
num_patches_w,
num_registers=4,
has_cls=True,
qk_norm=None,
rope_base=10000.0,
reg_rope_base=100.0,
attn_dropout=0.0,
proj_dropout=0.0,
qkv_bias=False,
out_proj_bias=False,
scale=None,
init_fn_qkv_proj=None,
init_fn_out_proj=None,
)#

Bases: Module

ViT-5 multi-head self-attention with RMSNorm QK-Norm and register-aware RoPE.

This module is the primary sequence-mixing operator for the ViT-5 family of hierarchical vision transformers. It computes standard scaled dot-product attention:

\[\text{head}_i = \text{softmax}\!\left( \frac{Q_i K_i^\top}{\sqrt{d_k}} \right) V_i, \quad \text{out} = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) W_O\]

where \(d_k = C / H\) is the per-head dimension and \(W_O\) is the output projection. Q, K, V are obtained from a single fused linear projection \([Q, K, V] = x W_{QKV}\).

Token layout

Input shape: [B, T, C] where T = num_patches_h * num_patches_w + (1 if has_cls else 0) + num_registers. Token ordering within the sequence axis:

[ patch_0, patch_1, ..., patch_{H*W-1}, (CLS,) reg_0, ..., reg_{R-1} ]
  <----- H*W patch tokens --------->   <--1-->  <---- R registers ---->

This ordering must be consistent with the token layout produced by the network’s patchify + register-injection layers (see ViT5Classifier).

Positional encoding

Three distinct positional encodings are applied:

  • Patch tokens — 2D RoPE with base frequency rope_base (default 10000). The H×W grid is linearised in row-major (Y-then-X) order.

  • CLS token — identity rotation: cos=1, sin=0. No positional bias is imposed on the class token.

  • Register tokens — 2D RoPE with base reg_rope_base (default 100), treating the R registers as a sqrt(R) × sqrt(R) grid. A lower base value (reg_rope_base=100 vs rope_base=10000) yields higher rotation frequencies (theta decays more slowly across head-dim pairs), giving denser angular spacing for register positions. This reflects their role as global context carriers without fixed spatial meaning.

All three tables are concatenated into a single buffer pair (rope_cos, rope_sin) of shape [T, head_dim] and applied with a single broadcast multiply in forward().

QK normalisation

When qk_norm is provided, two independent norm modules (q_norm, k_norm) are instantiated and applied to Q and K after qkv.unbind() produces tensors of shape [B, T, H, d_k], and before RoPE. The norm is expected to be a learnable RMSNorm or equivalent (accepting input of shape [B, T, H, d_k] and normalising along the last axis). Unlike the generic Attention module which uses a fixed L2 (cosine) normalisation, the learnable per-head norm here allows the model to control the scale of the dot products.

Note

Norm is applied before RoPE in this module (order: unbind q_norm/k_norm rope SDPA), whereas the generic Attention applies RoPE before L2-norm. The order matters for checkpoint compatibility — swapping the two will change the effective positional encoding applied to normalised queries and keys.

Differences vs. Attention

  • Self-contained QKV + output projections (generic uses outer QKVSequenceMixer).

  • RMSNorm QK-Norm instead of L2 normalisation.

  • Dual-base register-aware RoPE instead of single-base uniform RoPE.

  • Fixed [B, T, C] input — no multi-dimensional spatial support, no causal masking, no context-parallelism guard.

hidden_dim#

Total channel dimension C.

Type:

int

num_heads#

Number of attention heads H.

Type:

int

head_dim#

Per-head dimension d_k = C / H.

Type:

int

scale#

Attention logit scale, default head_dim ** -0.5.

Type:

float

num_patches_h#

Height of the patch grid used for 2D RoPE.

Type:

int

num_patches_w#

Width of the patch grid used for 2D RoPE.

Type:

int

num_registers#

Number of register tokens R.

Type:

int

has_cls#

Whether the token sequence includes a CLS token between the patch tokens and the register tokens.

Type:

bool

attn_dropout#

Dropout probability applied to attention weights during training; set to 0.0 at inference.

Type:

float

qkv#

Fused QKV projection: Linear(C, 3C, bias=qkv_bias).

Type:

nn.Linear

proj#

Output projection: Linear(C, C, bias=out_proj_bias).

Type:

nn.Linear

proj_drop#

Dropout on the projected output.

Type:

nn.Dropout | nn.Identity

q_norm#

Per-head query normaliser. Present only when qk_norm is provided (i.e. self.qk_norm is True).

Type:

nn.Module

k_norm#

Per-head key normaliser. Present only when qk_norm is provided.

Type:

nn.Module

qk_norm#

Flag indicating whether QK normalisation is active.

Type:

bool

rope_base#

Base frequency for patch-token RoPE.

Type:

float

reg_rope_base#

Base frequency for register-token RoPE.

Type:

float

reg_rope_h#

Height dimension of the register RoPE grid (int(num_registers ** 0.5)).

Type:

int

reg_rope_w#

Width dimension of the register RoPE grid (int(num_registers ** 0.5)).

Type:

int

rope_cos#

Non-persistent buffer of shape [T, head_dim] containing the concatenated patch + CLS + register cosine tables.

Type:

torch.Tensor

rope_sin#

Non-persistent buffer of shape [T, head_dim] containing the concatenated patch + CLS + register sine tables.

Type:

torch.Tensor

Parameters:
  • hidden_dim (int) – Total hidden dimension C. Must be divisible by num_heads.

  • num_heads (int) – Number of attention heads H.

  • num_patches_h (int) – Height of the patch grid (number of patch rows). Used to build the patch 2D RoPE table.

  • num_patches_w (int) – Width of the patch grid (number of patch columns). Used to build the patch 2D RoPE table.

  • num_registers (int) – Number of register tokens R appended after the (optional) CLS token. Should be a perfect square when > 0 so that the register RoPE grid is exactly sqrt(R) × sqrt(R). If R is not a perfect square, reg_rope_h = reg_rope_w = int(R**0.5) silently truncates, producing only reg_rope_h * reg_rope_w < R RoPE rows and causing a torch.cat shape mismatch at init time. Defaults to 4.

  • has_cls (bool) – If True, the token sequence contains one CLS token immediately after the patch tokens. The CLS token receives identity RoPE (cos=1, sin=0). Defaults to True.

  • qk_norm (LazyConfig | None) – LazyConfig for the per-head QK normalisation module (e.g. RMSNorm(head_dim)). When None, QK normalisation is disabled. Defaults to None.

  • rope_base (float) – Base frequency \(\\theta_0\) for the patch RoPE frequency schedule. Defaults to 10000.0.

  • reg_rope_base (float) – Base frequency for the register-token RoPE schedule. A lower base (higher frequency) gives denser angular spacing. Defaults to 100.0.

  • attn_dropout (float) – Dropout rate on attention weights, applied only during training (module.training is True). Defaults to 0.0.

  • proj_dropout (float) – Dropout rate on the output projection. When 0.0, proj_drop is an nn.Identity. Defaults to 0.0.

  • qkv_bias (bool) – Whether to include a bias term in the fused QKV projection. Defaults to False.

  • out_proj_bias (bool) – Whether to include a bias term in the output projection. Defaults to False.

  • scale (float | None) – Explicit attention logit scale. When None, the scale defaults to head_dim ** -0.5. Defaults to None.

  • init_fn_qkv_proj (Callable[[Tensor], None] | None) – Optional callable fn(weight: Tensor) -> None applied to self.qkv.weight after construction. The bias, if present, is zero-initialised. When None, PyTorch’s default Xavier uniform initialisation is used. Defaults to None.

  • init_fn_out_proj (Callable[[Tensor], None] | None) – Optional callable fn(weight: Tensor) -> None applied to self.proj.weight after construction. The bias, if present, is zero-initialised. When None, PyTorch’s default initialisation is used. Defaults to None.

Raises:

AssertionError – If hidden_dim % num_heads != 0.

Example:

import torch
from nvsubquadratic.modules.vit5_attention import ViT5Attention

# 2D patch grid of 14x14 with 4 register tokens and 1 CLS token, no QK norm
attn = ViT5Attention(
    hidden_dim=384,
    num_heads=6,
    num_patches_h=14,
    num_patches_w=14,
    num_registers=4,
    has_cls=True,
)
T = 14 * 14 + 1 + 4  # patches + CLS + registers = 201
x = torch.randn(2, T, 384)  # [B, T, C]
out = attn(x)               # [B, T, C]
assert out.shape == x.shape
# To enable QK-norm, pass a LazyConfig targeting any norm module
# that accepts [B, T, H, d_k] tensors and normalises along the last axis.
__init__(
hidden_dim,
num_heads,
num_patches_h,
num_patches_w,
num_registers=4,
has_cls=True,
qk_norm=None,
rope_base=10000.0,
reg_rope_base=100.0,
attn_dropout=0.0,
proj_dropout=0.0,
qkv_bias=False,
out_proj_bias=False,
scale=None,
init_fn_qkv_proj=None,
init_fn_out_proj=None,
)#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Parameters:
flop_count(num_tokens, inference=False)#

Count FLOPs for multi-head self-attention on num_tokens tokens.

The inference flag is accepted for API consistency but does not change the count — attention has no cacheable precomputation analogous to SIREN kernels.

Let T = num_tokens, D = self.hidden_dim.

FLOPs breakdown:
  1. QKV projection (Linear(D, 3D)): 6 * T * D² Three projections packed into one: 2 * T * D * 3D.

  2. QK-Norm (2x RMSNorm on Q and K): Delegated to self.q_norm / self.k_norm. Only counted when self.qk_norm is True; 0 otherwise. Each norm module must expose flop_count(num_tokens: int) -> int returning the cost for a sequence of num_tokens tokens across all heads (i.e. for the full [B, T, H, d_k] shaped input).

  3. RoPE on Q and K: 4 * T * D Each of Q, K: x * cos + rotate(x) * sin = 2 elementwise multiplies per element, over T * D elements, for both Q and K. This assumes full RoPE (all head_dim dimensions rotated), which is the case here: the cos/sin buffers have shape [T, head_dim] and broadcast across all heads. For partial RoPE (only the first rope_dim of each head rotated, remainder passed through), the count would instead be 4 * T * num_heads * rope_dim.

  4. SDPA (Q@K^T + attn@V): 4 * T² * D Q@K^T: 2 * T * T * D. attn@V: 2 * T * T * D. (Softmax cost ~3 * T * H is negligible and omitted.)

  5. Output projection (Linear(D, D)): 2 * T * D²

Total: 8 * T * D² + 4 * T² * D + 4 * T * D + qk_norm_flops.

Note

num_tokens should equal num_patches_h * num_patches_w + (1 if has_cls else 0) + num_registers to match the actual sequence length seen during forward(). Passing a different value will give a proportionally scaled estimate.

Parameters:
  • num_tokens (int) – Total sequence length T (cls + patches + registers). Should equal num_patches_h * num_patches_w + (1 if has_cls else 0) + num_registers.

  • inference (bool) – Accepted for API consistency with other sequence-mixer modules (e.g. Hyena); does not affect the FLOP count.

Returns:

Total FLOPs as an integer.

Return type:

int

forward(x)#

Apply ViT-5 multi-head self-attention to a token sequence.

Executes the following pipeline:

  1. QKV projectionx W_{QKV} reshaped to [B, T, 3, H, d_k], then split into Q, K, V each of shape [B, T, H, d_k].

  2. (Optional) QK normalisationq_norm(Q) and k_norm(K) applied independently along the last (head-dim) axis.

  3. RoPEQ' = Q * cos + rotate(Q) * sin and K' = K * cos + rotate(K) * sin, where cos / sin are the precomputed [T, head_dim] buffers broadcast to [1, T, 1, head_dim] over the batch and head axes. Uses _rotate_half_per_axis() (split-half convention).

  4. Transpose for SDPA — rearrange to [B, H, T, d_k].

  5. Scaled dot-product attention — delegates to F.scaled_dot_product_attention; PyTorch auto-selects the best backend (CuDNN on H100, FlashAttention on A100, etc.). The dropout_p is set to self.attn_dropout during training and 0.0 at inference.

  6. Merge headsout.transpose(1, 2).reshape(B, T, C).

  7. Output projection + dropoutproj_drop(proj(out)).

Parameters:

x (Tensor) –

Input token sequence of shape [B, T, C] where:

  • B — batch size,

  • T = num_patches_h * num_patches_w + (1 if has_cls else 0) + num_registers — total token count following the ViT-5 layout [patches, (CLS,) registers],

  • C = hidden_dim — channel dimension.

The spatial dimensions of the patch grid are baked into the precomputed rope_cos / rope_sin buffers; T must match rope_cos.shape[0] exactly.

Returns:

Output tensor of shape [B, T, C], the same shape as the input.

Return type:

torch.Tensor

Raises:

RuntimeError – If T does not match rope_cos.shape[0], causing a shape mismatch in the broadcast multiply q * cos. The expected value is num_patches_h * num_patches_w + int(has_cls) + num_registers as set at construction time.

extra_repr()#

Return a concise string summary of this module’s configuration.

Note

has_cls and scale are consequential hyperparameters that are not included in the output string. Use module.has_cls and module.scale to inspect them directly.

Returns:

Comma-separated key=value pairs covering hidden_dim, num_heads, qk_norm, num_registers, patch grid size, rope_base, and reg_rope_base.

Return type:

str