PerHeadRMSNorm#
- class PerHeadRMSNorm(num_heads, head_dim, eps=1e-6, use_quack=True)#
Bases:
ModuleRMSNorm applied independently to each attention head (QK-norm).
Accepts a flat hidden representation of shape
[*leading, H·D], reshapes to[*leading, H, D], applies an independentRMSNormto each head’sD-dimensional slice, then flattens back to[*leading, H·D]. Each head has its own learnable scaleγ ∈ ℝ^D.This is the QK-norm technique used in ViT-5 / vit5_attention.py to stabilise attention logit magnitudes at large model sizes:
x : [*leading, H*D] x → reshape → [*leading, H, D] x → RMSNorm per head → [*leading, H, D] x → flatten → [*leading, H*D]
- Parameters:
num_heads (int) – Number of attention heads
H.head_dim (int) – Dimension of each head
D.eps (float) – Small constant for numerical stability. Default
1e-6.use_quack (bool) – If
True(default), use the QuACK fused kernel when available. Set toFalseto force the PyTorch path so thattorch.compilecan fuse the norm with surrounding ops.
- __init__(num_heads, head_dim, eps=1e-6, use_quack=True)#
Initialise PerHeadRMSNorm.
- flop_count(num_tokens)#
Count FLOPs for per-head RMS normalization on
num_tokenstokens.Each token is reshaped to [num_heads, head_dim] and RMSNorm is applied independently per head. Total cost is the same as a full RMSNorm over
hidden_dim = num_heads * head_dim:Total: 3 * num_tokens * num_heads * head_dim.