RMSNorm#

class RMSNorm(dim, eps=1e-6, use_quack=True)#

Bases: Module

Root Mean Square Layer Normalization (Zhang & Sennrich, arXiv:1910.07467).

Normalises a tensor over its last dimension using RMS statistics, then scales the result by a learned per-channel weight:

RMS(x) = sqrt( mean(x²) + ε )      # scalar per token
out    = (x / RMS(x)) * γ           # γ broadcast over leading dims

Accepts tensors of any shape [*leading, D]; only the last dimension is normalised. The learned weight γ has shape (D,) and is excluded from weight decay via the _no_weight_decay tag.

Backend selection (see module docstring for full details):

  • use_quack=True (default): QuACK fused kernel on SM ≥ 9.0 GPUs. Falls back to PyTorch automatically on older GPUs.

  • use_quack=False: Pure PyTorch; preferred under torch.compile.

weight#

Learnable scale γ of shape (dim,), ones-initialised. Tagged _no_weight_decay = True.

Type:

nn.Parameter

eps#

Stability constant added inside the square root.

Type:

float

use_quack#

Whether to attempt the QuACK kernel path.

Type:

bool

Parameters:
  • dim (int) – Size of the last dimension D to normalise over.

  • eps (float) – Small positive constant for numerical stability. Default 1e-6.

  • use_quack (bool) – If True (default), use the QuACK fused kernel when available. Set to False to force the PyTorch path, which lets torch.compile fuse the norm with surrounding ops.

__init__(dim, eps=1e-6, use_quack=True)#

Initialise RMSNorm with ones-initialised weight.

Parameters:
  • dim (int) – Channel dimension D; determines the shape of weight.

  • eps (float) – Stability constant added to the RMS denominator. Default 1e-6.

  • use_quack (bool) – Enable QuACK kernel path. Default True.

flop_count(num_tokens)#

Count FLOPs for RMS normalization over num_tokens token vectors.

Operations per token (D = self.weight.shape[0]):
  1. Square each element: D FLOPs

  2. Mean over D + rsqrt: D FLOPs (amortized reduction + 1 rsqrt)

  3. Multiply by learned scale: D FLOPs

Total: 3 * num_tokens * D.

Parameters:

num_tokens (int) – Number of token vectors being normalized.

Returns:

Total FLOPs as an integer.

Return type:

int

forward(x)#

Apply RMS normalisation over the last dimension.

Parameters:

x (Tensor) – Input tensor of shape [*leading, D]. Any number of leading dimensions is supported; only the last axis is normalised.

Returns:

Normalised and scaled tensor, same shape and dtype as x.

Return type:

torch.Tensor