GlobalResponseNorm#

class GlobalResponseNorm(dim, eps=1e-6)#

Bases: Module

Global Response Normalisation (GRN) layer (Woo et al., arXiv:2301.00808).

Computes a per-channel global L2 norm across all spatial positions, normalises each channel norm by the cross-channel mean norm, then rescales the input with learned γ / β parameters plus a residual connection:

gx  = ||x||_{spatial, L2}             # [B, 1, ..., 1, C]
nx  = gx / (mean_C(gx) + eps)         # [B, 1, ..., 1, C]
out = γ * (x * nx)  +  β  +  x        # [B, *spatial, C]

γ and β are 1-D tensors of length C that broadcast over the batch dimension and all spatial dimensions. They are zero-initialised so the layer starts as an identity (out = x) and the network can learn to activate the normalisation only where it is beneficial.

Inter-channel competition

The divisive step gx / mean_C(gx) produces values > 1 for channels whose global L2 norm exceeds the cross-channel average, and < 1 for weaker channels. Multiplying the input by nx therefore amplifies dominant channels and suppresses weak ones, enforcing a form of lateral inhibition across the channel dimension. This is the key mechanism by which GRN combats feature collapse (see module docstring).

Unlike LayerNorm — which normalises each token’s feature vector and discards inter-channel magnitude differences — GRN preserves and amplifies these differences, making it particularly effective inside gated MLPs (GLU / SwiGLU) where per-channel activation strength carries semantic weight.

Broadcast semantics

keepdim=True in the spatial reduction produces gx of shape [B, 1, ..., 1, C] (one singleton per spatial axis). The mean is then taken along the channel axis (dim=-1, keepdim=True) to yield nx of the same shape. The subsequent multiplication x * nx broadcasts over all spatial positions without an explicit tile, so GRN is memory-efficient and agnostic to the number of spatial dimensions (1-D sequences, 2-D images, 3-D volumes, etc.).

Parameters:
dim#

Number of channels C; must equal x.shape[-1] at every forward call.

Type:

int

gamma#

Learnable per-channel scale; shape (C,), zero-initialised.

Type:

nn.Parameter

beta#

Learnable per-channel bias; shape (C,), zero-initialised.

Type:

nn.Parameter

eps#

Small positive constant added to mean_C(gx) in the denominator to prevent division by zero.

Type:

float

Reference:

Woo et al., “ConvNeXt V2”, arXiv:2301.00808 (CVPR 2023), Sec. 3 “Global Response Normalization”.

__init__(dim, eps=1e-6)#

Initialise GRN with zero-initialised gamma and beta.

Parameters:
  • dim (int) – Number of channels C. Must match the size of the last dimension of every input tensor passed to forward. Determines the shape of gamma and beta.

  • eps (float) – Small positive constant added to mean_C(gx) in the denominator for numerical stability. Defaults to 1e-6.

flop_count(num_tokens)#

Return the approximate FLOP count for one forward pass.

Let T = num_tokens (total number of spatial positions summed over the batch, i.e. B * prod(spatial_shape)) and C = self.dim. The cost is dominated by element-wise operations over the T × C activation grid:

  • Squared L2 norm per channel — element-wise square + reduction sum over T positions per channel → 2 · T · C FLOPs.

  • Square root per channel — C FLOPs (negligible vs T · C, included for completeness).

  • Cross-channel mean — C additions → C FLOPs.

  • Division gx / (mean_C(gx) + eps) — C FLOPs.

  • Broadcast multiply x * nx — T · C FLOPs.

  • Scale γ * (x * nx) — T · C FLOPs.

  • Add beta and residual — 2 · T · C FLOPs.

Total: approximately 6 · T · C FLOPs.

Parameters:

num_tokens (int) – Total number of spatial positions in the batch, i.e. B * prod(spatial_shape). Note this includes the batch dimension: for a batch of 8 images of size 32×32 the value is 8 * 32 * 32 = 8192.

Returns:

Estimated integer FLOP count for one forward pass.

Return type:

int

forward(x)#

Apply Global Response Normalisation to the input tensor.

The input must be in channels-last layout: the channel axis is the last dimension and all intermediate dimensions are spatial. The batch axis is always the first dimension.

Parameters:

x (Tensor) –

Input activation tensor of shape [B, *spatial, C], where

  • B — batch size,

  • *spatial — any number (≥ 1) of spatial dimensions (e.g. (L,) for 1-D sequences, (H, W) for 2-D images, (D, H, W) for 3-D volumes),

  • C — number of channels; must equal self.dim.

Returns:

Output tensor of shape [B, *spatial, C], the same dtype and device as x, with GRN applied: γ * (x * nx) + β + x.

Return type:

torch.Tensor

Raises:

RuntimeError – If x.shape[-1] does not equal self.dim (raised implicitly when broadcasting self.gamma / self.beta against a mismatched channel dimension).