RMSNormChannelFirst#

class RMSNormChannelFirst(dim, eps=1e-6, use_quack=True)#

Bases: Module

Root Mean Square Layer Normalization — channel-first layout (arXiv:1910.07467).

Normalises a tensor along ``dim=1`` (the channel axis) and scales the result by a learned per-channel weight. Accepts tensors of shape [B, C, *spatial], e.g. [B, C, H, W] for 2-D or [B, C, L] for 1-D.

RMS(x)_b   = sqrt( mean_C(x[b, :, ...]²) + ε )   # scalar per sample
out[b,:,…] = x[b,:,…] / RMS(x)_b  *  γ            # γ: [C, 1, …, 1]

Duck-typing sentinel: the class attribute channels_first = True allows callers (e.g. HyenaOperatorND) to detect the layout without an isinstance check.

weight#

Learnable scale γ of shape (C,), ones-initialised. Tagged _no_weight_decay = True.

Type:

nn.Parameter

eps#

Stability constant added inside the square root.

Type:

float

use_quack#

Whether to attempt the QuACK kernel path.

Type:

bool

Parameters:
  • dim (int) – Number of channels C (size of dim=1).

  • eps (float) – Small positive constant for numerical stability. Default 1e-6.

  • use_quack (bool) – If True (default), use the QuACK fused kernel when available. Set to False to force the PyTorch path, which lets torch.compile fuse the norm with surrounding ops.

channels_first: bool = True#
__init__(dim, eps=1e-6, use_quack=True)#

Initialise RMSNormChannelFirst with ones-initialised weight.

Parameters:
  • dim (int) – Number of channels C; determines the shape of weight.

  • eps (float) – Stability constant. Default 1e-6.

  • use_quack (bool) – Enable QuACK kernel path. Default True.

flop_count(num_tokens)#

Return approximate FLOP count, identical in cost to channel-last RMSNorm.

The three operations — square, mean+rsqrt, scale — each touch every element once, giving 3 * num_tokens * C FLOPs regardless of the memory layout.

Parameters:

num_tokens (int) – Total number of spatial positions in the batch (B * prod(spatial_shape)).

Returns:

3 * num_tokens * C.

Return type:

Integer FLOP estimate

forward(x)#

Normalise over the channel dimension (dim=1) and scale by weight.

Parameters:

x (Tensor) – Input tensor of shape [B, C, *spatial].

Returns:

Normalised and scaled tensor of the same shape and dtype as x.

Return type:

torch.Tensor