Ops#
Low-level convolution primitives. Pure-PyTorch reference implementations
double as the spec the CUDA kernels must match; the
subquadratic_ops_torch wrappers are the production path on GPU.
FFT convolutions (reference fp32)#
Use these for correctness and as the spec for the CUDA kernels below.
|
Non-causal 1D FFT convolution (BLH layout, channels-last) with optional shortcut. |
|
2D FFT convolution with optional shortcut. |
|
3D FFT convolution with optional shortcut. |
|
Causal 1D FFT convolution (BLH layout, channels-last) with optional shortcut. |
|
1D FFT convolution with optional shortcut, for inputs with layout (batch, hidden, length). |
|
2D FFT convolution with optional shortcut, for inputs with layout (batch, hidden, height, width). |
|
3D FFT convolution with optional shortcut, for inputs with layout (batch, hidden, depth, height, width). |
|
1D FFT convolution with optional shortcut, for inputs with layout (batch, hidden, length). |
FFT convolutions (CUDA-accelerated)#
Drop-in wrappers around the subquadratic_ops_torch fused CUDA kernels.
2D non-causal and 1D causal long-conv variants share the same API as the
fp32 reference ops above.
|
Alias for |
|
2D FFT convolution via subq_ops CUDA kernel, BHL layout |
|
2D FFT convolution via subq_ops for BLH inputs |
|
Alias for |
|
1D causal FFT convolution via subq_ops CUDA kernel, BHL layout |
|
1D causal FFT convolution via subq_ops for BLH inputs |
Direct 1D causal convolutions (CUDA-accelerated)#
Non-FFT CUDA kernels for short and fused 1D causal convolutions. Useful for small kernel sizes (where FFT overhead dominates) and as building blocks for fused Hyena variants.
|
Depthwise causal 1D conv via the subq_ops CUDA kernel. |
|
Back-to-back fused causal 1D conv via the subq_ops CUDA kernel. |
Circular FFT convolutions#
Periodic-boundary FFT convolutions for global mixing without zero padding.
|
1D circular FFT convolution with optional shortcut (BHL layout). |
|
2D circular FFT convolution with optional shortcut (BHL layout). |
|
3D circular FFT convolution with optional shortcut (BHL layout). |
Chunking utilities#
Helpers to bound the FFT working-set memory by processing along the sequence axis in chunks.
|
Enable chunked FFT conv globally, as decorator, or as context manager. |
|
Context manager to temporarily enable/disable chunked FFT conv. |
|
Set the default chunk size for chunked FFT convolutions. |
Return the default chunk size for chunked FFT convolutions. |
Mixed boundary-condition FFT convolutions#
FFT convolutions with per-axis boundary conditions — periodic on some
spatial axes, zero-padded on others. See
Mixed Boundary-Condition FFT Convolution for the per-axis algorithm and the
fft_padding API.
|
1D mixed-BC FFT convolution (BHL layout). |
|
2D mixed-BC FFT convolution (BHL layout). |
|
3D mixed-BC FFT convolution (BHL layout). |
|
1D mixed-BC FFT conv wrapper for BLH layout (batch, length, hidden). |
|
2D mixed-BC FFT conv wrapper for BLH layout (batch, X, Y, hidden). |
|
3D mixed-BC FFT conv wrapper for BLH layout (batch, X, Y, Z, hidden). |
|
Memory-efficient 1D mixed-BC FFT conv (BHL) via channel chunking. |
|
Memory-efficient 2D mixed-BC FFT conv (BHL) via channel chunking. |
|
Memory-efficient 3D mixed-BC FFT conv (BHL) via channel chunking. |