QKVConditionMixer#
- class QKVConditionMixer(
- hidden_dim,
- mixer_cfg,
- init_method_in=None,
- init_method_out=None,
Bases:
ModuleCross-attention condition mixer that routes a conditioning signal into the feature map.
This module implements the condition mixer branch of
ResidualBlock. It injects an external conditioning signalc(e.g. a timestep embedding, class label, or physics parameter vector) into the residual streamxvia learned Q, K, V projections and a pluggable inner mixing operator:x ─[q_proj]──────────────────────────► Q ─┐ c ─[kv_proj]──► split ──► K, V ────────────► inner_mixer(Q, K, V) ─[out_proj]──► y
Queries are derived from the current feature map
xso that each spatial position can attend selectively to the conditioning tokens. Keys and values are derived from the conditioning signalc. This is therefore a form of cross-attention conditioning — more expressive than FiLM (which applies a uniform per-channel affine transform regardless of spatial content) and more efficient than full self-attention between concatenated feature and conditioning tokens.The inner mixing computation is delegated to
mixer_cfg, following the same operator-agnostic dispatch pattern asQKVSequenceMixer. Any module whoseforward(q, k, v)conforms to channels-last tensors of shape(B, *spatial, C)and returns a tensor of the same shape asqcan be used as the inner mixer. If the inner mixer returns a different shape,out_projwill raise a shape error.Note
In practice
xarriving atforward()has already been passed throughcondition_mixer_normby the enclosingResidualBlock. This module is not used withAdaLNZeroResidualBlock, which has no condition-mixer branch.Weight initialisation
Optional curried initialisers
init_method_inandinit_method_outallow the caller to supply per-projection weight schedules (e.g. depth-scaled Gaussian init from GPT / Megatron). Both follow the signaturefn(dim: int) -> fn(tensor: Tensor) -> None. When omitted, PyTorch’s default Kaiming-uniform init is used.Weight decay
No weight-decay tags are set on the projections; the caller is responsible for any per-parameter weight-decay grouping (see the analogous logic in
KernelFiLMGenerator).See also
QKVSequenceMixer:Structurally parallel module for self-attention / sequence mixing (
forward(x)rather thanforward(x, condition)).ResidualBlock:The enclosing block that calls this module’s
forwardafter applyingcondition_mixer_normto the residual stream.
- Parameters:
- mixer#
The instantiated inner mixing operator. Its
forward(q, k, v)method receives channels-last tensors and must return a tensor of the same spatial shape and channel dimension asq.- Type:
- kv_proj#
Combined K+V projection (no bias) that maps the conditioning signal from
Cto2·C. Weight shape:(2·hidden_dim, hidden_dim). The input channel dimensionhidden_dimcan be recovered viaself.kv_proj.in_features.- Type:
- q_proj#
Query projection (no bias) that maps the feature map from
CtoC. Weight shape:(hidden_dim, hidden_dim).- Type:
- out_proj#
Output projection (no bias) that maps the mixer output back to
C. Weight shape:(hidden_dim, hidden_dim).- Type:
- __init__(
- hidden_dim,
- mixer_cfg,
- init_method_in=None,
- init_method_out=None,
Initialise the QKVConditionMixer.
- Parameters:
hidden_dim (int) – Channel dimension
Cshared by the feature mapxand the conditioning signalc. All linear projections (q_proj,kv_proj,out_proj) are sized using this value. The conditioning tensorcmust have its last dimension equal tohidden_dim; mismatches are not checked at init time and will silently produce wrong-shaped K/V tensors during the forward pass.mixer_cfg (LazyConfig) –
LazyConfigfor the inner mixing operator. The instantiated module’sforwardmust accept(q, k, v)as positional arguments and return a tensor of the same shape asq. Any attention-compatible module (e.g. a dot-product attention layer) can be used here.init_method_in (Callable[[int], Callable[[Tensor], None]] | None) – Optional curried weight initialiser applied to both
q_proj.weightandkv_proj.weight. Must have the signaturefn(dim: int) -> fn(tensor: Tensor) -> None—fn(hidden_dim)is called and the returned callable is applied in-place to each weight matrix. PassNoneto keep PyTorch’s default Kaiming-uniform init.init_method_out (Callable[[int], Callable[[Tensor], None]] | None) – Same curried signature as
init_method_in, applied toout_proj.weight. A common choice is a depth-scaled Gaussian (GPT / Megatron style) to control the residual branch variance at initialisation. PassNoneto keep the default.
- forward(x, condition)#
Inject the conditioning signal into the feature map via cross-attention.
Computes queries from the (pre-normalised) feature map
xand keys/values from the conditioning signalcondition, then mixes them with the innermixerand projects the result back tohidden_dim.Note
In the standard
ResidualBlockusage,xhas already been passed throughcondition_mixer_normbefore this method is called.The signal flow is:
Q = q_proj(x) # (B, *spatial_dims, C) K, V = split(kv_proj(condition)) # each (B, *spatial_dims_cond, C) y = out_proj(mixer(Q, K, V)) # (B, *spatial_dims, C)
The inner
mixermust return a tensor of the same shape asQ(i.e.(B, *spatial_dims, C)); if it does not,out_projwill raise a shape error.A global (non-spatial) conditioning vector of shape
(B, C)is automatically unsqueezed to(B, 1, C)before the K/V projection so that the inner mixer sees a single conditioning token per sample.- Parameters:
x (Tensor) – Feature map tensor of shape
(B, *spatial_dims, C), whereBis the batch size,spatial_dimsis one or more spatial axes (e.g.(H, W)for 2-D images or(T,)for 1-D sequences), andC = hidden_dim. Must have at least three dimensions (batch + one spatial axis + channel).condition (Tensor) –
Conditioning signal tensor. Two shapes are accepted:
(B, C)— global conditioning vector (e.g. a timestep or class embedding). Unsqueezed internally to(B, 1, C)before projection.(B, *spatial_dims_cond, C)— spatially distributed conditioning tokens (e.g. encoder output). Must have the same number of dimensions asx. The spatial extentspatial_dims_condneed not matchspatial_dimsofx.
The channel dimension
Cmust equalhidden_dim(self.kv_proj.in_features); this is not checked at runtime.
- Returns:
Output tensor of shape
(B, *spatial_dims, C)— same spatial layout asx, with the conditioning signal blended in via cross-attention.- Raises:
ValueError – If
xhas fewer than three dimensions (i.e. is missing at least one spatial axis).ValueError – If
condition.ndimis neither2(global vector) nor equal tox.ndim(matching spatial rank).
- Return type: