circular_fftconv1d_fp32_bhl#

circular_fftconv1d_fp32_bhl(
x,
kernel,
shortcut=None,
use_phase_shift=True,
)#

1D circular FFT convolution with optional shortcut (BHL layout).

Circular convolution#

  • Spatially, circular (periodic) convolution is circular_conv(x, k) = roll(irfft(rfft(x) * rfft(k))).

  • To avoid the explicit spatial roll, we apply an equivalent frequency-domain phase ramp: circular_conv(x, k) = irfft(rfft(x) * rfft(k) * phase_ramp), where phase_ramp has shape [L//2 + 1] and encodes the integer shift.

Layout and shapes#

  • Layout: BHL ([batch, hidden, length])

  • Inputs: - x: [B, H, L] - kernel: [1|B, H, K]

  • Output: - y: [B, H, L]

Alignment and shifts#

  • We align to the “same” output by shifting with shift = -((K - 1) // 2).

  • If use_phase_shift=True, we multiply the kernel spectrum by the cached phase ramp in frequency domain. Otherwise we roll the spatial output by shift after the inverse transform.

Shortcut#

  • Optional shortcut: [H] scales the input per-channel and is added to the convolution output: y += shortcut * x.

Caching#

  • The phase ramp is retrieved from a global, module-level LRU cache shared across all layers/callers within the same process.

param x:

[B, H, L], any dtype (internally cast to float32).

type x:

Tensor

param kernel:

[1|B, H, K], any dtype (internally cast to float32).

type kernel:

Tensor

param shortcut:

Optional [H] per-channel residual scale. Never cast; the multiply auto-upcasts.

type shortcut:

Tensor | None

param use_phase_shift:

Use frequency-domain shift if True; else spatial roll.

type use_phase_shift:

bool

returns:

[B, H, L], in the original dtype of x.

rtype:

Tensor

Parameters:
Return type:

Tensor