LayerScale#
- class LayerScale(dim, init_value=1e-4)#
Bases:
ModuleLearnable per-channel scalar gate for residual branch outputs.
Operation
Given an input tensor
xof arbitrary leading batch / spatial dimensions followed by a channel dimensionC, LayerScale computesoutput = x * γ
where
γ ∈ ℝ^Cis broadcast element-wise along all axes except the last one. Concretely, ifxhas shape(B, T, C)(the ViT-5 layout) thenγis of shape(C,)and is broadcast to(1, 1, C)automatically by PyTorch. The same broadcast rule applies to any channels-last layout:(B, H, W, C),(B, T, H, W, C), etc.Training dynamics
The
init_valueargument controls the initial magnitude of every element ofγ. A small value (e.g.1e-4) means the residual update is almost entirely suppressed at the start of training, which:Prevents gradient explosion in very deep networks (depth ≥ 24).
Lets the skip connections carry most of the signal early on, and allows the residual branches to activate gradually once they have learned useful features.
Using a larger
init_value(e.g.1.0) is appropriate when fine-tuning from a checkpoint where the residual branches are already well-trained and suppressing them would slow convergence.The parameter
γis tagged with_no_weight_decay = Trueso that optimiser weight-decay regularisation (L2) is not applied to it. This is standard practice and matches the original CaiT training recipe.How it differs from a plain nn.Linear / scalar gate
Unlike a
nn.Linear(C, C)projection (which mixes channels), LayerScale applies an independent scalar per channel. This is equivalent to a diagonal linear mapdiag(γ)and requires onlyCparameters rather thanC².- gamma#
Learnable scale vector of shape
(dim,), initialised toinit_value. Tagged_no_weight_decay = Trueto exclude it from L2 weight-decay in the optimiser.- Type:
nn.Parameter
Example:
ls = LayerScale(dim=768, init_value=1e-4) x = torch.randn(2, 196, 768) # [B, T, C] out = ls(x) # [B, T, C], same shape as x
- __init__(dim, init_value=1e-4)#
Initialise the learnable scale vector.
- Parameters:
dim (int) – Channel dimension
C. Determines the length of thegammaparameter vector. Must match the channel (last) dimension of tensors passed toforward().init_value (float) –
Initial value for every element of
gamma. All elements are set to this scalar at construction time. Typical choices:1e-4— recommended for training from scratch on deep networks; effectively suppresses residual branches initially.1e-5— used in the original CaiT paper for the deepest (depth-48) variants.1.0— effectively disables the gating at initialisation; useful when fine-tuning from a strong pre-trained checkpoint.
- flop_count(num_tokens)#
Count floating-point multiply operations for one forward pass.
Each element of
xis multiplied by the corresponding element ofgamma, so the total number of scalar multiplications isFLOPs = num_tokens × dim
where
dim = self.gamma.shape[0]. The broadcast ofgammais free (no arithmetic), and the element-wise multiply is counted as one FLOP per output element (following the convention used throughout this codebase of counting multiply-only, not multiply-add pairs).
- forward(x)#
Apply per-channel scaling to the input tensor.
Multiplies every channel slice of
xby the corresponding scalar ingamma. PyTorch broadcastsgammaof shape(C,)across all leading dimensions ofxautomatically.- Parameters:
x (Tensor) –
Input tensor of shape
(*leading_dims, C)where the last dimension must equaldimpassed to the constructor. Common shapes:(B, T, C)— ViT-5 / transformer token sequences.(B, H, W, C)— channels-last 2-D feature maps.(B, T, H, W, C)— channels-last 3-D (video) tensors.
- Returns:
Scaled tensor with the same shape and dtype as
x, whereoutput[..., c] = x[..., c] * gamma[c]for each channel indexc.- Return type: