PerHeadRMSNorm#

class PerHeadRMSNorm(num_heads, head_dim, eps=1e-6, use_quack=True)#

Bases: Module

RMSNorm applied independently to each attention head (QK-norm).

Accepts a flat hidden representation of shape [*leading, H·D], reshapes to [*leading, H, D], applies an independent RMSNorm to each head’s D-dimensional slice, then flattens back to [*leading, H·D]. Each head has its own learnable scale γ ℝ^D.

This is the QK-norm technique used in ViT-5 / vit5_attention.py to stabilise attention logit magnitudes at large model sizes:

x  : [*leading, H*D]
x  → reshape → [*leading, H, D]
x  → RMSNorm per head → [*leading, H, D]
x  → flatten → [*leading, H*D]
num_heads#

Number of attention heads H.

Type:

int

head_dim#

Dimension per head D.

Type:

int

norm#

Shared RMSNorm instance applied to each head slice (the weight γ has shape (D,)).

Type:

RMSNorm

Parameters:
  • num_heads (int) – Number of attention heads H.

  • head_dim (int) – Dimension of each head D.

  • eps (float) – Small constant for numerical stability. Default 1e-6.

  • use_quack (bool) – If True (default), use the QuACK fused kernel when available. Set to False to force the PyTorch path so that torch.compile can fuse the norm with surrounding ops.

__init__(num_heads, head_dim, eps=1e-6, use_quack=True)#

Initialise PerHeadRMSNorm.

Parameters:
  • num_heads (int) – Number of attention heads H.

  • head_dim (int) – Dimension per head D; the inner RMSNorm normalises over this dimension.

  • eps (float) – Stability constant. Default 1e-6.

  • use_quack (bool) – Enable QuACK kernel path. Default True.

flop_count(num_tokens)#

Count FLOPs for per-head RMS normalization on num_tokens tokens.

Each token is reshaped to [num_heads, head_dim] and RMSNorm is applied independently per head. Total cost is the same as a full RMSNorm over hidden_dim = num_heads * head_dim:

Total: 3 * num_tokens * num_heads * head_dim.

Parameters:

num_tokens (int) – Number of token vectors being normalized.

Returns:

Total FLOPs as an integer.

Return type:

int

forward(x)#

Apply per-head RMS normalisation.

Parameters:

x (Tensor) – Input tensor of shape [*leading, num_heads * head_dim].

Returns:

Normalised tensor of the same shape as x, where each head’s head_dim-slice has been RMS-normalised and scaled by the shared learnable γ.

Return type:

torch.Tensor

extra_repr()#

Return head layout for repr().

Return type:

str