QKVSequenceMixer#
- class QKVSequenceMixer(
- hidden_dim,
- mixer_cfg,
- qkv_bias=False,
- out_proj_bias=False,
- init_method_in=None,
- init_method_out=None,
Bases:
ModuleOperator-agnostic sequence mixer with shared QKV and output projections.
QKVSequenceMixermirrors the structure ofAttention(and ViT5Attention) so that any inner mixer — Hyena, attention, CKConv, Mamba — can be dropped in without changing the surrounding residual block.The forward pass is:
x ─[Linear(C → 3C, + bias?)]──► split ──► Q, K, V │ inner_mixer(Q, K, V, cp_group, **kwargs) │ [Linear(C → C, + bias?)]──► yThe QKV projection packs all three projections into a single
Linear(C, 3·C)call for efficiency; the output projection maps back toC. Both projections optionally include a bias term (disabled by default; seeqkv_biasandout_proj_bias).- Parameters:
- qkv_proj#
Combined Q+K+V input projection; maps
C→3·C(weight shape(3C, C)).- Type:
- out_proj#
Output projection; maps
C→C(weight shape(C, C)).- Type:
Example:
from nvsubquadratic.lazy_config import LazyConfig from nvsubquadratic.modules.hyena_nd import Hyena mixer_cfg = LazyConfig(Hyena)( global_conv_cfg=..., short_conv_cfg=..., gate_nonlinear_cfg=..., pixelhyena_norm_cfg=..., qk_norm_cfg=None, ) block = QKVSequenceMixer(hidden_dim=256, mixer_cfg=mixer_cfg) x = torch.randn(2, 32, 32, 256) # [B, H, W, C] y = block(x) # [B, H, W, C]
- __init__(
- hidden_dim,
- mixer_cfg,
- qkv_bias=False,
- out_proj_bias=False,
- init_method_in=None,
- init_method_out=None,
Initialise the QKV sequence mixer.
- Parameters:
hidden_dim (int) – Channel dimension
Cof the input / output tensor. Bothqkv_projandout_projare sized using this value.mixer_cfg (LazyConfig) –
LazyConfigfor the inner sequence-mixing operator. The target class’sforwardmethod must accept(q, k, v, cp_group, **kwargs)wherecp_groupis the fourth positional argument. Supported targets includeHyena,Attention,CKConvND, andMamba.qkv_bias (bool) – If
True, adds a learnable bias to the combined QKV projection. The bias is zero-initialised wheninit_method_inis provided. Defaults toFalse.out_proj_bias (bool) – If
True, adds a learnable bias to the output projection. Zero-initialised wheninit_method_outis provided. Defaults toFalse.init_method_in (Callable[[int], Callable[[Tensor], Tensor]] | None) – Optional curried weight initialiser for
qkv_proj. Must have the signaturefn(dim: int) -> fn(tensor: Tensor) -> None. When provided,fn(hidden_dim)is called and the returned callable is applied toqkv_proj.weight.data. Ifqkv_biasis alsoTrue, the bias is zero-initialised. PassNoneto use PyTorch’s default (Kaiming uniform).init_method_out (Callable[[int], Callable[[Tensor], Tensor]] | None) –
Same as
init_method_inbut applied toout_proj.weight.data. Typically a scaled initialiser that controls residual-branch variance (GPT/Megatron style), e.g.:import math init_method_out = ( lambda dim: lambda w: torch.nn.init.normal_( w, std=1 / math.sqrt(num_layers) ) )
- Raises:
Exception – Propagated from
instantiate()if the target class cannot be constructed (e.g. missing required arguments or an invalidmixer_cfg). The exact exception type depends on theLazyConfigbackend (typically anomegaconf.errors.InstantiationExceptionor similar). Checkmixer_cfg._target_and its keyword arguments if this is raised.
- flop_count(spatial_dims, inference=False)#
Count FLOPs for QKV projections + inner mixer + output projection.
Uses the standard multiply-accumulate convention where one FLOP = one multiply + one add (i.e. the matrix-vector product
y = WxoverTtokens costs2 · T · in_dim · out_dimFLOPs). Bias additions are excluded, following the standard ML FLOP-counting convention.FLOPs breakdown (
D=hidden_dim,T=prod(spatial_dims)):QKV projection
Linear(D, 3D):2 · T · D · 3D = 6 · T · D²Inner mixer (e.g. Hyena, attention): Delegated to
self.mixer.flop_count(spatial_dims, inference). For Hyena this is dominated by the FFT convolutionO(T log T · D); for attention it isO(T² · D).Output projection
Linear(D, D):2 · T · D²
Total (excluding inner mixer):
8 · T · D².- Parameters:
spatial_dims (tuple[int, ...]) – Spatial extents of the input signal, e.g.
(H, W)for images or(T,)for 1D sequences. Linear projections treat the flattened token countT = prod(spatial_dims)as the sequence length.inference (bool) – Forwarded to
self.mixer.flop_count. Some mixers (e.g. autoregressive Mamba) have different inference-time costs.
- Returns:
Total FLOPs as a non-negative integer.
- Raises:
AttributeError – If the inner mixer does not implement
flop_count.- Return type:
- forward(x, cp_group=None, **mixer_kwargs)#
Run the QKV-project → mix → output-project forward pass.
- Parameters:
x (Tensor) – Input tensor of shape
(B, *spatial, C)whereBis batch size,spatialis one or more spatial axes (e.g.(T,)for 1D,(H, W)for 2D,(D, H, W)for 3D), andC = hidden_dim.cp_group (ProcessGroup | None) – Optional context-parallel process group (
torch.distributed.ProcessGroup). When provided, the input is assumed to be already split across ranks along the spatial axis, and the inner mixer is responsible for the cross-rank communication (e.g. AllToAll for Hyena, ring-attention forAttention). PassNone(default) for single-GPU / non-distributed runs.**mixer_kwargs –
Additional keyword arguments forwarded verbatim to
self.mixer.forward. Mixers that do not recognise a key must accept and ignore it via their own**kwargs. Common keys:conditioning(torch.Tensor, shape(B, cond_dim)): FiLM conditioning vector consumed byHyenawhen acondition_mixeris attached. Ignored by mixers that do not have acondition_mixer(it passes through**mixer_kwargsand is discarded).
- Returns:
Output tensor of shape
(B, *spatial, C)— same layout as the input.- Return type: