circular_fftconv2d_fp32_bhl#
- circular_fftconv2d_fp32_bhl(
- x,
- kernel,
- shortcut=None,
- use_phase_shift=True,
2D circular FFT convolution with optional shortcut (BHL layout).
Spatially, the circular convolution is
circular_conv(x, kernel) = roll(ifft2(fft2(x) * fft2(kernel))). Equivalently, by shifting via a frequency-domain phase ramp:circular_conv(x, kernel) = ifft2(fft2(x) * fft2(kernel) * phase_ramp), wherephase_rampis a complex tensor of shape(X_in, Y_in // 2 + 1)encoding the alignment shift. This avoids the spatial roll, making the convolution faster and more memory-efficient.- Parameters:
x (Tensor) – Tensor of shape (B, H, X_in, Y_in), any dtype (internally cast to float32).
kernel (Tensor) – Tensor of shape (1|B, H, K_x, K_y), any dtype (internally cast to float32).
shortcut (Tensor | None) – Optional tensor of shape (H,). Never cast; the multiply auto-upcasts.
use_phase_shift (bool) – If True, apply alignment via frequency-domain phase ramp. If False, align via spatial torch.roll after iFFT.
- Returns:
Tensor of shape (B, H, X_in, Y_in), in the original dtype of
x.- Return type:
Notes
When
use_phase_shift=True, the phase ramp is retrieved from a global, module-level LRU cache (shared across all layers/callers within the same Python process). This avoids recomputing the ramp for repeated sizes and shifts on the same device/dtype.