Modules#
High-level torch.nn.Module building blocks that compose the ops above
into sequence and spatial mixers, plus the kernels, norms, gates, and
residual blocks they rely on.
Mixers#
Sequence/spatial mixers — Hyena, Mamba, attention variants.
|
Gated global convolutional mixer for ND signals. |
|
Selective state-space mixer for ND signals. |
|
Multi-head scaled dot-product self-attention for 1D/2D/3D spatial inputs. |
|
ViT-5 multi-head self-attention with RMSNorm QK-Norm and register-aware RoPE. |
|
Bridges ViT-5's |
|
Operator-agnostic sequence mixer with shared QKV and output projections. |
Convolutions#
Depthwise, multi-head, and continuous-kernel convolutions plus their context-parallel counterparts.
|
1D convolution with configurable causal (left-only) or symmetric padding. |
|
Depthwise causal 1D conv using |
|
N-dimensional Continuous Kernel Convolution (CKConv) operator. |
|
1-D depthwise convolution with CP-aware channel slicing and weight sharing. |
|
2-D depthwise convolution with CP-aware channel slicing and weight sharing. |
|
3-D depthwise convolution with CP-aware channel slicing and weight sharing. |
Kernels & filters#
Learned kernel parametrisations (SIREN, random Fourier features) and masks that produce the filters consumed by the FFT ops.
|
Convolutional kernel parametrised by a SIREN (sinusoidal representation network) MLP. |
|
N-dimensional positional embedding using a SIREN first layer. |
|
SIRENKernelND with a per-row ω₀ in the first (positional-embedding) layer. |
|
SIREN positional embedding with a per-row ω₀ in the first layer. |
|
Per-block ω₀ + (near-)block-diagonal MLP init for a SIREN kernel. |
|
SIRENKernelND whose first-layer ω₀ is multiplied by a learnable per-row scale. |
SIREN positional embedding with a learnable per-row ω₀ multiplier. |
|
Block-diagonal learnable-ω₀ SIREN kernel. |
|
|
Learned convolutional kernel parametrised via Random Fourier Features and an MLP. |
|
N-dimensional positional embedding using Random Fourier Features (RFF). |
|
Sine activation function used in SIREN networks. |
|
Fixed exponential-decay spatial window applied to implicit convolutional kernels. |
|
Learnable Gaussian-window spatial mask for ND convolutional kernels. |
|
Gaussian modulation with channel-reversed std ordering for block-structured SIRENs. |
Normalization#
|
Root Mean Square Layer Normalization (Zhang & Sennrich, arXiv:1910.07467). |
|
RMSNorm applied independently to each attention head (QK-norm). |
|
Root Mean Square Layer Normalization — channel-first layout (arXiv:1910.07467). |
|
Global Response Normalisation (GRN) layer (Woo et al., arXiv:2301.00808). |
|
Learnable per-channel scalar gate for residual branch outputs. |
Position encoding & patching#
|
Axis-factorised learned positional encoding for ND spatial token grids. |
|
Conv-based patch embedding for ND spatial signals (channels-last I/O). |
|
Inverse patch-embedding layer: reconstruct spatial signal from token grid. |
|
Point-wise two-layer MLP — the channel-mixing branch of each residual block. |
Gating & conditioning#
Drop-path, FiLM-style conditioning, and the QKV conditioning mixer that feeds Hyena’s per-sample kernels.
|
Drop paths (stochastic depth) per sample — |
|
Cross-attention condition mixer that routes a conditioning signal into the feature map. |
|
MLP that generates per-layer FiLM (γ, β) pairs from a conditioning vector. |
|
Learnable softmax-weighted average over register tokens. |
|
Compress each register token and concatenate into a single conditioning vector. |
|
Apply per-sample stochastic depth (functional form). |
Residual blocks#
|
Standard pre-norm residual block for ND signals. |
|
Pre-norm residual block with AdaLN-Zero conditioning (DiT-style). |
|
ViT-5 style residual block with LayerScale and stochastic depth. |
Schedulers#
|
|