causal_fftconv1d_fp32_blh#

causal_fftconv1d_fp32_blh(x, kernel, shortcut=None)#

Causal 1D FFT convolution (BLH layout, channels-last) with optional shortcut.

Computes \(y[n] = \sum_{m=0}^{n} x[n-m]\, k[m]\) per channel via the FFT path:

\[y = \mathcal{F}^{-1}\bigl(\mathcal{F}_F(x) \odot \mathcal{F}_F(k)\bigr)[\,:L\,]\]

where \(F = \min(L + K, 2L)\) is the zero-pad length that prevents wrap-around. Causality is enforced implicitly by keeping only the leading L samples of the inverse FFT (no future taps leak into position n).

When shortcut is provided, the per-channel residual is added:

\[y \leftarrow y + \text{shortcut} \odot x\]
Parameters:
  • x (Tensor) – Input tensor of shape [batch_size, seq_len, hidden_dim].

  • kernel (Tensor) – Kernel tensor of shape [1|B, kernel_len, hidden_dim]. The leading dim is 1 for a shared kernel or B for FiLM-style per-sample kernels.

  • shortcut (Tensor | None) – Optional [hidden_dim] per-channel residual scale.

Returns:

Output tensor of shape [batch_size, seq_len, hidden_dim] in the original dtype of x.

Return type:

Tensor