circular_fftconv1d_fp32_bhl#
- circular_fftconv1d_fp32_bhl(
- x,
- kernel,
- shortcut=None,
- use_phase_shift=True,
1D circular FFT convolution with optional shortcut (BHL layout).
Circular convolution#
Spatially, circular (periodic) convolution is
circular_conv(x, k) = roll(irfft(rfft(x) * rfft(k))).To avoid the explicit spatial roll, we apply an equivalent frequency-domain phase ramp:
circular_conv(x, k) = irfft(rfft(x) * rfft(k) * phase_ramp), wherephase_ramphas shape[L//2 + 1]and encodes the integer shift.
Layout and shapes#
Layout: BHL (
[batch, hidden, length])Inputs: -
x: [B, H, L]-kernel: [1|B, H, K]Output: -
y: [B, H, L]
Alignment and shifts#
We align to the “same” output by shifting with
shift = -((K - 1) // 2).If
use_phase_shift=True, we multiply the kernel spectrum by the cached phase ramp in frequency domain. Otherwise we roll the spatial output byshiftafter the inverse transform.
Shortcut#
Optional
shortcut: [H]scales the input per-channel and is added to the convolution output:y += shortcut * x.
Caching#
The phase ramp is retrieved from a global, module-level LRU cache shared across all layers/callers within the same process.
- param x:
[B, H, L], any dtype (internally cast to float32).- type x:
Tensor
- param kernel:
[1|B, H, K], any dtype (internally cast to float32).- type kernel:
Tensor
- param shortcut:
Optional
[H]per-channel residual scale. Never cast; the multiply auto-upcasts.- type shortcut:
Tensor | None
- param use_phase_shift:
Use frequency-domain shift if True; else spatial roll.
- type use_phase_shift:
bool
- returns:
[B, H, L], in the original dtype ofx.- rtype:
Tensor