mixed_fftconv2d_fp32_bhl#

mixed_fftconv2d_fp32_bhl(
x,
kernel,
periodic,
shortcut=None,
use_phase_shift=True,
)#

2D mixed-BC FFT convolution (BHL layout).

Each of the two spatial axes independently uses periodic (circular) or non-periodic (zero-padded “same”) boundary handling.

Parameters:
  • x (Tensor) – Input tensor of shape [B, H, X, Y] (any dtype, internally cast to fp32).

  • kernel (Tensor) – Kernel tensor of shape [1|B, H, K_x, K_y] (any dtype, cast to fp32).

  • periodic (Sequence[bool]) – Length-2 sequence (periodic_x, periodic_y).

  • shortcut (Tensor | None) – Optional per-channel scale [H] added as y += shortcut * x.

  • use_phase_shift (bool) – See mixed_fftconv1d_fp32_bhl().

Returns:

Tensor of shape [B, H, X, Y] in the original dtype of x.

Return type:

Tensor

Example

>>> import torch
>>> from nvsubquadratic.ops.mixed_fftconv import mixed_fftconv2d_fp32_bhl
>>> B, H, X, Y, Kx, Ky = 2, 64, 32, 64, 63, 127
>>> x = torch.randn(B, H, X, Y)
>>> kernel = torch.randn(1, H, Kx, Ky)
>>> # x-axis periodic, y-axis zero-padded
>>> y = mixed_fftconv2d_fp32_bhl(x, kernel, periodic=(True, False))
>>> y.shape
torch.Size([2, 64, 32, 64])