MLP#
- class MLP(
- dim,
- activation,
- dropout_cfg,
- expansion_factor=2.0,
- bias=False,
- backend='torch',
- init_method_in=None,
- init_method_out=None,
Bases:
ModulePoint-wise two-layer MLP — the channel-mixing branch of each residual block.
Acts on each spatial position independently (no cross-position interaction), expanding the channel dimension by
expansion_factor, applying a non-linearity, and projecting back:Plain MLP (
activation in {"relu", "gelu", "silu"}):y = W₂( act( W₁(x) ) )
where
W₁ : C → H,W₂ : H → C, andH = floor(expansion_factor × C).Gated variants (
activation in {"glu", "swiglu"}):[a, b] = W₁(x) # W₁ : C → 2H, split at midpoint y = W₂( gate(a) ⊙ b ) # W₂ : H → C
GLU (
"glu"):gate(a) = sigmoid(a)SwiGLU (
"swiglu"):gate(a) = SiLU(a)
Because the first projection
W₁must produce2Houtputs for the gate,layer1has shape(2H, C)when a gated activation is used (seeglu_factorin__init__), whilelayer2remains(C, H). This means the parameter count differs from the plain variant:Plain:
C*H + H*C = 2CHparameters in the two linear layers.Gated:
C*2H + H*C = 3CHparameters in the two linear layers.
At
expansion_factor=2:Plain (GELU):
H = 2C, params ≈4C².SwiGLU:
H = 2C, params ≈6C².
To keep parameter counts comparable, gated configs often use a smaller
expansion_factor(e.g.4/3rather than2).Dropout is inserted between the two layers (applied after the activation and gate product).
The
backendargument selects the compute kernel family:"torch"(default): Standardnn.Linear+ PyTorch activation. Works everywhere."quack"(experimental, currently blocked): QuACK fused GEMM + activation kernels targeting Hopper / Blackwell GPUs. Requiresquack-kernels >= 0.3.0,bias=False, channel dimensions divisible by 8, and a supported activation. Currently raisesNotImplementedErrorat init time pending backward validation.
- Parameters:
Expanded hidden channel dimension
H = floor(expansion_factor * dim). For gated variants this is the post-gate width — the actuallayer1output width is2 * hidden_dim.- Type:
- activation#
Name of the activation function in use. One of
"relu","gelu","silu","glu","swiglu".- Type:
- is_glu_variant#
Truewhenactivationis"glu"or"swiglu", indicating thatlayer1produces2 * hidden_dimoutputs and the forward path applies a gating product.- Type:
- layer1#
First linear projection. Shape:
(hidden_dim * glu_factor, dim)whereglu_factoris 2 for gated activations, 1 otherwise.- Type:
nn.Linear
- dropout#
Dropout layer instantiated from
dropout_cfg. Applied between the activation (or gate product) andlayer2.- Type:
nn.Module
- layer2#
Second linear projection. Shape:
(dim, hidden_dim). Projects back to the input channel dimensionC.- Type:
nn.Linear
- __init__(
- dim,
- activation,
- dropout_cfg,
- expansion_factor=2.0,
- bias=False,
- backend='torch',
- init_method_in=None,
- init_method_out=None,
Initialise the MLP.
- Parameters:
dim (int) – Input and output channel dimension
C. Bothlayer1andlayer2preserve this outer dimension (the network’s residual stream width).activation (Literal['relu', 'gelu', 'silu', 'glu', 'swiglu']) –
Non-linearity to apply between the two linear layers. Controls whether the MLP is plain or gated:
"relu","gelu","silu": plain MLP;layer1mapsC → H."glu": sigmoid-gated MLP;layer1mapsC → 2Hand is split into gate + value halves."swiglu": SiLU-gated MLP (recommended for modern configs); same gating structure as"glu"but with SiLU in place of sigmoid.
dropout_cfg (LazyConfig) –
LazyConfigspecifying the dropout module inserted between the activation andlayer2. Usetorch.nn.Identity(or aLazyConfigthat targets it) for no dropout.expansion_factor (float) – Multiplier that sets the hidden dimension
H = floor(expansion_factor * dim). Defaults to2.0. For gated activations, the actuallayer1output width is2 * H; to keep total parameter count similar to a plain MLP withexpansion_factor=2, useexpansion_factor ≈ 4/3for SwiGLU / GLU.bias (bool) – Whether to include additive bias terms in
layer1andlayer2. Defaults toFalse. Must beFalsewhenbackend="quack".backend (Literal['torch', 'quack']) – Compute kernel family.
"torch"(default) usesnn.Linear+ PyTorch activation and runs on any hardware."quack"enables fused GEMM+activation kernels but is currently experimental (raisesNotImplementedErrorat init).init_method_in (Callable[[Tensor], Tensor] | None) – Optional weight initialiser for
layer1. Expected signature:init_method_in(out_features)(weight)— a curried callable that first takes the number of output features and returns an in-place initialiser. Bias is always zero-initialised when present, regardless of this argument. PassNoneto use PyTorch’s default Kaiming uniform init.init_method_out (Callable[[Tensor], Tensor] | None) – Optional weight initialiser for
layer2. Same curried signature asinit_method_in, applied tolayer2.weight. PassNonefor PyTorch default init.
- Raises:
NotImplementedError – If
backend="quack"(experimental; usebackend="torch"for now).ValueError – If
backend="quack"and any QuACK constraint is violated (unsupported activation,bias=True, or dimension not divisible by 8). Currently unreachable due to theNotImplementedErrorraised first.
- flop_count(num_tokens)#
Count FLOPs for a two-layer MLP applied to
num_tokenstokens.Structure: Linear(dim, hidden_dim * glu_factor) -> activation -> Linear(hidden_dim, dim).
- FLOPs breakdown (T = num_tokens):
layer1 (Linear): 2 * T *
self.layer1.in_features*self.layer1.out_featuresFor GLU/SwiGLU, out_features = 2 * hidden_dim (gate doubles width).Activation: T *
self.hidden_dim(elementwise). For GLU variants, an additional T *self.hidden_dimfor the gate multiply (SiLU on one half + elementwise product).layer2 (Linear): 2 * T *
self.layer2.in_features*self.layer2.out_features
- Convention: 1 MAC = 2 FLOPs for linear layers.
Activations count as 1 FLOP per element.
- forward(x)#
Apply the two-layer MLP to an ND feature tensor.
Operates identically on every spatial position — no cross-position interaction occurs in this module. The QuACK path bypasses dropout (QuACK kernels fuse both linear layers and the activation into a single kernel; dropout must be inserted by the caller if needed).
- Parameters:
x (Tensor) – Input feature tensor of shape
(B, *spatial_dims, C)in channels-last layout, whereBis the batch size,spatial_dimsare one or more spatial axes (e.g.(H, W)for 2-D images or(T,)for 1-D sequences), andCis the channel dimension (must equal thedimpassed at construction time).- Returns:
Output tensor of shape
(B, *spatial_dims, C), the same shape asx.- Return type: