QKVSequenceMixer#

class QKVSequenceMixer(
hidden_dim,
mixer_cfg,
qkv_bias=False,
out_proj_bias=False,
init_method_in=None,
init_method_out=None,
)#

Bases: Module

Operator-agnostic sequence mixer with shared QKV and output projections.

QKVSequenceMixer mirrors the structure of Attention (and ViT5Attention) so that any inner mixer — Hyena, attention, CKConv, Mamba — can be dropped in without changing the surrounding residual block.

The forward pass is:

x  ─[Linear(C → 3C, + bias?)]──► split ──► Q, K, V
                                                 │
                         inner_mixer(Q, K, V, cp_group, **kwargs)
                                                 │
                         [Linear(C → C, + bias?)]──► y

The QKV projection packs all three projections into a single Linear(C, 3·C) call for efficiency; the output projection maps back to C. Both projections optionally include a bias term (disabled by default; see qkv_bias and out_proj_bias).

Parameters:
mixer#

The instantiated inner sequence-mixing operator (e.g. Hyena).

Type:

torch.nn.Module

qkv_proj#

Combined Q+K+V input projection; maps C3·C (weight shape (3C, C)).

Type:

torch.nn.Linear

out_proj#

Output projection; maps CC (weight shape (C, C)).

Type:

torch.nn.Linear

Example:

from nvsubquadratic.lazy_config import LazyConfig
from nvsubquadratic.modules.hyena_nd import Hyena

mixer_cfg = LazyConfig(Hyena)(
    global_conv_cfg=...,
    short_conv_cfg=...,
    gate_nonlinear_cfg=...,
    pixelhyena_norm_cfg=...,
    qk_norm_cfg=None,
)
block = QKVSequenceMixer(hidden_dim=256, mixer_cfg=mixer_cfg)

x = torch.randn(2, 32, 32, 256)   # [B, H, W, C]
y = block(x)                       # [B, H, W, C]
__init__(
hidden_dim,
mixer_cfg,
qkv_bias=False,
out_proj_bias=False,
init_method_in=None,
init_method_out=None,
)#

Initialise the QKV sequence mixer.

Parameters:
  • hidden_dim (int) – Channel dimension C of the input / output tensor. Both qkv_proj and out_proj are sized using this value.

  • mixer_cfg (LazyConfig) – LazyConfig for the inner sequence-mixing operator. The target class’s forward method must accept (q, k, v, cp_group, **kwargs) where cp_group is the fourth positional argument. Supported targets include Hyena, Attention, CKConvND, and Mamba.

  • qkv_bias (bool) – If True, adds a learnable bias to the combined QKV projection. The bias is zero-initialised when init_method_in is provided. Defaults to False.

  • out_proj_bias (bool) – If True, adds a learnable bias to the output projection. Zero-initialised when init_method_out is provided. Defaults to False.

  • init_method_in (Callable[[int], Callable[[Tensor], Tensor]] | None) – Optional curried weight initialiser for qkv_proj. Must have the signature fn(dim: int) -> fn(tensor: Tensor) -> None. When provided, fn(hidden_dim) is called and the returned callable is applied to qkv_proj.weight.data. If qkv_bias is also True, the bias is zero-initialised. Pass None to use PyTorch’s default (Kaiming uniform).

  • init_method_out (Callable[[int], Callable[[Tensor], Tensor]] | None) –

    Same as init_method_in but applied to out_proj.weight.data. Typically a scaled initialiser that controls residual-branch variance (GPT/Megatron style), e.g.:

    import math
    init_method_out = (
        lambda dim: lambda w: torch.nn.init.normal_(
            w, std=1 / math.sqrt(num_layers)
        )
    )
    

Raises:

Exception – Propagated from instantiate() if the target class cannot be constructed (e.g. missing required arguments or an invalid mixer_cfg). The exact exception type depends on the LazyConfig backend (typically an omegaconf.errors.InstantiationException or similar). Check mixer_cfg._target_ and its keyword arguments if this is raised.

flop_count(spatial_dims, inference=False)#

Count FLOPs for QKV projections + inner mixer + output projection.

Uses the standard multiply-accumulate convention where one FLOP = one multiply + one add (i.e. the matrix-vector product y = Wx over T tokens costs 2 · T · in_dim · out_dim FLOPs). Bias additions are excluded, following the standard ML FLOP-counting convention.

FLOPs breakdown (D = hidden_dim, T = prod(spatial_dims)):

  1. QKV projection Linear(D, 3D): 2 · T · D · 3D = 6 · T ·

  2. Inner mixer (e.g. Hyena, attention): Delegated to self.mixer.flop_count(spatial_dims, inference). For Hyena this is dominated by the FFT convolution O(T log T · D); for attention it is O(T² · D).

  3. Output projection Linear(D, D): 2 · T ·

Total (excluding inner mixer): 8 · T · .

Parameters:
  • spatial_dims (tuple[int, ...]) – Spatial extents of the input signal, e.g. (H, W) for images or (T,) for 1D sequences. Linear projections treat the flattened token count T = prod(spatial_dims) as the sequence length.

  • inference (bool) – Forwarded to self.mixer.flop_count. Some mixers (e.g. autoregressive Mamba) have different inference-time costs.

Returns:

Total FLOPs as a non-negative integer.

Raises:

AttributeError – If the inner mixer does not implement flop_count.

Return type:

int

forward(x, cp_group=None, **mixer_kwargs)#

Run the QKV-project → mix → output-project forward pass.

Parameters:
  • x (Tensor) – Input tensor of shape (B, *spatial, C) where B is batch size, spatial is one or more spatial axes (e.g. (T,) for 1D, (H, W) for 2D, (D, H, W) for 3D), and C = hidden_dim.

  • cp_group (ProcessGroup | None) – Optional context-parallel process group (torch.distributed.ProcessGroup). When provided, the input is assumed to be already split across ranks along the spatial axis, and the inner mixer is responsible for the cross-rank communication (e.g. AllToAll for Hyena, ring-attention for Attention). Pass None (default) for single-GPU / non-distributed runs.

  • **mixer_kwargs

    Additional keyword arguments forwarded verbatim to self.mixer.forward. Mixers that do not recognise a key must accept and ignore it via their own **kwargs. Common keys:

    • conditioning (torch.Tensor, shape (B, cond_dim)): FiLM conditioning vector consumed by Hyena when a condition_mixer is attached. Ignored by mixers that do not have a condition_mixer (it passes through **mixer_kwargs and is discarded).

Returns:

Output tensor of shape (B, *spatial, C) — same layout as the input.

Return type:

Tensor