ViT5HyenaAdapter#
- class ViT5HyenaAdapter(inner_mixer_cfg, grid_w)#
Bases:
ModuleBridges ViT-5’s
[B, T, C]token sequences and Hyena’s[B, H, W, C]spatial interface.The adapter is a parameter-free reshape wrapper: it does not own any QKV projection, output projection, or positional encoding. All learnable components live inside
inner_mixer(typically aQKVSequenceMixerwrapping aHyenainstance).Data flow:
x: [B, T, C] │ ▼ reshape (T → H × grid_w, where H = T // grid_w) x: [B, H, grid_w, C] │ ▼ inner_mixer (any [B, H, W, C]-preserving mixer) x: [B, H, grid_w, C] │ ▼ reshape back x: [B, T, C]What the adapter handles vs. what the inner mixer handles:
Adapter: shape contract (flat ↔ 2-D),
flop_countdelegation.Inner mixer: input projection (C → 3C), Hyena global convolution, gating, output projection (C → C), any normalisation, and optional FiLM / AdaLN conditioning. The mixer receives tensors in channels-last layout
[B, H, W, C]; if it uses channels-first convolution internally (asQKVSequenceMixerdoes) it handles the permutation itself.
- Register-token handling:
Register tokens (and the CLS token, if present) are treated as ordinary spatial positions within the reshaped grid — no masking or special-casing is applied. The upstream network is responsible for:
Padding the sequence so that
T % grid_w == 0.Choosing a
grid_wthat places register/CLS tokens in a predictable row (e.g. a dedicated “register row” at the bottom of the grid), so that the spatial convolution inside Hyena sees a consistent layout.In the hierarchical case, supplying the correct
grid_wat each stage after patch merging changes the spatial width.
- Parameters:
inner_mixer_cfg (LazyConfig)
grid_w (int)
- inner_mixer#
The instantiated 2-D sequence mixer. Accepts and returns
[B, H, W, C]tensors in channels-last layout. Typically aQKVSequenceMixerwrappingHyena.- Type:
nn.Module
- __init__(inner_mixer_cfg, grid_w)#
Instantiate the adapter and its inner 2-D mixer.
- Parameters:
inner_mixer_cfg (LazyConfig) –
LazyConfigdescribing the 2-D sequence mixer to instantiate (e.g.QKVSequenceMixerwrappingHyena). The instantiated module must accept(x: Tensor[B, H, W, C], **kwargs)in channels-last layout and return a tensor of the same shape. Any inner mixer that uses channels-first convolution (likeQKVSequenceMixer) handles the permutation internally. Projection dimensions (hidden_dim,num_heads, etc.) must be set inside this config; the adapter itself accepts nohidden_dimargument.grid_w (int) – Width of the 2-D spatial grid. Every call to
forwardmust supply a sequence lengthTthat satisfiesT % grid_w == 0; the grid height is computed asH = T // grid_w. In a hierarchical network, pass the correctgrid_wfor each stage (after patch merging). After a 2× patch-merging step,grid_whalves; the network’s stage configuration (e.g.ViT5HierarchicalClassificationNet) is the source of truth for each stage’sgrid_w.
- flop_count(num_tokens, inference=False)#
Delegate FLOPs accounting to the inner mixer.
The adapter’s reshape operations are pure metadata re-strides — zero arithmetic FLOPs — so the total cost is entirely determined by
inner_mixer.flop_count.Note
flop_countis a de-facto protocol, not enforced by a formal interface. To guard against missing implementations usehasattr(adapter.inner_mixer, "flop_count").- Parameters:
- Returns:
Total FLOPs reported by the inner mixer for a
(T // grid_w, grid_w)spatial grid.- Raises:
AttributeError – If
inner_mixerdoes not implementflop_count.- Return type:
- forward(x, **mixer_kwargs)#
Reshape to 2-D grid, apply the inner mixer, reshape back.
- Parameters:
x (Tensor) –
Input token sequence of shape
[B, T, C]whereB— batch size.T— total sequence length (must satisfyT % grid_w == 0). Typical layout (set by the network, not enforced here):[patch_tokens (H_patch * W_patch), CLS (0 or 1), register_tokens (R), padding (P)].C— channel / hidden dimension.
**mixer_kwargs –
Keyword arguments forwarded verbatim to
inner_mixer.forward. Common keys include:conditioning— FiLM/AdaLN conditioning tensor used by some Hyena configurations.cp_group— process group for context-parallel (AllToAll) sharding inside the Hyena operator.
Any additional kwargs accepted by the concrete inner mixer are also forwarded; consult the inner mixer’s docstring for the full list.
- Returns:
Tensor of shape
[B, T, C]— the token sequence after 2-D Hyena mixing. The firstreshape(to[B, H, W, C]) is a zero-copy view whenxis contiguous. Ifinner_mixerreturns a non-contiguous tensor, the finalreshape(back to[B, T, C]) triggers a contiguous copy; this does not affect correctness but can affect memory traffic in CUDA-graph ortorch.compilecontexts. In practice,QKVSequenceMixerreturns a contiguous tensor (its output projection is aLinearon the last axis), so the finalreshapeis typically a free view.- Raises:
RuntimeError – Raised by
torch.Tensor.reshapeifT % grid_w != 0, with a message reporting the mismatched total element count.- Return type: