QKVConditionMixer#

class QKVConditionMixer(
hidden_dim,
mixer_cfg,
init_method_in=None,
init_method_out=None,
)#

Bases: Module

Cross-attention condition mixer that routes a conditioning signal into the feature map.

This module implements the condition mixer branch of ResidualBlock. It injects an external conditioning signal c (e.g. a timestep embedding, class label, or physics parameter vector) into the residual stream x via learned Q, K, V projections and a pluggable inner mixing operator:

x  ─[q_proj]──────────────────────────► Q ─┐
c  ─[kv_proj]──► split ──► K, V ────────────► inner_mixer(Q, K, V) ─[out_proj]──► y

Queries are derived from the current feature map x so that each spatial position can attend selectively to the conditioning tokens. Keys and values are derived from the conditioning signal c. This is therefore a form of cross-attention conditioning — more expressive than FiLM (which applies a uniform per-channel affine transform regardless of spatial content) and more efficient than full self-attention between concatenated feature and conditioning tokens.

The inner mixing computation is delegated to mixer_cfg, following the same operator-agnostic dispatch pattern as QKVSequenceMixer. Any module whose forward(q, k, v) conforms to channels-last tensors of shape (B, *spatial, C) and returns a tensor of the same shape as q can be used as the inner mixer. If the inner mixer returns a different shape, out_proj will raise a shape error.

Note

In practice x arriving at forward() has already been passed through condition_mixer_norm by the enclosing ResidualBlock. This module is not used with AdaLNZeroResidualBlock, which has no condition-mixer branch.

Weight initialisation

Optional curried initialisers init_method_in and init_method_out allow the caller to supply per-projection weight schedules (e.g. depth-scaled Gaussian init from GPT / Megatron). Both follow the signature fn(dim: int) -> fn(tensor: Tensor) -> None. When omitted, PyTorch’s default Kaiming-uniform init is used.

Weight decay

No weight-decay tags are set on the projections; the caller is responsible for any per-parameter weight-decay grouping (see the analogous logic in KernelFiLMGenerator).

See also

QKVSequenceMixer:

Structurally parallel module for self-attention / sequence mixing (forward(x) rather than forward(x, condition)).

ResidualBlock:

The enclosing block that calls this module’s forward after applying condition_mixer_norm to the residual stream.

Parameters:
mixer#

The instantiated inner mixing operator. Its forward(q, k, v) method receives channels-last tensors and must return a tensor of the same spatial shape and channel dimension as q.

Type:

torch.nn.Module

kv_proj#

Combined K+V projection (no bias) that maps the conditioning signal from C to 2·C. Weight shape: (2·hidden_dim, hidden_dim). The input channel dimension hidden_dim can be recovered via self.kv_proj.in_features.

Type:

torch.nn.Linear

q_proj#

Query projection (no bias) that maps the feature map from C to C. Weight shape: (hidden_dim, hidden_dim).

Type:

torch.nn.Linear

out_proj#

Output projection (no bias) that maps the mixer output back to C. Weight shape: (hidden_dim, hidden_dim).

Type:

torch.nn.Linear

__init__(
hidden_dim,
mixer_cfg,
init_method_in=None,
init_method_out=None,
)#

Initialise the QKVConditionMixer.

Parameters:
  • hidden_dim (int) – Channel dimension C shared by the feature map x and the conditioning signal c. All linear projections (q_proj, kv_proj, out_proj) are sized using this value. The conditioning tensor c must have its last dimension equal to hidden_dim; mismatches are not checked at init time and will silently produce wrong-shaped K/V tensors during the forward pass.

  • mixer_cfg (LazyConfig) – LazyConfig for the inner mixing operator. The instantiated module’s forward must accept (q, k, v) as positional arguments and return a tensor of the same shape as q. Any attention-compatible module (e.g. a dot-product attention layer) can be used here.

  • init_method_in (Callable[[int], Callable[[Tensor], None]] | None) – Optional curried weight initialiser applied to both q_proj.weight and kv_proj.weight. Must have the signature fn(dim: int) -> fn(tensor: Tensor) -> Nonefn(hidden_dim) is called and the returned callable is applied in-place to each weight matrix. Pass None to keep PyTorch’s default Kaiming-uniform init.

  • init_method_out (Callable[[int], Callable[[Tensor], None]] | None) – Same curried signature as init_method_in, applied to out_proj.weight. A common choice is a depth-scaled Gaussian (GPT / Megatron style) to control the residual branch variance at initialisation. Pass None to keep the default.

forward(x, condition)#

Inject the conditioning signal into the feature map via cross-attention.

Computes queries from the (pre-normalised) feature map x and keys/values from the conditioning signal condition, then mixes them with the inner mixer and projects the result back to hidden_dim.

Note

In the standard ResidualBlock usage, x has already been passed through condition_mixer_norm before this method is called.

The signal flow is:

Q = q_proj(x)                     # (B, *spatial_dims, C)
K, V = split(kv_proj(condition))  # each (B, *spatial_dims_cond, C)
y = out_proj(mixer(Q, K, V))      # (B, *spatial_dims, C)

The inner mixer must return a tensor of the same shape as Q (i.e. (B, *spatial_dims, C)); if it does not, out_proj will raise a shape error.

A global (non-spatial) conditioning vector of shape (B, C) is automatically unsqueezed to (B, 1, C) before the K/V projection so that the inner mixer sees a single conditioning token per sample.

Parameters:
  • x (Tensor) – Feature map tensor of shape (B, *spatial_dims, C), where B is the batch size, spatial_dims is one or more spatial axes (e.g. (H, W) for 2-D images or (T,) for 1-D sequences), and C = hidden_dim. Must have at least three dimensions (batch + one spatial axis + channel).

  • condition (Tensor) –

    Conditioning signal tensor. Two shapes are accepted:

    • (B, C) — global conditioning vector (e.g. a timestep or class embedding). Unsqueezed internally to (B, 1, C) before projection.

    • (B, *spatial_dims_cond, C) — spatially distributed conditioning tokens (e.g. encoder output). Must have the same number of dimensions as x. The spatial extent spatial_dims_cond need not match spatial_dims of x.

    The channel dimension C must equal hidden_dim (self.kv_proj.in_features); this is not checked at runtime.

Returns:

Output tensor of shape (B, *spatial_dims, C) — same spatial layout as x, with the conditioning signal blended in via cross-attention.

Raises:
  • ValueError – If x has fewer than three dimensions (i.e. is missing at least one spatial axis).

  • ValueError – If condition.ndim is neither 2 (global vector) nor equal to x.ndim (matching spatial rank).

Return type:

Tensor