nvsubquadratic.ops — FFT convolution primitives#
This folder contains the lowest-level building blocks of the library: FFT-based convolution operators that turn an O(N · K) spatial convolution into an O(N log N) frequency-domain product. They are the workhorses behind every subquadratic mixer in the library (Hyena, CKConv), and are kept here as plain functions — no nn.Module state, no learned parameters — so that higher-level modules can compose them freely.
If you are reading the paper alongside this codebase, this is the file to start with.
Why FFT convolution?#
A standard spatial convolution between an input x of length N and a kernel k of length K,
costs O(N · K) per channel. When K is small (e.g. a 3×3 image kernel) that is fine. When K is comparable to N — the regime Hyena-style models live in, where each layer’s effective receptive field can span the whole input — the spatial cost grows quadratically with sequence length.
The convolution theorem lets us replace the spatial convolution with an element-wise product in the frequency domain:
The two FFTs and the inverse each cost O(N log N), the element-wise product is O(N), and the total cost is independent of kernel size. That is what makes “global-kernel” convolutional sequence models subquadratic.
Two flavours show up throughout the folder:
Flavour |
What it computes |
When to use |
|---|---|---|
Linear ( |
Standard convolution, zero-padded so no wrap-around occurs, then cropped to “same” size. |
Default choice — matches |
Circular ( |
Periodic convolution where the kernel wraps around the input boundary. |
When you want global mixing under periodic boundary conditions, or when input and kernel are the same size (no padding needed → cheaper). |
File map#
File |
Precision |
Conv type |
Channel mixing |
When you’d reach for it |
|---|---|---|---|---|
fp32 |
linear |
depthwise |
The default. 1D/2D/3D, causal & non-causal. |
|
fp32 |
circular |
depthwise |
Periodic boundaries (e.g. PDEs, ARC grids), or when |
|
fp32 |
per-axis BC |
depthwise |
Mixed boundaries — periodic on some spatial axes, zero-padded on others (e.g. Well’s |
|
fp32 |
linear |
depthwise |
Memory-constrained training; processes channels in chunks. Has a global flag so models can opt in transparently. |
|
fp32 |
linear |
depthwise |
Wraps optional fused CUDA kernels ( |
|
fp32 |
direct causal |
depthwise |
Non-FFT 1D causal kernels ( |
mixed_boundary_conditions.md describes the per-axis boundary-condition support (periodic on some spatial axes, zero-padded on others) used by the Well PDE datasets.
Naming convention#
Every function name encodes its contract:
[causal_] fftconv {1d|2d|3d} _ fp32 _ {bhl|blh} [_w_reshape] [_chunked]
Part |
Meaning |
|---|---|
|
Output at position |
|
Spatial rank. |
|
Internal compute precision. The output dtype always matches |
|
Memory layout. |
|
Wrapper that accepts BLH input, internally reshapes to BHL (faster), and reshapes back. The recommended entry point for channels-last callers. |
|
Processes channels in groups to reduce peak GPU memory. |
So causal_fftconv1d_fp32_bhl_w_reshape is: causal 1D FFT conv, fp32 internal, accepts channels-last input, internally uses the channels-first kernel.
The CUDA-accelerated wrappers in fftconv_custom.py drop the _fp32_ token because the underlying kernel manages its own precision internally — so the same name in fftconv_custom is causal_fftconv1d_bhl_w_reshape. The direct-conv wrappers in causal_conv1d_custom.py (causal_conv1d, b2b_causal_conv1d) do not follow this scheme because they are thin pass-throughs to the upstream API; see their docstrings for shapes.
Shape conventions#
Everything in this folder follows two layouts. Pick whichever matches your surrounding module:
BHL (channels-first):
x: [B, H, *spatial],kernel: [1|B, H, *K_dims]. Standard fortorch.nn.ConvNd-style modules. Faster under the hood because FFT runs on contiguous spatial axes without a transpose.BLH (channels-last):
x: [B, *spatial, H],kernel: [1|B, *K_dims, H]. Common in transformer-style code. Use the_w_reshapevariants.
The kernel’s leading dim is either 1 (shared kernel across the batch — the standard depthwise case) or B (per-sample kernel, e.g. FiLM-conditioned Hyena where each sample gets its own kernel).
The shortcut term#
Every operator accepts an optional shortcut: [H] tensor and computes
i.e. a per-channel residual scale. This is not a generic skip connection — it fuses a specific algebraic shortcut that shows up repeatedly in Hyena-style gating, saving a separate kernel launch. Pass None if you don’t need it.
Choosing a function: a decision tree#
Do I need periodic boundaries?
Yes →
circular_fftconv*. The kernel wraps around the input; useful for PDE-like signals or whenever the input is naturally periodic.No →
fftconv*. The default.
Is my model causal (1D sequence)?
Yes → use the
causal_*variant. Slightly more padding (L + Kinstead ofL + K/2), but enforces no information leak from the future.No → use the non-causal variant. Cheaper, since you only pad by
K/2.
What’s my hidden layout?
Channels-first (
[B, H, …]) → use_bhldirectly.Channels-last (
[B, …, H]) → use_bhl_w_reshape. Benchmarks show this is faster than a true_blhop because the FFT runs on contiguous spatial axes.
Am I OOMing?
Try
fftconv_chunked— splits the channel dim into groups to cap peak memory. Default chunk size 128 gives ~26% memory savings for ~11% overhead.
Is there a fused CUDA kernel for my shape?
2D non-causal or 1D causal long-conv →
fftconv_custom.pyexposes the upstream fused FFT kernels (fft_conv2d,fft_causal_conv1d) through the same API. Wire in via thefft_backend="subq_ops"flag onCKConvND. The 1D path requiresdata_dim=1, is_causal=True; the 2D path requiresdata_dim=2, is_causal=False.1D causal short conv (typical short_conv slot in a Hyena block) →
causal_conv1d_custom.pyexposescausal_conv1ddirectly, andnvsubquadratic.modules.subq_ops_causal_conv1d.SubqOpsCausalConv1dwraps it as a depthwisenn.Conv1d-compatible module.1D causal fused proj+gate+mixer+gate block →
b2b_causal_conv1dincausal_conv1d_custom.py. Not yet wired into aHyenavariant; exposed as a building block.
Numerical notes#
All operators accept any input dtype but cast to fp32 before the FFT. The output is returned in the original dtype of
x— no need for a manual cast on the caller side.The fp32 ops are correct for any input range.
The non-causal linear ops match a standard
torch.nn.ConvNd(padding='same')up to floating-point rounding. The circular ops matchtorch.nn.functional.conv*dafter a circular pad. Both are exercised intests/.