MLP#

class MLP(
dim,
activation,
dropout_cfg,
expansion_factor=2.0,
bias=False,
backend='torch',
init_method_in=None,
init_method_out=None,
)#

Bases: Module

Point-wise two-layer MLP — the channel-mixing branch of each residual block.

Acts on each spatial position independently (no cross-position interaction), expanding the channel dimension by expansion_factor, applying a non-linearity, and projecting back:

Plain MLP (activation in {"relu", "gelu", "silu"}):

y = W₂( act( W₁(x) ) )

where W₁ : C H, W₂ : H C, and H = floor(expansion_factor × C).

Gated variants (activation in {"glu", "swiglu"}):

[a, b] = W₁(x)          # W₁ : C → 2H, split at midpoint
y = W₂( gate(a) ⊙ b )  # W₂ : H → C
  • GLU ("glu"): gate(a) = sigmoid(a)

  • SwiGLU ("swiglu"): gate(a) = SiLU(a)

Because the first projection W₁ must produce 2H outputs for the gate, layer1 has shape (2H, C) when a gated activation is used (see glu_factor in __init__), while layer2 remains (C, H). This means the parameter count differs from the plain variant:

  • Plain: C*H + H*C = 2CH parameters in the two linear layers.

  • Gated: C*2H + H*C = 3CH parameters in the two linear layers.

At expansion_factor=2:

  • Plain (GELU): H = 2C, params ≈ 4C².

  • SwiGLU: H = 2C, params ≈ 6C².

To keep parameter counts comparable, gated configs often use a smaller expansion_factor (e.g. 4/3 rather than 2).

Dropout is inserted between the two layers (applied after the activation and gate product).

The backend argument selects the compute kernel family:

  • "torch" (default): Standard nn.Linear + PyTorch activation. Works everywhere.

  • "quack" (experimental, currently blocked): QuACK fused GEMM + activation kernels targeting Hopper / Blackwell GPUs. Requires quack-kernels >= 0.3.0, bias=False, channel dimensions divisible by 8, and a supported activation. Currently raises NotImplementedError at init time pending backward validation.

Parameters:
hidden_dim#

Expanded hidden channel dimension H = floor(expansion_factor * dim). For gated variants this is the post-gate width — the actual layer1 output width is 2 * hidden_dim.

Type:

int

activation#

Name of the activation function in use. One of "relu", "gelu", "silu", "glu", "swiglu".

Type:

str

backend#

Compute backend, either "torch" or "quack".

Type:

str

is_glu_variant#

True when activation is "glu" or "swiglu", indicating that layer1 produces 2 * hidden_dim outputs and the forward path applies a gating product.

Type:

bool

layer1#

First linear projection. Shape: (hidden_dim * glu_factor, dim) where glu_factor is 2 for gated activations, 1 otherwise.

Type:

nn.Linear

dropout#

Dropout layer instantiated from dropout_cfg. Applied between the activation (or gate product) and layer2.

Type:

nn.Module

layer2#

Second linear projection. Shape: (dim, hidden_dim). Projects back to the input channel dimension C.

Type:

nn.Linear

__init__(
dim,
activation,
dropout_cfg,
expansion_factor=2.0,
bias=False,
backend='torch',
init_method_in=None,
init_method_out=None,
)#

Initialise the MLP.

Parameters:
  • dim (int) – Input and output channel dimension C. Both layer1 and layer2 preserve this outer dimension (the network’s residual stream width).

  • activation (Literal['relu', 'gelu', 'silu', 'glu', 'swiglu']) –

    Non-linearity to apply between the two linear layers. Controls whether the MLP is plain or gated:

    • "relu", "gelu", "silu": plain MLP; layer1 maps C H.

    • "glu": sigmoid-gated MLP; layer1 maps C 2H and is split into gate + value halves.

    • "swiglu": SiLU-gated MLP (recommended for modern configs); same gating structure as "glu" but with SiLU in place of sigmoid.

  • dropout_cfg (LazyConfig) – LazyConfig specifying the dropout module inserted between the activation and layer2. Use torch.nn.Identity (or a LazyConfig that targets it) for no dropout.

  • expansion_factor (float) – Multiplier that sets the hidden dimension H = floor(expansion_factor * dim). Defaults to 2.0. For gated activations, the actual layer1 output width is 2 * H; to keep total parameter count similar to a plain MLP with expansion_factor=2, use expansion_factor 4/3 for SwiGLU / GLU.

  • bias (bool) – Whether to include additive bias terms in layer1 and layer2. Defaults to False. Must be False when backend="quack".

  • backend (Literal['torch', 'quack']) – Compute kernel family. "torch" (default) uses nn.Linear + PyTorch activation and runs on any hardware. "quack" enables fused GEMM+activation kernels but is currently experimental (raises NotImplementedError at init).

  • init_method_in (Callable[[Tensor], Tensor] | None) – Optional weight initialiser for layer1. Expected signature: init_method_in(out_features)(weight) — a curried callable that first takes the number of output features and returns an in-place initialiser. Bias is always zero-initialised when present, regardless of this argument. Pass None to use PyTorch’s default Kaiming uniform init.

  • init_method_out (Callable[[Tensor], Tensor] | None) – Optional weight initialiser for layer2. Same curried signature as init_method_in, applied to layer2.weight. Pass None for PyTorch default init.

Raises:
  • NotImplementedError – If backend="quack" (experimental; use backend="torch" for now).

  • ValueError – If backend="quack" and any QuACK constraint is violated (unsupported activation, bias=True, or dimension not divisible by 8). Currently unreachable due to the NotImplementedError raised first.

flop_count(num_tokens)#

Count FLOPs for a two-layer MLP applied to num_tokens tokens.

Structure: Linear(dim, hidden_dim * glu_factor) -> activation -> Linear(hidden_dim, dim).

FLOPs breakdown (T = num_tokens):
  1. layer1 (Linear): 2 * T * self.layer1.in_features * self.layer1.out_features For GLU/SwiGLU, out_features = 2 * hidden_dim (gate doubles width).

  2. Activation: T * self.hidden_dim (elementwise). For GLU variants, an additional T * self.hidden_dim for the gate multiply (SiLU on one half + elementwise product).

  3. layer2 (Linear): 2 * T * self.layer2.in_features * self.layer2.out_features

Convention: 1 MAC = 2 FLOPs for linear layers.

Activations count as 1 FLOP per element.

Parameters:

num_tokens (int) – Number of tokens (positions) the MLP is applied to. Equal to the product of all spatial dimensions, i.e. H * W for 2-D images or T for 1-D sequences.

Returns:

Total FLOPs as an integer.

Return type:

int

forward(x)#

Apply the two-layer MLP to an ND feature tensor.

Operates identically on every spatial position — no cross-position interaction occurs in this module. The QuACK path bypasses dropout (QuACK kernels fuse both linear layers and the activation into a single kernel; dropout must be inserted by the caller if needed).

Parameters:

x (Tensor) – Input feature tensor of shape (B, *spatial_dims, C) in channels-last layout, where B is the batch size, spatial_dims are one or more spatial axes (e.g. (H, W) for 2-D images or (T,) for 1-D sequences), and C is the channel dimension (must equal the dim passed at construction time).

Returns:

Output tensor of shape (B, *spatial_dims, C), the same shape as x.

Return type:

torch.Tensor