Attention#

class Attention(
hidden_dim,
num_heads,
apply_qk_norm,
use_rope,
is_causal=False,
attn_dropout=0.0,
rope_base=10000.0,
rope_spatial_dims=None,
)#

Bases: Module

Multi-head scaled dot-product self-attention for 1D/2D/3D spatial inputs.

Computes standard multi-head attention:

\[ \begin{align}\begin{aligned}\text{head}_i = \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i\\\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,\ldots,\text{head}_H)\end{aligned}\end{align} \]

where \(d_k = C / H\) is the per-head dimension, \(H\) is the number of heads, and \(C\) is the hidden (channel) dimension.

Spatial layout#

Inputs and outputs use channels-last layout:

  • 1D sequences: [B, T, C]

  • 2D images: [B, H, W, C]

  • 3D volumes: [B, D, H, W, C]

Internally, spatial dimensions are flattened to a single sequence axis L = prod(spatial_dims) for the SDPA kernel, then unflattened on output.

Multi-head splitting#

The channel axis C is split into H heads of size d_k = C / H. Internally the module works with the merged batch-head axis (B * H, L, d_k) before the SDPA call and re-merges after.

QK normalisation (cosine attention)#

When apply_qk_norm=True, queries and keys are L2-normalised per head along the last dimension before the attention logits are formed. This replaces the 1/sqrt(d_k) scaling with a fixed scale of 1.0 to avoid flattening the already-normalised logits.

Rotary Positional Embeddings (RoPE)#

RoPE is applied to Q and K before QK-normalisation and before the SDPA call. The cos/sin buffers are precomputed once at __init__ from rope_spatial_dims and stored as non-persistent registered buffers (persistent=False) so they are reconstructed from __init__ args and never serialised to checkpoints. Head-dim divisibility requirements:

  • 1D: head_dim divisible by 2

  • 2D: head_dim divisible by 4 (two half-dim RoPE tables, one per axis)

  • 3D: head_dim divisible by 6 (three one-third-dim RoPE tables)

Context parallelism (CP)#

Not yet functional. Passing a cp_group with size() > 1 to forward immediately raises ValueError("Context parallelism must be revisited."). The zigzag all-gather/split code below the raise is dead code retained as a sketch for a future ring-attention implementation. Pass cp_group=None (the default) for all current use cases.

Backend selection#

Attention is computed with torch.nn.functional.scaled_dot_product_attention, which auto-selects FlashAttention (A100), cuDNN SDPA (H100), or a memory-efficient fallback based on device capability.

hidden_dim#

Total channel dimension C.

Type:

int

num_heads#

Number of attention heads H. In the current implementation all heads are computed on every rank (there is no head-parallel CP split). A # TODO(@farhad) in forward flags that local_num_heads is always equal to num_heads, which may need revisiting for tensor-parallel training.

Type:

int

head_dim#

Per-head dimension d_k = C / H.

Type:

int

scale#

Attention logit scale 1 / sqrt(d_k); set to 1.0 when apply_qk_norm=True.

Type:

float

apply_qk_norm#

Whether L2 QK normalisation is active.

Type:

bool

use_rope#

Whether RoPE positional encoding is active.

Type:

bool

rope_base#

Geometric base for RoPE frequency bands.

Type:

float

is_causal#

Whether to apply a causal (auto-regressive) mask.

Type:

bool

attn_dropout#

Dropout probability applied to attention weights during training. Set to 0.0 automatically at inference regardless of this value.

Type:

float

_rope_ndim#

Spatial rank for which RoPE was initialised (1, 2, or 3). Present only when use_rope=True; not defined otherwise. Used in forward to dispatch to the correct RoPE apply function.

Type:

int

param hidden_dim:

Total hidden-state dimension C. Must be divisible by num_heads.

type hidden_dim:

int

param num_heads:

Number of parallel attention heads H.

type num_heads:

int

param apply_qk_norm:

If True, L2-normalise Q and K per head along the last dimension (cosine attention).

type apply_qk_norm:

bool

param use_rope:

If True, apply Rotary Positional Embeddings to Q and K before the attention logits.

type use_rope:

bool

param is_causal:

If True, apply a causal attention mask so each position attends only to earlier positions. Defaults to False.

type is_causal:

bool

param attn_dropout:

Dropout rate on attention weights (active only during training). Defaults to 0.0.

type attn_dropout:

float

param rope_base:

Base frequency for RoPE; controls how fast the rotation frequency decays across head-dim pairs. Defaults to 10000.0.

type rope_base:

float

param rope_spatial_dims:

Spatial grid shape used to precompute RoPE tables. Required when use_rope=True. Examples: (4096,) for 1D, (64, 64) for 2D, (8, 64, 64) for 3D. Must match the spatial shape seen during forward.

type rope_spatial_dims:

tuple[int, …] | None

Example:

import torch
from nvsubquadratic.modules.attention import Attention

# 2D image attention with 8 heads, RoPE, and cosine-attention QK norm
attn = Attention(
    hidden_dim=256,
    num_heads=8,
    apply_qk_norm=True,
    use_rope=True,
    rope_spatial_dims=(32, 32),
)
q = k = v = torch.randn(2, 32, 32, 256)  # [B, H, W, C]
out = attn(q, k, v)  # [B, H, W, C]
assert out.shape == q.shape
__init__(
hidden_dim,
num_heads,
apply_qk_norm,
use_rope,
is_causal=False,
attn_dropout=0.0,
rope_base=10000.0,
rope_spatial_dims=None,
)#

Initialise the Attention module and precompute RoPE buffers.

Parameters:
  • hidden_dim (int) – Total channel dimension C. Must be divisible by num_heads.

  • num_heads (int) – Number of attention heads H.

  • apply_qk_norm (bool) – Whether to L2-normalise Q and K per head.

  • use_rope (bool) – Whether to apply Rotary Positional Embeddings.

  • is_causal (bool) – Whether to use a causal attention mask. Defaults to False.

  • attn_dropout (float) – Attention-weight dropout probability. Defaults to 0.0.

  • rope_base (float) – RoPE base frequency. Defaults to 10000.0.

  • rope_spatial_dims (tuple[int, ...] | None) – Spatial grid shape for RoPE table precomputation. Required when use_rope=True. Not stored as an instance attribute; the caller is responsible for tracking the spatial dims if they need to recover them after construction (e.g. for serialisation or extra_repr). The corresponding cos/sin buffers are stored as non-persistent registered buffers (rope_cos, rope_sin, etc.).

Raises:
  • AssertionError – If hidden_dim % num_heads != 0.

  • AssertionError – If use_rope=True and rope_spatial_dims is None.

  • AssertionError – If RoPE head-dim divisibility requirements are not met (divisible by 2 for 1D, 4 for 2D, 6 for 3D).

  • ValueError – If rope_spatial_dims has length other than 1, 2, or 3.

extra_repr()#

Return a concise string summary of this module’s configuration.

Returns:

Comma-separated key=value pairs for num_heads,

apply_qk_norm, is_causal, attn_dropout, use_rope, and rope_base.

Return type:

str

forward(query, key, value, cp_group=None)#

Apply multi-head scaled dot-product attention.

Computes:

\[\text{out} = \text{Concat}_{i=1}^{H} \left[ \text{softmax}\!\left( \frac{Q_i K_i^\top}{\sqrt{d_k}} \right) V_i \right]\]

where \(H\) = num_heads and \(d_k\) = head_dim. When apply_qk_norm=True, Q and K are L2-normalised before the logits are formed and the scale is 1.0 instead of 1/sqrt(d_k).

The forward pipeline is:

  1. (CP guard) Raises ValueError if cp_group.size() > 1; pass cp_group=None for all current use cases.

  2. Split channel dim into heads: [B, *spatial, C] [B*H, *spatial, d_k].

  3. (Optional) Apply RoPE to Q and K.

  4. (Optional) L2-normalise Q and K per head.

  5. Flatten spatial dims: [B*H, *spatial, d_k] [B*H, L, d_k].

  6. Reshape to SDPA layout: [B*H, L, d_k] [B, H, L, d_k].

  7. F.scaled_dot_product_attention (FlashAttention / cuDNN / fallback).

  8. Merge heads: [B, H, L, d_k] [B, L, C].

  9. Unflatten spatial dims: [B, L, C] [B, *spatial, C].

  10. (Optional CP) Zigzag-split output back to the local spatial slice.

Parameters:
  • query (torch.Tensor) – Query tensor of shape [B, *spatial_dims, C]. spatial_dims may be (T,), (H, W), or (D, H, W).

  • key (torch.Tensor) – Key tensor of shape [B, *spatial_dims, C]. Must match query shape.

  • value (torch.Tensor) – Value tensor of shape [B, *spatial_dims, C]. Must match query shape.

  • cp_group (torch.distributed.ProcessGroup | None) – Context-parallel process group. When not None and cp_group.size() > 1, the full spatial sequence is gathered before attention and split back afterwards. Currently raises ``ValueError`` as ring-attention is not yet implemented; provided for future compatibility. Defaults to None.

Returns:

Output of shape [B, *spatial_dims, C], the

same layout as the inputs.

Return type:

torch.Tensor

Raises:

ValueError – If cp_group is provided and has size > 1 (context parallelism is not yet supported).

Parameters: