RMSNorm#
- class RMSNorm(dim, eps=1e-6, use_quack=True)#
Bases:
ModuleRoot Mean Square Layer Normalization (Zhang & Sennrich, arXiv:1910.07467).
Normalises a tensor over its last dimension using RMS statistics, then scales the result by a learned per-channel weight:
RMS(x) = sqrt( mean(x²) + ε ) # scalar per token out = (x / RMS(x)) * γ # γ broadcast over leading dims
Accepts tensors of any shape
[*leading, D]; only the last dimension is normalised. The learned weightγhas shape(D,)and is excluded from weight decay via the_no_weight_decaytag.Backend selection (see module docstring for full details):
use_quack=True(default): QuACK fused kernel on SM ≥ 9.0 GPUs. Falls back to PyTorch automatically on older GPUs.use_quack=False: Pure PyTorch; preferred undertorch.compile.
- weight#
Learnable scale
γof shape(dim,), ones-initialised. Tagged_no_weight_decay = True.- Type:
nn.Parameter
- Parameters:
dim (int) – Size of the last dimension
Dto normalise over.eps (float) – Small positive constant for numerical stability. Default
1e-6.use_quack (bool) – If
True(default), use the QuACK fused kernel when available. Set toFalseto force the PyTorch path, which letstorch.compilefuse the norm with surrounding ops.
- __init__(dim, eps=1e-6, use_quack=True)#
Initialise RMSNorm with ones-initialised weight.
- flop_count(num_tokens)#
Count FLOPs for RMS normalization over
num_tokenstoken vectors.- Operations per token (D =
self.weight.shape[0]): Square each element: D FLOPs
Mean over D + rsqrt: D FLOPs (amortized reduction + 1 rsqrt)
Multiply by learned scale: D FLOPs
Total: 3 * num_tokens * D.
- Operations per token (D =