RMSNormChannelFirst#
- class RMSNormChannelFirst(dim, eps=1e-6, use_quack=True)#
Bases:
ModuleRoot Mean Square Layer Normalization — channel-first layout (arXiv:1910.07467).
Normalises a tensor along ``dim=1`` (the channel axis) and scales the result by a learned per-channel weight. Accepts tensors of shape
[B, C, *spatial], e.g.[B, C, H, W]for 2-D or[B, C, L]for 1-D.RMS(x)_b = sqrt( mean_C(x[b, :, ...]²) + ε ) # scalar per sample out[b,:,…] = x[b,:,…] / RMS(x)_b * γ # γ: [C, 1, …, 1]
Duck-typing sentinel: the class attribute
channels_first = Trueallows callers (e.g.HyenaOperatorND) to detect the layout without anisinstancecheck.- weight#
Learnable scale
γof shape(C,), ones-initialised. Tagged_no_weight_decay = True.- Type:
nn.Parameter
- Parameters:
dim (int) – Number of channels
C(size ofdim=1).eps (float) – Small positive constant for numerical stability. Default
1e-6.use_quack (bool) – If
True(default), use the QuACK fused kernel when available. Set toFalseto force the PyTorch path, which letstorch.compilefuse the norm with surrounding ops.
- __init__(dim, eps=1e-6, use_quack=True)#
Initialise RMSNormChannelFirst with ones-initialised weight.
- flop_count(num_tokens)#
Return approximate FLOP count, identical in cost to channel-last RMSNorm.
The three operations — square, mean+rsqrt, scale — each touch every element once, giving
3 * num_tokens * CFLOPs regardless of the memory layout.- Parameters:
num_tokens (int) – Total number of spatial positions in the batch (
B * prod(spatial_shape)).- Returns:
3 * num_tokens * C.- Return type:
Integer FLOP estimate