Ops#

Low-level convolution primitives. Pure-PyTorch reference implementations double as the spec the CUDA kernels must match; the subquadratic_ops_torch wrappers are the production path on GPU.

FFT convolutions (reference fp32)#

Use these for correctness and as the spec for the CUDA kernels below.

fftconv1d_fp32_blh(x, kernel[, shortcut])

Non-causal 1D FFT convolution (BLH layout, channels-last) with optional shortcut.

fftconv2d_fp32_blh(x, kernel[, shortcut])

2D FFT convolution with optional shortcut.

fftconv3d_fp32_blh(x, kernel[, shortcut])

3D FFT convolution with optional shortcut.

causal_fftconv1d_fp32_blh(x, kernel[, shortcut])

Causal 1D FFT convolution (BLH layout, channels-last) with optional shortcut.

fftconv1d_fp32_bhl(x, kernel[, shortcut])

1D FFT convolution with optional shortcut, for inputs with layout (batch, hidden, length).

fftconv2d_fp32_bhl(x, kernel[, shortcut])

2D FFT convolution with optional shortcut, for inputs with layout (batch, hidden, height, width).

fftconv3d_fp32_bhl(x, kernel[, shortcut])

3D FFT convolution with optional shortcut, for inputs with layout (batch, hidden, depth, height, width).

causal_fftconv1d_fp32_bhl(x, kernel[, shortcut])

1D FFT convolution with optional shortcut, for inputs with layout (batch, hidden, length).

FFT convolutions (CUDA-accelerated)#

Drop-in wrappers around the subquadratic_ops_torch fused CUDA kernels. 2D non-causal and 1D causal long-conv variants share the same API as the fp32 reference ops above.

fftconv2d_blh(x, kernel[, shortcut])

Alias for fftconv2d_bhl_w_reshape().

fftconv2d_bhl(x, kernel[, shortcut])

2D FFT convolution via subq_ops CUDA kernel, BHL layout [B, H, X, Y].

fftconv2d_bhl_w_reshape(x, kernel[, shortcut])

2D FFT convolution via subq_ops for BLH inputs [B, X, Y, H].

causal_fftconv1d_blh(x, kernel[, shortcut])

Alias for causal_fftconv1d_bhl_w_reshape().

causal_fftconv1d_bhl(x, kernel[, shortcut])

1D causal FFT convolution via subq_ops CUDA kernel, BHL layout [B, H, L].

causal_fftconv1d_bhl_w_reshape(x, kernel[, ...])

1D causal FFT convolution via subq_ops for BLH inputs [B, L, H].

Direct 1D causal convolutions (CUDA-accelerated)#

Non-FFT CUDA kernels for short and fused 1D causal convolutions. Useful for small kernel sizes (where FFT overhead dominates) and as building blocks for fused Hyena variants.

causal_conv1d(x, weight[, bias, activation])

Depthwise causal 1D conv via the subq_ops CUDA kernel.

b2b_causal_conv1d(x, weight_proj, ...)

Back-to-back fused causal 1D conv via the subq_ops CUDA kernel.

Circular FFT convolutions#

Periodic-boundary FFT convolutions for global mixing without zero padding.

circular_fftconv1d_fp32_bhl(x, kernel[, ...])

1D circular FFT convolution with optional shortcut (BHL layout).

circular_fftconv2d_fp32_bhl(x, kernel[, ...])

2D circular FFT convolution with optional shortcut (BHL layout).

circular_fftconv3d_fp32_bhl(x, kernel[, ...])

3D circular FFT convolution with optional shortcut (BHL layout).

Chunking utilities#

Helpers to bound the FFT working-set memory by processing along the sequence axis in chunks.

enable_chunking([module_or_flag, chunk_size])

Enable chunked FFT conv globally, as decorator, or as context manager.

chunking_enabled([enabled, chunk_size])

Context manager to temporarily enable/disable chunked FFT conv.

set_default_chunk_size(chunk_size)

Set the default chunk size for chunked FFT convolutions.

get_default_chunk_size()

Return the default chunk size for chunked FFT convolutions.

Mixed boundary-condition FFT convolutions#

FFT convolutions with per-axis boundary conditions — periodic on some spatial axes, zero-padded on others. See Mixed Boundary-Condition FFT Convolution for the per-axis algorithm and the fft_padding API.

mixed_fftconv1d_fp32_bhl(x, kernel, periodic)

1D mixed-BC FFT convolution (BHL layout).

mixed_fftconv2d_fp32_bhl(x, kernel, periodic)

2D mixed-BC FFT convolution (BHL layout).

mixed_fftconv3d_fp32_bhl(x, kernel, periodic)

3D mixed-BC FFT convolution (BHL layout).

mixed_fftconv1d_fp32_bhl_w_reshape(x, ...[, ...])

1D mixed-BC FFT conv wrapper for BLH layout (batch, length, hidden).

mixed_fftconv2d_fp32_bhl_w_reshape(x, ...[, ...])

2D mixed-BC FFT conv wrapper for BLH layout (batch, X, Y, hidden).

mixed_fftconv3d_fp32_bhl_w_reshape(x, ...[, ...])

3D mixed-BC FFT conv wrapper for BLH layout (batch, X, Y, Z, hidden).

mixed_fftconv1d_fp32_bhl_chunked(x, kernel, ...)

Memory-efficient 1D mixed-BC FFT conv (BHL) via channel chunking.

mixed_fftconv2d_fp32_bhl_chunked(x, kernel, ...)

Memory-efficient 2D mixed-BC FFT conv (BHL) via channel chunking.

mixed_fftconv3d_fp32_bhl_chunked(x, kernel, ...)

Memory-efficient 3D mixed-BC FFT conv (BHL) via channel chunking.