Attention#
- class Attention(
- hidden_dim,
- num_heads,
- apply_qk_norm,
- use_rope,
- is_causal=False,
- attn_dropout=0.0,
- rope_base=10000.0,
- rope_spatial_dims=None,
Bases:
ModuleMulti-head scaled dot-product self-attention for 1D/2D/3D spatial inputs.
Computes standard multi-head attention:
\[ \begin{align}\begin{aligned}\text{head}_i = \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i\\\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\ldots,\text{head}_H)\end{aligned}\end{align} \]where \(d_k = C / H\) is the per-head dimension, \(H\) is the number of heads, and \(C\) is the hidden (channel) dimension.
Spatial layout#
Inputs and outputs use channels-last layout:
1D sequences:
[B, T, C]2D images:
[B, H, W, C]3D volumes:
[B, D, H, W, C]
Internally, spatial dimensions are flattened to a single sequence axis
L = prod(spatial_dims)for the SDPA kernel, then unflattened on output.Multi-head splitting#
The channel axis
Cis split intoHheads of sized_k = C / H. Internally the module works with the merged batch-head axis(B * H, L, d_k)before the SDPA call and re-merges after.QK normalisation (cosine attention)#
When
apply_qk_norm=True, queries and keys are L2-normalised per head along the last dimension before the attention logits are formed. This replaces the1/sqrt(d_k)scaling with a fixed scale of 1.0 to avoid flattening the already-normalised logits.Rotary Positional Embeddings (RoPE)#
RoPE is applied to Q and K before QK-normalisation and before the SDPA call. The cos/sin buffers are precomputed once at
__init__fromrope_spatial_dimsand stored as non-persistent registered buffers (persistent=False) so they are reconstructed from__init__args and never serialised to checkpoints. Head-dim divisibility requirements:1D:
head_dimdivisible by 22D:
head_dimdivisible by 4 (two half-dim RoPE tables, one per axis)3D:
head_dimdivisible by 6 (three one-third-dim RoPE tables)
Context parallelism (CP)#
Not yet functional. Passing a
cp_groupwithsize() > 1toforwardimmediately raisesValueError("Context parallelism must be revisited."). The zigzag all-gather/split code below theraiseis dead code retained as a sketch for a future ring-attention implementation. Passcp_group=None(the default) for all current use cases.Backend selection#
Attention is computed with
torch.nn.functional.scaled_dot_product_attention, which auto-selects FlashAttention (A100), cuDNN SDPA (H100), or a memory-efficient fallback based on device capability.Total channel dimension
C.- Type:
- num_heads#
Number of attention heads
H. In the current implementation all heads are computed on every rank (there is no head-parallel CP split). A# TODO(@farhad)inforwardflags thatlocal_num_headsis always equal tonum_heads, which may need revisiting for tensor-parallel training.- Type:
- attn_dropout#
Dropout probability applied to attention weights during training. Set to 0.0 automatically at inference regardless of this value.
- Type:
- _rope_ndim#
Spatial rank for which RoPE was initialised (1, 2, or 3). Present only when
use_rope=True; not defined otherwise. Used inforwardto dispatch to the correct RoPE apply function.- Type:
- param hidden_dim:
Total hidden-state dimension
C. Must be divisible bynum_heads.- type hidden_dim:
int
- param num_heads:
Number of parallel attention heads
H.- type num_heads:
int
- param apply_qk_norm:
If
True, L2-normalise Q and K per head along the last dimension (cosine attention).- type apply_qk_norm:
bool
- param use_rope:
If
True, apply Rotary Positional Embeddings to Q and K before the attention logits.- type use_rope:
bool
- param is_causal:
If
True, apply a causal attention mask so each position attends only to earlier positions. Defaults toFalse.- type is_causal:
bool
- param attn_dropout:
Dropout rate on attention weights (active only during training). Defaults to
0.0.- type attn_dropout:
float
- param rope_base:
Base frequency for RoPE; controls how fast the rotation frequency decays across head-dim pairs. Defaults to
10000.0.- type rope_base:
float
- param rope_spatial_dims:
Spatial grid shape used to precompute RoPE tables. Required when
use_rope=True. Examples:(4096,)for 1D,(64, 64)for 2D,(8, 64, 64)for 3D. Must match the spatial shape seen duringforward.- type rope_spatial_dims:
tuple[int, …] | None
Example:
import torch from nvsubquadratic.modules.attention import Attention # 2D image attention with 8 heads, RoPE, and cosine-attention QK norm attn = Attention( hidden_dim=256, num_heads=8, apply_qk_norm=True, use_rope=True, rope_spatial_dims=(32, 32), ) q = k = v = torch.randn(2, 32, 32, 256) # [B, H, W, C] out = attn(q, k, v) # [B, H, W, C] assert out.shape == q.shape
- __init__(
- hidden_dim,
- num_heads,
- apply_qk_norm,
- use_rope,
- is_causal=False,
- attn_dropout=0.0,
- rope_base=10000.0,
- rope_spatial_dims=None,
Initialise the Attention module and precompute RoPE buffers.
- Parameters:
hidden_dim (int) – Total channel dimension
C. Must be divisible bynum_heads.num_heads (int) – Number of attention heads
H.apply_qk_norm (bool) – Whether to L2-normalise Q and K per head.
use_rope (bool) – Whether to apply Rotary Positional Embeddings.
is_causal (bool) – Whether to use a causal attention mask. Defaults to
False.attn_dropout (float) – Attention-weight dropout probability. Defaults to
0.0.rope_base (float) – RoPE base frequency. Defaults to
10000.0.rope_spatial_dims (tuple[int, ...] | None) – Spatial grid shape for RoPE table precomputation. Required when
use_rope=True. Not stored as an instance attribute; the caller is responsible for tracking the spatial dims if they need to recover them after construction (e.g. for serialisation orextra_repr). The corresponding cos/sin buffers are stored as non-persistent registered buffers (rope_cos,rope_sin, etc.).
- Raises:
AssertionError – If
hidden_dim % num_heads != 0.AssertionError – If
use_rope=Trueandrope_spatial_dimsisNone.AssertionError – If RoPE head-dim divisibility requirements are not met (divisible by 2 for 1D, 4 for 2D, 6 for 3D).
ValueError – If
rope_spatial_dimshas length other than 1, 2, or 3.
- extra_repr()#
Return a concise string summary of this module’s configuration.
- Returns:
- Comma-separated key=value pairs for
num_heads, apply_qk_norm,is_causal,attn_dropout,use_rope, andrope_base.
- Comma-separated key=value pairs for
- Return type:
- forward(query, key, value, cp_group=None)#
Apply multi-head scaled dot-product attention.
Computes:
\[\text{out} = \text{Concat}_{i=1}^{H} \left[ \text{softmax}\!\left( \frac{Q_i K_i^\top}{\sqrt{d_k}} \right) V_i \right]\]where \(H\) =
num_headsand \(d_k\) =head_dim. Whenapply_qk_norm=True, Q and K are L2-normalised before the logits are formed and the scale is 1.0 instead of1/sqrt(d_k).The forward pipeline is:
(CP guard) Raises
ValueErrorifcp_group.size() > 1; passcp_group=Nonefor all current use cases.Split channel dim into heads:
[B, *spatial, C] → [B*H, *spatial, d_k].(Optional) Apply RoPE to Q and K.
(Optional) L2-normalise Q and K per head.
Flatten spatial dims:
[B*H, *spatial, d_k] → [B*H, L, d_k].Reshape to SDPA layout:
[B*H, L, d_k] → [B, H, L, d_k].F.scaled_dot_product_attention(FlashAttention / cuDNN / fallback).Merge heads:
[B, H, L, d_k] → [B, L, C].Unflatten spatial dims:
[B, L, C] → [B, *spatial, C].(Optional CP) Zigzag-split output back to the local spatial slice.
- Parameters:
query (torch.Tensor) – Query tensor of shape
[B, *spatial_dims, C].spatial_dimsmay be(T,),(H, W), or(D, H, W).key (torch.Tensor) – Key tensor of shape
[B, *spatial_dims, C]. Must matchqueryshape.value (torch.Tensor) – Value tensor of shape
[B, *spatial_dims, C]. Must matchqueryshape.cp_group (torch.distributed.ProcessGroup | None) – Context-parallel process group. When not
Noneandcp_group.size() > 1, the full spatial sequence is gathered before attention and split back afterwards. Currently raises ``ValueError`` as ring-attention is not yet implemented; provided for future compatibility. Defaults toNone.
- Returns:
- Output of shape
[B, *spatial_dims, C], the same layout as the inputs.
- Output of shape
- Return type:
- Raises:
ValueError – If
cp_groupis provided and has size > 1 (context parallelism is not yet supported).