Hyena#

class Hyena(
global_conv_cfg,
short_conv_cfg,
gate_nonlinear_cfg,
pixelhyena_norm_cfg,
qk_norm_cfg,
output_norm_cfg=LazyConfig(torch.nn.Identity)(),
gate_nonlinear_2_cfg=None,
)#

Bases: Module

Gated global convolutional mixer for ND signals.

The Hyena operator computes the following gated convolution (all tensors channels-first internally, channels-last on the public interface):

\[\begin{split}z &= Q \odot \sigma(K) \\ h &= \mathrm{GlobalConv}(z) \\ y &= h \odot \sigma_2(V)\end{split}\]

where \(\sigma\) is gate_nonlinear, \(\sigma_2\) is gate_nonlinear_2 (defaults to \(\sigma\)), and \(\mathrm{GlobalConv}\) is a depthwise FFT convolution whose kernel is generated on-the-fly by an implicit MLP (CKConvND).

Setting both gates to Identity gives a linear gating variant (element-wise products only). Setting \(\sigma = \mathrm{SiLU}\) and \(\sigma_2 = \mathrm{Sigmoid}\) matches the gated attention formulation used in the original Hyena paper.

Paper references#

The two-gate structure follows the H3 block (Fu et al., “Hungry Hungry Hippos”, ICLR 2023, arXiv:2212.14052, Section 3.2) and is generalised in Hyena (Poli et al., “Hyena Hierarchy: Towards Larger Convolutional Language Models”, ICML 2023, arXiv:2302.10866, Section 3 “The Hyena Recurrence”). The ND extension replaces the causal 1D FFT conv with a non-causal ND FFT conv (CKConvND).

Optional components (each disabled by passing Identity or None):
  • Short depthwise convolution on concatenated [Q, K, V]

  • QK normalisation (Q always; K only when \(\sigma = \mathrm{Identity}\))

  • PixelHyena normalisation between first gate and global conv

  • Output normalisation after second gate

  • Context parallelism via AllToAll communication (cp_group argument)

Example:

# Minimal 2D Hyena block (non-causal, no normalisation).
# In practice global_conv_cfg wraps a fully-configured CKConvND.
import torch
from nvsubquadratic.lazy_config import LazyConfig
from nvsubquadratic.modules.hyena_nd import Hyena

hyena = Hyena(
    global_conv_cfg=...,          # LazyConfig wrapping CKConvND
    short_conv_cfg=LazyConfig(torch.nn.Conv2d)(
        192, 192, 3, padding=1, groups=192
    ),
    gate_nonlinear_cfg=LazyConfig(torch.nn.SiLU)(),
    pixelhyena_norm_cfg=LazyConfig(torch.nn.Identity)(),
    qk_norm_cfg=None,
)
B, H, W, C = 2, 16, 16, 64
q = k = v = torch.randn(B, H, W, C)
y = hyena(q, k, v)  # [2, 16, 16, 64]
global_conv#

Long-range global convolution, typically CKConvND. Must expose hidden_dim and flop_count(spatial_dims, inference) for FLOP counting.

Type:

torch.nn.Module

short_conv#

Short depthwise convolution applied to the concatenated [Q, K, V] tensor (3·C input channels). Must be one of torch.nn.Conv{1,2,3}d, DistributedDepthwiseConv{1,2,3}d, or torch.nn.Identity.

Type:

torch.nn.Module

gate_nonlinear#

Activation \(\sigma\) for the first gate. Applied to K before multiplying with Q.

Type:

torch.nn.Module

gate_nonlinear_2#

Activation \(\sigma_2\) for the second gate. Applied to V before multiplying with h. Shares the same object as gate_nonlinear when gate_nonlinear_2_cfg is None.

Type:

torch.nn.Module

pixelhyena_norm#

Normalisation layer applied to z = Q σ(K) before the global conv. Parameters are excluded from weight-decay via _no_weight_decay = True.

Type:

torch.nn.Module

output_norm#

Normalisation layer applied to y = h σ₂(V) before returning. Parameters are excluded from weight-decay.

Type:

torch.nn.Module

q_norm#

Per-channel normalisation for Q. None when qk_norm_cfg is None.

Type:

torch.nn.Module | None

k_norm#

Per-channel normalisation for K. None when qk_norm_cfg is None (QK-norm entirely disabled). torch.nn.Identity when the gate is nonlinear (\(\sigma\) already bounds K’s magnitude); a fresh instance of qk_norm_cfg when the gate is Identity (linear gating).

Type:

torch.nn.Module | None

__init__(
global_conv_cfg,
short_conv_cfg,
gate_nonlinear_cfg,
pixelhyena_norm_cfg,
qk_norm_cfg,
output_norm_cfg=LazyConfig(torch.nn.Identity)(),
gate_nonlinear_2_cfg=None,
)#

Construct a Hyena gated global convolutional mixer.

All *_cfg arguments are LazyConfig objects that are instantiated inside __init__ via nvsubquadratic.lazy_config.instantiate. This pattern allows full Python configurability without importing module classes at config-definition time.

Parameters:
  • global_conv_cfg (LazyConfig) – LazyConfig for the long-range global convolution (e.g. CKConvND). The instantiated module must expose hidden_dim: int and flop_count(spatial_dims, inference) -> int.

  • short_conv_cfg (LazyConfig) – LazyConfig for the short depthwise conv applied to the concatenated [Q; K; V] tensor (3·C input channels). Must instantiate to one of torch.nn.Conv{1,2,3}d, DistributedDepthwiseConv{1,2,3}d, or torch.nn.Identity. Use Identity to skip the short conv entirely.

  • gate_nonlinear_cfg (LazyConfig) – LazyConfig for the first-gate activation \(\sigma(K)\) (e.g. SiLU). Use Identity for linear gating.

  • pixelhyena_norm_cfg (LazyConfig) – LazyConfig for the normalisation applied between the first gate and the global conv. Use Identity to disable. Parameters receive _no_weight_decay = True.

  • qk_norm_cfg (LazyConfig | None) – LazyConfig for per-channel normalisation of Q (and K when the gate is Identity). Pass None to disable QK-norm entirely. Two separate instances are created (one for Q, one for K) so that stateful norms (e.g. RMSNorm with a learnable scale) keep independent parameters.

  • output_norm_cfg (LazyConfig) – LazyConfig for the normalisation applied after the second gate. Defaults to a LazyConfig wrapping torch.nn.Identity (no normalisation). Do not pass an already-instantiated module — pass a LazyConfig object that wraps the class. Parameters receive _no_weight_decay = True.

  • gate_nonlinear_2_cfg (LazyConfig | None) – LazyConfig for the second-gate activation \(\sigma_2(V)\). If None (default), both gates share the same activation object (self.gate_nonlinear).

Raises:

AssertionError – If the instantiated short_conv is not one of the supported Conv / DistributedDepthwiseConv / Identity types.

extra_repr()#

Return a compact summary of key configuration choices.

Included fields:
  • q_norm / k_norm class names (or "None"). When QK-norm is disabled both are None; the strings "q_norm=None" and "k_norm=None" are still emitted so the disabled state is explicit in repr(module).

  • gates=<σ>/<σ₂> when the two gate activations differ.

  • is_causal when the global conv exposes that attribute.

Returns:

Comma-separated string suitable for repr(module) output.

Return type:

str

flop_count(spatial_dims, inference=False)#

Count FLOPs for one forward pass of the Hyena mixer.

Let C = self.global_conv.hidden_dim (the per-head channel count) and S = prod(spatial_dims) (total number of spatial positions). All counts use the multiply-add = 1 FLOP convention (i.e. a MAC counts as 1).

FLOP breakdown:

  1. Short depthwise conv on concatenated [Q; K; V] (3·C input channels):

    \[2 \cdot \frac{in\_ch}{groups} \cdot out\_ch \cdot S \cdot k\_prod\]

    where \(k\_prod = \prod_d kernel\_size_d\). Skipped when short_conv is Identity. For a pure depthwise conv (groups == in_ch == out_ch) this simplifies to 2 · out_ch · S · k_prod; the grouped formula is written here to handle partially-grouped convolutions (e.g. DistributedDepthwiseConvNd).

  2. QK-Norm (when self.q_norm is not None): 3·C·S for Q; additional 3·C·S for K only when gate_nonlinear is Identity (linear gating). The factor of 3 assumes an RMSNorm-like norm (sum-of-squares + rsqrt + elementwise scale). Other norm types will differ; this is an approximation.

  3. First gate \(z = Q \odot \sigma(K)\): C·S for the elementwise multiply, plus C·S for the activation \(\sigma\) when it is not Identity.

  4. PixelHyena norm (when not Identity): 3·C·S.

  5. Global convolution: delegated to self.global_conv.flop_count(spatial_dims, inference).

  6. Second gate \(y = h \odot \sigma_2(V)\): C·S for the multiply, plus C·S for \(\sigma_2\) when not Identity.

  7. Output norm (when not Identity): 3·C·S.

Parameters:
  • spatial_dims (tuple[int, ...]) – Spatial extent of the input per axis, e.g. (H, W) for a 2-D feature map of shape [B, C, H, W].

  • inference (bool) – Forwarded to self.global_conv.flop_count; some implementations skip re-generating the kernel at inference time when it is cached.

Returns:

Total FLOP count as an integer (multiply-add = 1 FLOP convention).

Return type:

int

forward(query, key, value, cp_group=None, **mixer_kwargs)#

Compute the Hyena gated global convolution.

Implements:

\[y = \mathrm{OutputNorm}\!\bigl( \mathrm{GlobalConv}\!\bigl( \mathrm{Norm}(Q \odot \sigma(K)) \bigr) \odot \sigma_2(V) \bigr)\]

Tensors enter and leave in channels-last layout [B, *spatial, C]. Internally the module works channels-first [B, C, *spatial] for the short conv and global conv.

Context parallelism (cp_group)#

When cp_group is provided and has size > 1, the method applies two AllToAll communications around the short conv so that each device sees the full spatial extent during the convolution:

  1. Before short conv: split_to_full — gather spatial shards along dim=2 (the first spatial axis), split along dim=1 (channels).

  2. After short conv: full_to_split — scatter spatial, gather channels back.

After step 1, each device holds the full spatial extent but only C / cp_size channels. After step 2, the original C channels are restored and each device holds spatial_0 / cp_size positions along the first spatial axis. The global conv receives only the local spatial slice and is expected to handle its own CP communication internally.

Implementation note#

The query tensor is overwritten after the first gate to hold the gated intermediate z = Q σ(K); the original Q tensor is no longer accessible after that point. This is intentional to avoid an extra allocation.

param query:

[B, *spatial, C] — query tensor, typically the output of a linear projection W_Q · x.

param key:

[B, *spatial, C] — key tensor, typically W_K · x.

param value:

[B, *spatial, C] — value tensor, typically W_V · x.

param cp_group:

torch.distributed.ProcessGroup for context parallelism. None disables CP (the default for single-GPU runs).

param **mixer_kwargs:

Extra keyword arguments forwarded verbatim to self.global_conv (e.g. conditioning for FiLM-conditioned CKConvND).

returns:

[B, *spatial, C] — output tensor in channels-last layout, same shape as the inputs.

Parameters:
Return type:

Tensor

Parameters: