fftconv3d_fp32_blh#
- fftconv3d_fp32_blh(x, kernel, shortcut=None)#
3D FFT convolution with optional shortcut. When shortcut provided, then the output is given by shortcut(x) + conv(x, kernel).
Accepts any input dtype. Internally casts
xandkernelto float32 for numerical stability and returns the result in the original dtype ofx.- Parameters:
x (torch.Tensor) – Input tensor of shape (batch_size, X_in, Y_in, Z_in, hidden_dim).
kernel (torch.Tensor) – Kernel tensor of shape (1, K_x, K_y, K_z, hidden_dim).
shortcut (torch.Tensor | None, optional) – Optional shortcut tensor of shape (hidden_dim). Defaults to None.
- Returns:
Output tensor of shape (batch_size, X_in, Y_in, Z_in, hidden_dim), in the original dtype of
x.- Return type: