circular_fftconv3d_fp32_bhl#

circular_fftconv3d_fp32_bhl(
x,
kernel,
shortcut=None,
use_phase_shift=True,
)#

3D circular FFT convolution with optional shortcut (BHL layout).

Circular convolution#

  • Spatially, circular (periodic) convolution is circular_conv(x, k) = roll(irfftn(rfftn(x) * rfftn(k))).

  • To avoid the explicit spatial roll, we apply an equivalent frequency-domain phase ramp: circular_conv(x, k) = irfftn(rfftn(x) * rfftn(k) * phase_ramp), where phase_ramp has shape [X, Y, Z//2 + 1] and encodes the integer shifts.

Layout and shapes#

  • Layout: BHL ([batch, hidden, X, Y, Z])

  • Inputs: - x: [B, H, X, Y, Z] - kernel: [1|B, H, Kx, Ky, Kz]

  • Output: - y: [B, H, X, Y, Z]

Alignment and shifts#

  • We align to the “same” output by shifting with: shift_x = -((Kx - 1)//2), shift_y = -((Ky - 1)//2), shift_z = -((Kz - 1)//2).

  • If use_phase_shift=True, we multiply the kernel spectrum by the cached 3D phase ramp in frequency domain. Otherwise we roll the spatial output after the inverse transform.

Shortcut#

  • Optional shortcut: [H] scales the input per-channel and is added to the convolution output: y += shortcut * x.

Caching#

  • The 3D phase ramp is retrieved from a global, module-level LRU cache shared across all layers/callers within the same process.

param x:

[B, H, X, Y, Z], any dtype (internally cast to float32).

type x:

Tensor

param kernel:

[1|B, H, Kx, Ky, Kz], any dtype (internally cast to float32).

type kernel:

Tensor

param shortcut:

Optional [H] per-channel residual scale. Never cast; the multiply auto-upcasts.

type shortcut:

Tensor | None

param use_phase_shift:

Use frequency-domain shift if True; else spatial roll.

type use_phase_shift:

bool

returns:

[B, H, X, Y, Z], in the original dtype of x.

rtype:

Tensor

Parameters:
Return type:

Tensor