Modules#

High-level torch.nn.Module building blocks that compose the ops above into sequence and spatial mixers, plus the kernels, norms, gates, and residual blocks they rely on.

Mixers#

Sequence/spatial mixers — Hyena, Mamba, attention variants.

Hyena(global_conv_cfg, short_conv_cfg, ...)

Gated global convolutional mixer for ND signals.

Mamba(mamba_layer_cfg[, bidirectional])

Selective state-space mixer for ND signals.

Attention(hidden_dim, num_heads, ...[, ...])

Multi-head scaled dot-product self-attention for 1D/2D/3D spatial inputs.

ViT5Attention(hidden_dim, num_heads, ...[, ...])

ViT-5 multi-head self-attention with RMSNorm QK-Norm and register-aware RoPE.

ViT5HyenaAdapter(inner_mixer_cfg, grid_w)

Bridges ViT-5's [B, T, C] token sequences and Hyena's [B, H, W, C] spatial interface.

QKVSequenceMixer(hidden_dim, mixer_cfg[, ...])

Operator-agnostic sequence mixer with shared QKV and output projections.

Convolutions#

Depthwise, multi-head, and continuous-kernel convolutions plus their context-parallel counterparts.

CausalConv1D(*args[, is_causal])

1D convolution with configurable causal (left-only) or symmetric padding.

SubqOpsCausalConv1d(in_channels, ...[, ...])

Depthwise causal 1D conv using subquadratic_ops_torch.causal_conv1d.

CKConvND(data_dim, hidden_dim, kernel_cfg, ...)

N-dimensional Continuous Kernel Convolution (CKConv) operator.

DistributedDepthwiseConv1d(hidden_dim, ...)

1-D depthwise convolution with CP-aware channel slicing and weight sharing.

DistributedDepthwiseConv2d(hidden_dim, ...)

2-D depthwise convolution with CP-aware channel slicing and weight sharing.

DistributedDepthwiseConv3d(hidden_dim, ...)

3-D depthwise convolution with CP-aware channel slicing and weight sharing.

Kernels & filters#

Learned kernel parametrisations (SIREN, random Fourier features) and masks that produce the filters consumed by the FFT ops.

SIRENKernelND(out_dim, data_dim, ...[, ...])

Convolutional kernel parametrised by a SIREN (sinusoidal representation network) MLP.

SIRENPositionalEmbeddingND(data_dim, ...[, ...])

N-dimensional positional embedding using a SIREN first layer.

MultiOmegaSIRENKernelND(out_dim, data_dim, ...)

SIRENKernelND with a per-row ω₀ in the first (positional-embedding) layer.

MultiOmegaSIRENPositionalEmbeddingND(...[, ...])

SIREN positional embedding with a per-row ω₀ in the first layer.

BlockDiagonalMultiOmegaSIRENKernelND(...[, ...])

Per-block ω₀ + (near-)block-diagonal MLP init for a SIREN kernel.

LearnableOmegaSIRENKernelND(out_dim, ...[, ...])

SIRENKernelND whose first-layer ω₀ is multiplied by a learnable per-row scale.

LearnableOmegaSIRENPositionalEmbeddingND(...)

SIREN positional embedding with a learnable per-row ω₀ multiplier.

BlockDiagonalLearnableOmegaSIRENKernelND(...)

Block-diagonal learnable-ω₀ SIREN kernel.

RandomFourierKernelND(out_dim, data_dim, ...)

Learned convolutional kernel parametrised via Random Fourier Features and an MLP.

RandomFourierPositionalEmbeddingND(data_dim, ...)

N-dimensional positional embedding using Random Fourier Features (RFF).

Sine(*args, **kwargs)

Sine activation function used in SIREN networks.

ExponentialModulationND(data_dim, num_channels)

Fixed exponential-decay spatial window applied to implicit convolutional kernels.

GaussianModulationND(data_dim, num_channels, ...)

Learnable Gaussian-window spatial mask for ND convolutional kernels.

BlockAlignedGaussianModulationND(data_dim, ...)

Gaussian modulation with channel-reversed std ordering for block-structured SIRENs.

Normalization#

RMSNorm(dim[, eps, use_quack])

Root Mean Square Layer Normalization (Zhang & Sennrich, arXiv:1910.07467).

PerHeadRMSNorm(num_heads, head_dim[, eps, ...])

RMSNorm applied independently to each attention head (QK-norm).

RMSNormChannelFirst(dim[, eps, use_quack])

Root Mean Square Layer Normalization — channel-first layout (arXiv:1910.07467).

GlobalResponseNorm(dim[, eps])

Global Response Normalisation (GRN) layer (Woo et al., arXiv:2301.00808).

LayerScale(dim[, init_value])

Learnable per-channel scalar gate for residual branch outputs.

Position encoding & patching#

PositionEmbeddingND(embedding_dim, data_dim, ...)

Axis-factorised learned positional encoding for ND spatial token grids.

Patchify(in_features, out_features, ...[, ...])

Conv-based patch embedding for ND spatial signals (channels-last I/O).

Unpatchify(in_features, out_features, ...[, ...])

Inverse patch-embedding layer: reconstruct spatial signal from token grid.

MLP(dim, activation, dropout_cfg[, ...])

Point-wise two-layer MLP — the channel-mixing branch of each residual block.

Gating & conditioning#

Drop-path, FiLM-style conditioning, and the QKV conditioning mixer that feeds Hyena’s per-sample kernels.

DropPath([drop_prob])

Drop paths (stochastic depth) per sample — nn.Module wrapper.

QKVConditionMixer(hidden_dim, mixer_cfg[, ...])

Cross-attention condition mixer that routes a conditioning signal into the feature map.

KernelFiLMGenerator(cond_dim, ...[, ...])

MLP that generates per-layer FiLM (γ, β) pairs from a conditioning vector.

RegisterPooling(num_registers)

Learnable softmax-weighted average over register tokens.

RegisterCompressConcat(num_registers, ...)

Compress each register token and concatenate into a single conditioning vector.

drop_path(x, drop_prob, training)

Apply per-sample stochastic depth (functional form).

Residual blocks#

ResidualBlock(sequence_mixer_cfg, ...)

Standard pre-norm residual block for ND signals.

AdaLNZeroResidualBlock(sequence_mixer_cfg, ...)

Pre-norm residual block with AdaLN-Zero conditioning (DiT-style).

ViT5ResidualBlock(sequence_mixer_cfg, ...[, ...])

ViT-5 style residual block with LayerScale and stochastic depth.

Schedulers#

ResumableSequentialLR(optimizer, schedulers, ...)

SequentialLR with a corrected load_state_dict.