Hyena#
- class Hyena(
- global_conv_cfg,
- short_conv_cfg,
- gate_nonlinear_cfg,
- pixelhyena_norm_cfg,
- qk_norm_cfg,
- output_norm_cfg=LazyConfig(torch.nn.Identity)(),
- gate_nonlinear_2_cfg=None,
Bases:
ModuleGated global convolutional mixer for ND signals.
The Hyena operator computes the following gated convolution (all tensors channels-first internally, channels-last on the public interface):
\[\begin{split}z &= Q \odot \sigma(K) \\ h &= \mathrm{GlobalConv}(z) \\ y &= h \odot \sigma_2(V)\end{split}\]where \(\sigma\) is
gate_nonlinear, \(\sigma_2\) isgate_nonlinear_2(defaults to \(\sigma\)), and \(\mathrm{GlobalConv}\) is a depthwise FFT convolution whose kernel is generated on-the-fly by an implicit MLP (CKConvND).Setting both gates to
Identitygives a linear gating variant (element-wise products only). Setting \(\sigma = \mathrm{SiLU}\) and \(\sigma_2 = \mathrm{Sigmoid}\) matches the gated attention formulation used in the original Hyena paper.Paper references#
The two-gate structure follows the H3 block (Fu et al., “Hungry Hungry Hippos”, ICLR 2023, arXiv:2212.14052, Section 3.2) and is generalised in Hyena (Poli et al., “Hyena Hierarchy: Towards Larger Convolutional Language Models”, ICML 2023, arXiv:2302.10866, Section 3 “The Hyena Recurrence”). The ND extension replaces the causal 1D FFT conv with a non-causal ND FFT conv (
CKConvND).- Optional components (each disabled by passing
IdentityorNone): Short depthwise convolution on concatenated
[Q, K, V]QK normalisation (Q always; K only when \(\sigma = \mathrm{Identity}\))
PixelHyena normalisation between first gate and global conv
Output normalisation after second gate
Context parallelism via AllToAll communication (
cp_groupargument)
Example:
# Minimal 2D Hyena block (non-causal, no normalisation). # In practice global_conv_cfg wraps a fully-configured CKConvND. import torch from nvsubquadratic.lazy_config import LazyConfig from nvsubquadratic.modules.hyena_nd import Hyena hyena = Hyena( global_conv_cfg=..., # LazyConfig wrapping CKConvND short_conv_cfg=LazyConfig(torch.nn.Conv2d)( 192, 192, 3, padding=1, groups=192 ), gate_nonlinear_cfg=LazyConfig(torch.nn.SiLU)(), pixelhyena_norm_cfg=LazyConfig(torch.nn.Identity)(), qk_norm_cfg=None, ) B, H, W, C = 2, 16, 16, 64 q = k = v = torch.randn(B, H, W, C) y = hyena(q, k, v) # [2, 16, 16, 64]
- global_conv#
Long-range global convolution, typically
CKConvND. Must exposehidden_dimandflop_count(spatial_dims, inference)for FLOP counting.- Type:
- short_conv#
Short depthwise convolution applied to the concatenated
[Q, K, V]tensor (3·C input channels). Must be one oftorch.nn.Conv{1,2,3}d,DistributedDepthwiseConv{1,2,3}d, ortorch.nn.Identity.- Type:
- gate_nonlinear#
Activation \(\sigma\) for the first gate. Applied to K before multiplying with Q.
- Type:
- gate_nonlinear_2#
Activation \(\sigma_2\) for the second gate. Applied to V before multiplying with h. Shares the same object as
gate_nonlinearwhengate_nonlinear_2_cfgisNone.- Type:
- pixelhyena_norm#
Normalisation layer applied to
z = Q ⊙ σ(K)before the global conv. Parameters are excluded from weight-decay via_no_weight_decay = True.- Type:
- output_norm#
Normalisation layer applied to
y = h ⊙ σ₂(V)before returning. Parameters are excluded from weight-decay.- Type:
- q_norm#
Per-channel normalisation for Q.
Nonewhenqk_norm_cfgisNone.- Type:
torch.nn.Module | None
- k_norm#
Per-channel normalisation for K.
Nonewhenqk_norm_cfgisNone(QK-norm entirely disabled).torch.nn.Identitywhen the gate is nonlinear (\(\sigma\) already bounds K’s magnitude); a fresh instance ofqk_norm_cfgwhen the gate isIdentity(linear gating).- Type:
torch.nn.Module | None
- __init__(
- global_conv_cfg,
- short_conv_cfg,
- gate_nonlinear_cfg,
- pixelhyena_norm_cfg,
- qk_norm_cfg,
- output_norm_cfg=LazyConfig(torch.nn.Identity)(),
- gate_nonlinear_2_cfg=None,
Construct a Hyena gated global convolutional mixer.
All
*_cfgarguments areLazyConfigobjects that are instantiated inside__init__vianvsubquadratic.lazy_config.instantiate. This pattern allows full Python configurability without importing module classes at config-definition time.- Parameters:
global_conv_cfg (LazyConfig) –
LazyConfigfor the long-range global convolution (e.g.CKConvND). The instantiated module must exposehidden_dim: intandflop_count(spatial_dims, inference) -> int.short_conv_cfg (LazyConfig) –
LazyConfigfor the short depthwise conv applied to the concatenated[Q; K; V]tensor (3·C input channels). Must instantiate to one oftorch.nn.Conv{1,2,3}d,DistributedDepthwiseConv{1,2,3}d, ortorch.nn.Identity. UseIdentityto skip the short conv entirely.gate_nonlinear_cfg (LazyConfig) –
LazyConfigfor the first-gate activation \(\sigma(K)\) (e.g.SiLU). UseIdentityfor linear gating.pixelhyena_norm_cfg (LazyConfig) –
LazyConfigfor the normalisation applied between the first gate and the global conv. UseIdentityto disable. Parameters receive_no_weight_decay = True.qk_norm_cfg (LazyConfig | None) –
LazyConfigfor per-channel normalisation of Q (and K when the gate isIdentity). PassNoneto disable QK-norm entirely. Two separate instances are created (one for Q, one for K) so that stateful norms (e.g.RMSNormwith a learnable scale) keep independent parameters.output_norm_cfg (LazyConfig) –
LazyConfigfor the normalisation applied after the second gate. Defaults to aLazyConfigwrappingtorch.nn.Identity(no normalisation). Do not pass an already-instantiated module — pass aLazyConfigobject that wraps the class. Parameters receive_no_weight_decay = True.gate_nonlinear_2_cfg (LazyConfig | None) –
LazyConfigfor the second-gate activation \(\sigma_2(V)\). IfNone(default), both gates share the same activation object (self.gate_nonlinear).
- Raises:
AssertionError – If the instantiated
short_convis not one of the supported Conv / DistributedDepthwiseConv / Identity types.
- extra_repr()#
Return a compact summary of key configuration choices.
- Included fields:
q_norm/k_normclass names (or"None"). When QK-norm is disabled both areNone; the strings"q_norm=None"and"k_norm=None"are still emitted so the disabled state is explicit inrepr(module).gates=<σ>/<σ₂>when the two gate activations differ.is_causalwhen the global conv exposes that attribute.
- Returns:
Comma-separated string suitable for
repr(module)output.- Return type:
- flop_count(spatial_dims, inference=False)#
Count FLOPs for one forward pass of the Hyena mixer.
Let
C = self.global_conv.hidden_dim(the per-head channel count) andS = prod(spatial_dims)(total number of spatial positions). All counts use the multiply-add = 1 FLOP convention (i.e. a MAC counts as 1).FLOP breakdown:
Short depthwise conv on concatenated
[Q; K; V](3·Cinput channels):\[2 \cdot \frac{in\_ch}{groups} \cdot out\_ch \cdot S \cdot k\_prod\]where \(k\_prod = \prod_d kernel\_size_d\). Skipped when
short_convisIdentity. For a pure depthwise conv (groups == in_ch == out_ch) this simplifies to2 · out_ch · S · k_prod; the grouped formula is written here to handle partially-grouped convolutions (e.g.DistributedDepthwiseConvNd).QK-Norm (when
self.q_norm is not None):3·C·Sfor Q; additional3·C·Sfor K only whengate_nonlinearisIdentity(linear gating). The factor of 3 assumes an RMSNorm-like norm (sum-of-squares + rsqrt + elementwise scale). Other norm types will differ; this is an approximation.First gate \(z = Q \odot \sigma(K)\):
C·Sfor the elementwise multiply, plusC·Sfor the activation \(\sigma\) when it is notIdentity.PixelHyena norm (when not
Identity):3·C·S.Global convolution: delegated to
self.global_conv.flop_count(spatial_dims, inference).Second gate \(y = h \odot \sigma_2(V)\):
C·Sfor the multiply, plusC·Sfor \(\sigma_2\) when notIdentity.Output norm (when not
Identity):3·C·S.
- Parameters:
- Returns:
Total FLOP count as an integer (multiply-add = 1 FLOP convention).
- Return type:
- forward(query, key, value, cp_group=None, **mixer_kwargs)#
Compute the Hyena gated global convolution.
Implements:
\[y = \mathrm{OutputNorm}\!\bigl( \mathrm{GlobalConv}\!\bigl( \mathrm{Norm}(Q \odot \sigma(K)) \bigr) \odot \sigma_2(V) \bigr)\]Tensors enter and leave in channels-last layout
[B, *spatial, C]. Internally the module works channels-first[B, C, *spatial]for the short conv and global conv.Context parallelism (
cp_group)#When
cp_groupis provided and has size > 1, the method applies two AllToAll communications around the short conv so that each device sees the full spatial extent during the convolution:Before short conv:
split_to_full— gather spatial shards alongdim=2(the first spatial axis), split alongdim=1(channels).After short conv:
full_to_split— scatter spatial, gather channels back.
After step 1, each device holds the full spatial extent but only
C / cp_sizechannels. After step 2, the originalCchannels are restored and each device holdsspatial_0 / cp_sizepositions along the first spatial axis. The global conv receives only the local spatial slice and is expected to handle its own CP communication internally.Implementation note#
The
querytensor is overwritten after the first gate to hold the gated intermediatez = Q ⊙ σ(K); the original Q tensor is no longer accessible after that point. This is intentional to avoid an extra allocation.- param query:
[B, *spatial, C]— query tensor, typically the output of a linear projectionW_Q · x.- param key:
[B, *spatial, C]— key tensor, typicallyW_K · x.- param value:
[B, *spatial, C]— value tensor, typicallyW_V · x.- param cp_group:
torch.distributed.ProcessGroupfor context parallelism.Nonedisables CP (the default for single-GPU runs).- param **mixer_kwargs:
Extra keyword arguments forwarded verbatim to
self.global_conv(e.g.conditioningfor FiLM-conditionedCKConvND).- returns:
[B, *spatial, C]— output tensor in channels-last layout, same shape as the inputs.
- Parameters:
query (Tensor)
key (Tensor)
value (Tensor)
cp_group (ProcessGroup)
- Return type:
- Parameters:
global_conv_cfg (LazyConfig)
short_conv_cfg (LazyConfig)
gate_nonlinear_cfg (LazyConfig)
pixelhyena_norm_cfg (LazyConfig)
qk_norm_cfg (LazyConfig | None)
output_norm_cfg (LazyConfig)
gate_nonlinear_2_cfg (LazyConfig | None)
- Optional components (each disabled by passing