GlobalResponseNorm#
- class GlobalResponseNorm(dim, eps=1e-6)#
Bases:
ModuleGlobal Response Normalisation (GRN) layer (Woo et al., arXiv:2301.00808).
Computes a per-channel global L2 norm across all spatial positions, normalises each channel norm by the cross-channel mean norm, then rescales the input with learned
γ/βparameters plus a residual connection:gx = ||x||_{spatial, L2} # [B, 1, ..., 1, C] nx = gx / (mean_C(gx) + eps) # [B, 1, ..., 1, C] out = γ * (x * nx) + β + x # [B, *spatial, C]γandβare 1-D tensors of lengthCthat broadcast over the batch dimension and all spatial dimensions. They are zero-initialised so the layer starts as an identity (out = x) and the network can learn to activate the normalisation only where it is beneficial.Inter-channel competition
The divisive step
gx / mean_C(gx)produces values > 1 for channels whose global L2 norm exceeds the cross-channel average, and < 1 for weaker channels. Multiplying the input bynxtherefore amplifies dominant channels and suppresses weak ones, enforcing a form of lateral inhibition across the channel dimension. This is the key mechanism by which GRN combats feature collapse (see module docstring).Unlike LayerNorm — which normalises each token’s feature vector and discards inter-channel magnitude differences — GRN preserves and amplifies these differences, making it particularly effective inside gated MLPs (GLU / SwiGLU) where per-channel activation strength carries semantic weight.
Broadcast semantics
keepdim=Truein the spatial reduction producesgxof shape[B, 1, ..., 1, C](one singleton per spatial axis). The mean is then taken along the channel axis (dim=-1, keepdim=True) to yieldnxof the same shape. The subsequent multiplicationx * nxbroadcasts over all spatial positions without an explicit tile, so GRN is memory-efficient and agnostic to the number of spatial dimensions (1-D sequences, 2-D images, 3-D volumes, etc.).- gamma#
Learnable per-channel scale; shape
(C,), zero-initialised.- Type:
nn.Parameter
- beta#
Learnable per-channel bias; shape
(C,), zero-initialised.- Type:
nn.Parameter
- eps#
Small positive constant added to
mean_C(gx)in the denominator to prevent division by zero.- Type:
- Reference:
Woo et al., “ConvNeXt V2”, arXiv:2301.00808 (CVPR 2023), Sec. 3 “Global Response Normalization”.
- __init__(dim, eps=1e-6)#
Initialise GRN with zero-initialised gamma and beta.
- flop_count(num_tokens)#
Return the approximate FLOP count for one forward pass.
Let T =
num_tokens(total number of spatial positions summed over the batch, i.e.B * prod(spatial_shape)) and C =self.dim. The cost is dominated by element-wise operations over the T × C activation grid:Squared L2 norm per channel — element-wise square + reduction sum over T positions per channel → 2 · T · C FLOPs.
Square root per channel — C FLOPs (negligible vs T · C, included for completeness).
Cross-channel mean — C additions → C FLOPs.
Division
gx / (mean_C(gx) + eps)— C FLOPs.Broadcast multiply
x * nx— T · C FLOPs.Scale
γ * (x * nx)— T · C FLOPs.Add beta and residual — 2 · T · C FLOPs.
Total: approximately 6 · T · C FLOPs.
- forward(x)#
Apply Global Response Normalisation to the input tensor.
The input must be in channels-last layout: the channel axis is the last dimension and all intermediate dimensions are spatial. The batch axis is always the first dimension.
- Parameters:
x (Tensor) –
Input activation tensor of shape
[B, *spatial, C], whereB— batch size,*spatial— any number (≥ 1) of spatial dimensions (e.g.(L,)for 1-D sequences,(H, W)for 2-D images,(D, H, W)for 3-D volumes),C— number of channels; must equalself.dim.
- Returns:
Output tensor of shape
[B, *spatial, C], the same dtype and device asx, with GRN applied:γ * (x * nx) + β + x.- Return type:
- Raises:
RuntimeError – If
x.shape[-1]does not equalself.dim(raised implicitly when broadcastingself.gamma/self.betaagainst a mismatched channel dimension).