fftconv3d_fp32_bhl#

fftconv3d_fp32_bhl(x, kernel, shortcut=None)#

3D FFT convolution with optional shortcut, for inputs with layout (batch, hidden, depth, height, width).

When shortcut provided, then the output is given by shortcut(x) + conv(x, kernel).

Accepts any input dtype. Internally casts x and kernel to float32 for numerical stability and returns the result in the original dtype of x.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch_size, hidden_dim, X_in, Y_in, Z_in).

  • kernel (torch.Tensor) – Kernel tensor of shape (1, hidden_dim, K_x, K_y, K_z).

  • shortcut (torch.Tensor | None, optional) – Optional shortcut tensor of shape (hidden_dim). Defaults to None.

Returns:

Output tensor of shape (batch_size, hidden_dim, X_in, Y_in, Z_in), in the original dtype of x.

Return type:

torch.Tensor