CKConvND#

class CKConvND(
data_dim,
hidden_dim,
kernel_cfg,
mask_cfg,
grid_type,
fft_padding,
is_causal=False,
use_chunked_fftconv=False,
fft_backend='torch_fft',
)#

Bases: Module

N-dimensional Continuous Kernel Convolution (CKConv) operator.

CKConvND implements the CKConv operator (Romero et al., arXiv:2102.02611) generalised to arbitrary spatial rank D ∈ {1, 2, 3}. The convolutional kernel is not stored as an explicit lookup table; instead it is produced on the fly by a small MLP evaluated on a continuous positional grid:

k_θ(p) = MLP_θ(pos_enc(p)),   p ∈ [-1, 1]^D

The convolution with the input signal x is then computed in the frequency domain:

y = IFFT( FFT(x) ⊙ FFT(k_θ) )  +  shortcut ⊙ x

at O(N log N) cost per channel, where N = prod(spatial_dims).

The MLP and its positional encoding are provided through kernel_cfg (typically a SIRENKernelND or RandomFourierKernelND lazy config from nvsubquadratic.modules.kernels_nd). An optional attenuation mask mask_cfg (e.g. GaussianModulationND) is applied to the kernel values after the MLP forward pass to restrict the effective receptive field at initialisation.

Boundary conditions are controlled jointly by fft_padding and grid_type:

  • fft_padding="zero", grid_type="double": standard linear convolution with “same” output size. The kernel spans the full input (double-grid: 2*N points) and is zero-padded before the FFT.

  • fft_padding="circular", grid_type="single": periodic convolution, kernel size == input size (single-grid: (N+1)//2 grid points → kernel of length N after the MLP).

  • fft_padding=["circular", "zero"], grid_type=None: per-axis mixed boundary conditions; grid_type is auto-derived per axis.

Context parallelism: when cp_group is supplied in forward, the kernel is sliced along the channel dimension to match the local slice of the input, and the shortcut parameter is sliced accordingly. Causal mode is not verified to be correct under CP.

Parameters:
  • data_dim (int)

  • hidden_dim (int)

  • kernel_cfg (LazyConfig)

  • mask_cfg (LazyConfig)

  • grid_type (Literal['double', 'single'] | None)

  • fft_padding (Literal['zero', 'circular'] | str | Sequence[str])

  • is_causal (bool)

  • use_chunked_fftconv (bool)

  • fft_backend (Literal['torch_fft', 'subq_ops'])

data_dim#

Spatial rank of the input (1 for sequences, 2 for images, 3 for volumes).

Type:

int

hidden_dim#

Number of channels C processed by this operator.

Type:

int

fft_padding#

Boundary condition specification as supplied by the caller. The normalised per-axis representation is in _periodic_per_axis.

Type:

str or Sequence[str]

is_causal#

Whether the operator enforces causal (past-only) convolution. Only valid when data_dim=1.

Type:

bool

use_chunked_fftconv#

Whether to process channels in chunks to reduce peak GPU memory.

Type:

bool

fft_backend#

FFT backend identifier, "torch_fft" or "subq_ops".

Type:

str

grid_type#

Kernel grid size mode ("single", "double", or None for per-axis auto-derivation).

Type:

str or None

kernel#

Implicit kernel generator (produces (kernel_values, grid) on each forward call).

Type:

nn.Module

mask#

Attenuation mask applied to kernel values after generation. nn.Identity when no mask is configured.

Type:

nn.Module

shortcut#

Learnable per-channel skip-connection scale of shape (hidden_dim,). Fused into the FFT convolution op. Initialised with uniform(-1/√hidden_dim, 1/√hidden_dim) (Kaiming-uniform scale).

Type:

nn.Parameter

fftconv_fn#

Selected FFT convolution function for channels-last (BLH) input with internal reshape. Signature: (x, kernel, shortcut) output.

Type:

callable

fftconv_fn_bhl_input#

Selected FFT convolution function for channels-first (BHL) input. Signature: (x, kernel, shortcut) output.

Type:

callable

_periodic_per_axis#

Per-axis periodicity flags of length data_dim, derived from fft_padding.

Type:

tuple[bool, …]

_is_tuple_mode#

True when fft_padding was supplied as a sequence of mode strings (mixed-BC path).

Type:

bool

__init__(
data_dim,
hidden_dim,
kernel_cfg,
mask_cfg,
grid_type,
fft_padding,
is_causal=False,
use_chunked_fftconv=False,
fft_backend='torch_fft',
)#

Construct a CKConvND operator.

Validates the combination of fft_padding, grid_type, is_causal, and fft_backend, normalises the per-axis boundary-condition representation, adjusts kernel_cfg and mask_cfg to match the resolved kernel grid geometry, and selects the appropriate FFT convolution function pair.

Parameters:
  • data_dim (int) – Spatial rank of the input signal. 1 for 1D sequences, 2 for 2D images (H, W), 3 for 3D volumes (D, H, W).

  • hidden_dim (int) – Number of channels C. Determines the size of the learnable shortcut parameter and the channel dimension of every intermediate tensor.

  • kernel_cfg (LazyConfig) – Lazy config (LazyConfig) that instantiates the implicit kernel generator when resolved. Typically points to SIRENKernelND or RandomFourierKernelND. The kernel’s out_dim must equal hidden_dim; a mismatch will produce incorrect tensor shapes at runtime. The L_cache field, if present, is adjusted to match the resolved kernel grid size (single or double grid) before instantiation.

  • mask_cfg (LazyConfig) – Lazy config for the attenuation mask applied to the generated kernel values. Use torch.nn.Identity (or an empty identity config) for no masking. If the mask class accepts a grid_size parameter, it is set automatically based on the largest per-axis kernel size.

  • grid_type (Literal['double', 'single'] | None) –

    Relationship between the SIREN coordinate grid and the input spatial size on each axis.

    • "single": grid spans (N+1)//2 points → kernel size equals input size N (for periodic / circular conv).

    • "double": grid spans N points → kernel size is 2*N - 1 2*N (for zero-padded conv).

    • None: required when fft_padding is a per-axis list. The grid type is auto-derived per axis: "single" on periodic axes, "double" on non-periodic axes.

    Must not be None when fft_padding is a single mode string ("zero" or "circular").

  • fft_padding (Literal['zero', 'circular'] | str | ~collections.abc.Sequence[str]) –

    Boundary-condition mode. Accepted forms:

    • "zero": all axes zero-padded (linear “same” conv).

    • "circular": all axes periodic (wrap-around conv). Requires grid_type="single" and use_chunked_fftconv=False.

    • ["circular", "zero"] (list/tuple of mode strings, one per spatial axis, length must equal data_dim): per-axis mixed boundary conditions. Requires grid_type=None and fft_backend="torch_fft". Mode names are case-insensitive and whitespace-stripped.

    Must be "zero" (or an all-"zero" list) when is_causal=True.

  • is_causal (bool) – If True, enforce causal (past-only) convolution so that the output at position n only depends on inputs at positions 0, …, n. Only valid for data_dim=1. Incompatible with periodic fft_padding and with the per-axis list form of fft_padding. Default: False.

  • use_chunked_fftconv (bool) – If True, process channels in groups to reduce peak GPU memory from complex FFT intermediates. Typical savings: ~26% memory at ~11% compute overhead. Not supported with fft_padding="circular". Default: False.

  • fft_backend (Literal['torch_fft', 'subq_ops']) –

    Which FFT convolution backend to use.

    • "torch_fft" (default): torch.fft-based implementations in nvsubquadratic.ops.fftconv and related modules.

    • "subq_ops": optimised CUDA kernels from subquadratic_ops_torch. Supported configurations:

      • data_dim=2, is_causal=False, fft_padding="zero" (2D non-causal zero-padded conv). Per-sample (FiLM) batched kernel weights are supported on this path.

      • data_dim=1, is_causal=True (1D causal conv). The 1D causal CUDA kernel does not accept batched per-sample weights; FiLM conditioning is unsupported on this path.

      Does not support fp16 FFT, per-axis fft_padding, or data_dim=3.

Raises:
  • AssertionError – If fft_backend is not one of the recognised values, or if a constraint between grid_type, fft_padding, is_causal, and fft_backend is violated.

  • ValueError – If fft_padding is invalid (wrong type, wrong length, comma-separated string, boolean), if is_causal is combined with a per-axis padding list or periodic padding, or if the resolved (fft_padding, data_dim) combination has no registered FFT function.

extra_repr()#

Return a concise summary string for print(module) and repr(module).

Returns:

data_dim, hidden_dim, fft_padding, periodic_per_axis (only when in per-axis list mode), grid_type, is_causal, use_chunked_fftconv, and fft_backend.

Return type:

A human-readable string listing the key hyperparameters

flop_count(spatial_dims, inference=False)#

Count FLOPs for CKConv: kernel generation + FFT convolution.

Two phases.

Phase 1 — kernel generation (via SIREN MLP). Delegated to self.kernel.flop_count(grid_lens, inference). At inference=True without FiLM, the kernel is input-independent and can be precomputed, so this phase returns 0.

Phase 2 — FFT-based depthwise convolution with C = self.hidden_dim. The convolution runs in the frequency domain. Padded signal sizes Np_i depend on the padding mode:

  • "zero" non-causal (“same” mode): Np_i = min(s_i + (k_i + 1) // 2, 2 * s_i). Only half the kernel width of extra padding is needed because the output is centre-cropped back to the input size. Matches fftconv.py lines 624-628.

  • "zero" causal (1D only): Np_i = min(s_i + k_i, 2 * s_i). Full linear convolution length; the output is tail-cropped.

  • "circular": Np_i = s_i. Wrap-around, no extra padding.

A separable N-D FFT on a grid of size (Np_1, ..., Np_d) costs 5 * prod(Np) * sum(log2(Np_i)) real FLOPs per channel, from the radix-2 Cooley-Tukey decomposition (each butterfly ≈ 5 real FLOPs: 1 complex multiply = 4 real muls + 2 real adds, minus shared twiddle-factor optimisations). The implementation uses rfft (real-to-complex), which is ~2x cheaper than a full complex FFT; the 5N log N formula is a conservative upper bound consistent with vision-paper conventions.

Three FFTs are needed (forward of input, forward of kernel, inverse of the product). At inference=True without FiLM the kernel FFT is precomputed and cached, reducing to two FFTs.

Pointwise complex multiply in the frequency domain costs 6 * C * prod(Np) (4 real muls + 2 real adds for (a + bi)(c + di)). The shortcut (skip connection) costs C * prod(spatial_dims) elementwise multiplies.

Parameters:
  • spatial_dims (tuple[int, ...]) – Spatial dimensions of the input signal, e.g. (H, W) for a 2D image or (L,) for a 1D sequence. Must have length equal to self.data_dim.

  • inference (bool) – If True and the kernel has no FiLM conditioning, skip the kernel generation and kernel FFT FLOPs (both can be precomputed and cached at inference time).

Returns:

Total estimated FLOPs as an integer.

Return type:

int

apply_convolution(x, conv_kernel, shortcut, is_bhl_input)#

Apply the FFT-based depthwise convolution.

Dispatches to the pre-selected fftconv_fn or fftconv_fn_bhl_input depending on the memory layout of x. When is_bhl_input=True the kernel is first transposed from channels-last (B, *spatial, C) to channels-first (B, C, *spatial) to match the BHL-native FFT op.

The output y is computed as:

y = IFFT( FFT(x) ⊙ FFT(conv_kernel) ) + shortcut ⊙ x

The shortcut term is fused inside the FFT op (no extra kernel launch).

Parameters:
  • x (Tensor) –

    Input signal.

    • BLH layout (is_bhl_input=False): shape (B, *spatial, C) where C = self.hidden_dim.

    • BHL layout (is_bhl_input=True): shape (B, C, *spatial).

  • conv_kernel (Tensor) – Kernel values produced by self.kernel and optionally masked by self.mask. Always in channels-last (BLH) format on entry: shape (1_or_B, *kernel_spatial, C). kernel_spatial equals spatial on single-grid (circular) axes and 2*N - 1 on double-grid (zero-padded) axes, where N is the corresponding input spatial size. Transposed internally when is_bhl_input=True.

  • shortcut (Tensor) – Per-channel skip-connection scale, shape (C,). Typically self.shortcut or a CP-sliced view thereof.

  • is_bhl_input (bool) – If True, treat x as channels-first (B, C, *spatial) and use self.fftconv_fn_bhl_input. If False, treat x as channels-last (B, *spatial, C) and use self.fftconv_fn (which handles the reshape internally).

Returns:

(B, *spatial, C) when is_bhl_input=False, or (B, C, *spatial) when is_bhl_input=True.

Return type:

Output tensor in the same memory layout as the input x

forward(x, is_bhl_input=False, cp_group=None, **mixer_kwargs)#

Run the CKConv forward pass.

Generates the implicit kernel from the positional grid, optionally applies the attenuation mask, crops the kernel for causal mode, handles context-parallel channel slicing, and applies the FFT convolution with the shortcut term.

Computation (non-causal, no CP):

grid_lens = [(s+1)//2  if single-grid axis  else  s  for s in spatial_dims]
k_θ, grid = self.kernel(grid_lens, conditioning=conditioning)  # (1, *grid_lens, C)
k_θ       = self.mask(grid=grid, x=k_θ)                        # attenuation
y         = IFFT(FFT(x) ⊙ FFT(k_θ)) + shortcut ⊙ x           # FFT conv

For causal mode (1D only), k_θ is cropped to its causal (positive- lag) half before the FFT convolution:

k_θ = k_θ[..., kernel_len // 2 :, :]   # keep second half
Parameters:
  • x (Tensor) –

    Input signal tensor. Two supported layouts:

    • Channels-last (is_bhl_input=False, default): shape (B, *spatial, hidden_dim) where spatial has length self.data_dim.

    • Channels-first (is_bhl_input=True): shape (B, hidden_dim, *spatial).

  • is_bhl_input (bool) – If True, x is in channels-first (B, C, *spatial) layout. Default: False (channels-last).

  • cp_group (ProcessGroup) – Context-parallel process group. When provided and cp_group.size() > 1, the kernel and shortcut are sliced along the channel dimension to match the local channel slice held by this rank. The spatial slice of x is expected to have already been distributed by the caller. Causal mode is not verified to be correct under CP. Default: None (single-device / no CP).

  • **mixer_kwargs

    Additional keyword arguments forwarded to the kernel generator. The following key is recognised:

    • conditioning (torch.Tensor, shape (B, cond_dim)): conditioning vector for FiLM-enabled kernels such as SIRENKernelND with a film_cfg. Ignored (no-op) when the kernel has no FiLM generator.

Returns:

(B, *spatial, hidden_dim) when is_bhl_input=False, or (B, hidden_dim, *spatial) when is_bhl_input=True.

Return type:

Output tensor in the same memory layout as x

Raises:

ValueError – If cp_group is provided together with is_causal=True. This combination is explicitly rejected because it has not been verified for correctness — the causal kernel crop and CP channel slicing interact in ways that may silently leak future positions. Do not rely on the error being absent in future versions without re-verification.