circular_fftconv2d_fp32_bhl#

circular_fftconv2d_fp32_bhl(
x,
kernel,
shortcut=None,
use_phase_shift=True,
)#

2D circular FFT convolution with optional shortcut (BHL layout).

Spatially, the circular convolution is circular_conv(x, kernel) = roll(ifft2(fft2(x) * fft2(kernel))). Equivalently, by shifting via a frequency-domain phase ramp: circular_conv(x, kernel) = ifft2(fft2(x) * fft2(kernel) * phase_ramp), where phase_ramp is a complex tensor of shape (X_in, Y_in // 2 + 1) encoding the alignment shift. This avoids the spatial roll, making the convolution faster and more memory-efficient.

Parameters:
  • x (Tensor) – Tensor of shape (B, H, X_in, Y_in), any dtype (internally cast to float32).

  • kernel (Tensor) – Tensor of shape (1|B, H, K_x, K_y), any dtype (internally cast to float32).

  • shortcut (Tensor | None) – Optional tensor of shape (H,). Never cast; the multiply auto-upcasts.

  • use_phase_shift (bool) – If True, apply alignment via frequency-domain phase ramp. If False, align via spatial torch.roll after iFFT.

Returns:

Tensor of shape (B, H, X_in, Y_in), in the original dtype of x.

Return type:

Tensor

Notes

When use_phase_shift=True, the phase ramp is retrieved from a global, module-level LRU cache (shared across all layers/callers within the same Python process). This avoids recomputing the ramp for repeated sizes and shifts on the same device/dtype.