L2Norm#

class L2Norm(dim=-1, eps=1e-12)#

Bases: Module

L2 normalisation layer — learnable-parameter-free, LazyConfig-friendly.

Wraps F.normalize(x, p=2, dim=self.dim) as an nn.Module so it can be used as a norm_cfg target in instantiate() wherever a plain normalisation module is expected (e.g. as the QK-norm in ViT5Attention).

Duck-typing

The channels_first property returns True when dim == 1, matching the convention used by RMSNormChannelFirst so callers can detect the memory layout without an isinstance check.

dim#

Axis to normalise along.

Type:

int

eps#

Stability constant for the L2 norm denominator.

Type:

float

Parameters:
  • dim (int) – Dimension to normalise over. Default -1 (last axis).

  • eps (float) – Small positive constant added to the L2 norm. Default 1e-12.

__init__(dim=-1, eps=1e-12)#

Initialise L2Norm.

Parameters:
  • dim (int) – Axis to normalise over. Default -1.

  • eps (float) – Stability constant. Default 1e-12.

property channels_first: bool#

True when normalising over dim=1 (channel-first layout).

forward(x)#

L2-normalise x along self.dim.

Parameters:

x (Tensor) – Input tensor of any shape.

Returns:

Unit-norm tensor along self.dim, same shape and dtype as x.

Return type:

torch.Tensor

extra_repr()#

Return dim and eps for repr().

Return type:

str