nvsubquadratic.ops — FFT convolution primitives#

This folder contains the lowest-level building blocks of the library: FFT-based convolution operators that turn an O(N · K) spatial convolution into an O(N log N) frequency-domain product. They are the workhorses behind every subquadratic mixer in the library (Hyena, CKConv), and are kept here as plain functions — no nn.Module state, no learned parameters — so that higher-level modules can compose them freely.

If you are reading the paper alongside this codebase, this is the file to start with.


Why FFT convolution?#

A standard spatial convolution between an input x of length N and a kernel k of length K,

\[ y[n] = \sum_{m} x[n - m] \cdot k[m] \]

costs O(N · K) per channel. When K is small (e.g. a 3×3 image kernel) that is fine. When K is comparable to N — the regime Hyena-style models live in, where each layer’s effective receptive field can span the whole input — the spatial cost grows quadratically with sequence length.

The convolution theorem lets us replace the spatial convolution with an element-wise product in the frequency domain:

\[ y = \mathcal{F}^{-1}\!\bigl( \mathcal{F}(x) \odot \mathcal{F}(k) \bigr) \]

The two FFTs and the inverse each cost O(N log N), the element-wise product is O(N), and the total cost is independent of kernel size. That is what makes “global-kernel” convolutional sequence models subquadratic.

Two flavours show up throughout the folder:

Flavour

What it computes

When to use

Linear (fftconv*)

Standard convolution, zero-padded so no wrap-around occurs, then cropped to “same” size.

Default choice — matches torch.nn.ConvNd semantics.

Circular (circular_fftconv*)

Periodic convolution where the kernel wraps around the input boundary.

When you want global mixing under periodic boundary conditions, or when input and kernel are the same size (no padding needed → cheaper).


File map#

File

Precision

Conv type

Channel mixing

When you’d reach for it

fftconv.py

fp32

linear

depthwise

The default. 1D/2D/3D, causal & non-causal.

circular_fftconv.py

fp32

circular

depthwise

Periodic boundaries (e.g. PDEs, ARC grids), or when K = N so padding is wasteful.

mixed_fftconv.py

fp32

per-axis BC

depthwise

Mixed boundaries — periodic on some spatial axes, zero-padded on others (e.g. Well’s rayleigh_benard, viscoelastic_instability, turbulent_radiative_layer). Routes to the existing linear/circular ops in the all-False/all-True cases.

fftconv_chunked.py

fp32

linear

depthwise

Memory-constrained training; processes channels in chunks. Has a global flag so models can opt in transparently.

fftconv_custom.py

fp32

linear

depthwise

Wraps optional fused CUDA kernels (subquadratic_ops_torch.fft_conv2d for 2D non-causal, fft_causal_conv1d for 1D causal) behind the same API as fftconv.py.

causal_conv1d_custom.py

fp32

direct causal

depthwise

Non-FFT 1D causal kernels (causal_conv1d short conv, b2b_causal_conv1d fused proj-gate-mixer-gate). Use for kernels short enough that FFT overhead dominates, or as a fused-Hyena building block.

mixed_boundary_conditions.md describes the per-axis boundary-condition support (periodic on some spatial axes, zero-padded on others) used by the Well PDE datasets.


Naming convention#

Every function name encodes its contract:

[causal_] fftconv {1d|2d|3d} _ fp32 _ {bhl|blh} [_w_reshape] [_chunked]

Part

Meaning

causal_

Output at position n only sees inputs at positions n. 1D only.

1d / 2d / 3d

Spatial rank.

fp32

Internal compute precision. The output dtype always matches x.dtype regardless.

bhl / blh

Memory layout. bhl = channels-first ([B, H, *spatial]). blh = channels-last ([B, *spatial, H]).

_w_reshape

Wrapper that accepts BLH input, internally reshapes to BHL (faster), and reshapes back. The recommended entry point for channels-last callers.

_chunked

Processes channels in groups to reduce peak GPU memory.

So causal_fftconv1d_fp32_bhl_w_reshape is: causal 1D FFT conv, fp32 internal, accepts channels-last input, internally uses the channels-first kernel.

The CUDA-accelerated wrappers in fftconv_custom.py drop the _fp32_ token because the underlying kernel manages its own precision internally — so the same name in fftconv_custom is causal_fftconv1d_bhl_w_reshape. The direct-conv wrappers in causal_conv1d_custom.py (causal_conv1d, b2b_causal_conv1d) do not follow this scheme because they are thin pass-throughs to the upstream API; see their docstrings for shapes.


Shape conventions#

Everything in this folder follows two layouts. Pick whichever matches your surrounding module:

  • BHL (channels-first): x: [B, H, *spatial], kernel: [1|B, H, *K_dims]. Standard for torch.nn.ConvNd-style modules. Faster under the hood because FFT runs on contiguous spatial axes without a transpose.

  • BLH (channels-last): x: [B, *spatial, H], kernel: [1|B, *K_dims, H]. Common in transformer-style code. Use the _w_reshape variants.

The kernel’s leading dim is either 1 (shared kernel across the batch — the standard depthwise case) or B (per-sample kernel, e.g. FiLM-conditioned Hyena where each sample gets its own kernel).

The shortcut term#

Every operator accepts an optional shortcut: [H] tensor and computes

\[ y \leftarrow y + \mathrm{shortcut} \odot x \]

i.e. a per-channel residual scale. This is not a generic skip connection — it fuses a specific algebraic shortcut that shows up repeatedly in Hyena-style gating, saving a separate kernel launch. Pass None if you don’t need it.


Choosing a function: a decision tree#

  1. Do I need periodic boundaries?

    • Yes → circular_fftconv*. The kernel wraps around the input; useful for PDE-like signals or whenever the input is naturally periodic.

    • No → fftconv*. The default.

  2. Is my model causal (1D sequence)?

    • Yes → use the causal_* variant. Slightly more padding (L + K instead of L + K/2), but enforces no information leak from the future.

    • No → use the non-causal variant. Cheaper, since you only pad by K/2.

  3. What’s my hidden layout?

    • Channels-first ([B, H, …]) → use _bhl directly.

    • Channels-last ([B, …, H]) → use _bhl_w_reshape. Benchmarks show this is faster than a true _blh op because the FFT runs on contiguous spatial axes.

  4. Am I OOMing?

    • Try fftconv_chunked — splits the channel dim into groups to cap peak memory. Default chunk size 128 gives ~26% memory savings for ~11% overhead.

  5. Is there a fused CUDA kernel for my shape?

    • 2D non-causal or 1D causal long-conv → fftconv_custom.py exposes the upstream fused FFT kernels (fft_conv2d, fft_causal_conv1d) through the same API. Wire in via the fft_backend="subq_ops" flag on CKConvND. The 1D path requires data_dim=1, is_causal=True; the 2D path requires data_dim=2, is_causal=False.

    • 1D causal short conv (typical short_conv slot in a Hyena block) → causal_conv1d_custom.py exposes causal_conv1d directly, and nvsubquadratic.modules.subq_ops_causal_conv1d.SubqOpsCausalConv1d wraps it as a depthwise nn.Conv1d-compatible module.

    • 1D causal fused proj+gate+mixer+gate block → b2b_causal_conv1d in causal_conv1d_custom.py. Not yet wired into a Hyena variant; exposed as a building block.


Numerical notes#

  • All operators accept any input dtype but cast to fp32 before the FFT. The output is returned in the original dtype of x — no need for a manual cast on the caller side.

  • The fp32 ops are correct for any input range.

  • The non-causal linear ops match a standard torch.nn.ConvNd(padding='same') up to floating-point rounding. The circular ops match torch.nn.functional.conv*d after a circular pad. Both are exercised in tests/.