CKConvND#
- class CKConvND(
- data_dim,
- hidden_dim,
- kernel_cfg,
- mask_cfg,
- grid_type,
- fft_padding,
- is_causal=False,
- use_chunked_fftconv=False,
- fft_backend='torch_fft',
Bases:
ModuleN-dimensional Continuous Kernel Convolution (CKConv) operator.
CKConvND implements the CKConv operator (Romero et al., arXiv:2102.02611) generalised to arbitrary spatial rank D ∈ {1, 2, 3}. The convolutional kernel is not stored as an explicit lookup table; instead it is produced on the fly by a small MLP evaluated on a continuous positional grid:
k_θ(p) = MLP_θ(pos_enc(p)), p ∈ [-1, 1]^D
The convolution with the input signal x is then computed in the frequency domain:
y = IFFT( FFT(x) ⊙ FFT(k_θ) ) + shortcut ⊙ x
at O(N log N) cost per channel, where N = prod(spatial_dims).
The MLP and its positional encoding are provided through
kernel_cfg(typically aSIRENKernelNDorRandomFourierKernelNDlazy config fromnvsubquadratic.modules.kernels_nd). An optional attenuation maskmask_cfg(e.g.GaussianModulationND) is applied to the kernel values after the MLP forward pass to restrict the effective receptive field at initialisation.Boundary conditions are controlled jointly by
fft_paddingandgrid_type:fft_padding="zero", grid_type="double": standard linear convolution with “same” output size. The kernel spans the full input (double-grid:2*Npoints) and is zero-padded before the FFT.fft_padding="circular", grid_type="single": periodic convolution, kernel size == input size (single-grid:(N+1)//2grid points → kernel of length N after the MLP).fft_padding=["circular", "zero"], grid_type=None: per-axis mixed boundary conditions;grid_typeis auto-derived per axis.
Context parallelism: when
cp_groupis supplied inforward, the kernel is sliced along the channel dimension to match the local slice of the input, and theshortcutparameter is sliced accordingly. Causal mode is not verified to be correct under CP.- Parameters:
data_dim (int)
hidden_dim (int)
kernel_cfg (LazyConfig)
mask_cfg (LazyConfig)
grid_type (Literal['double', 'single'] | None)
fft_padding (Literal['zero', 'circular'] | str | Sequence[str])
is_causal (bool)
use_chunked_fftconv (bool)
fft_backend (Literal['torch_fft', 'subq_ops'])
Number of channels C processed by this operator.
- Type:
- fft_padding#
Boundary condition specification as supplied by the caller. The normalised per-axis representation is in
_periodic_per_axis.
- is_causal#
Whether the operator enforces causal (past-only) convolution. Only valid when
data_dim=1.- Type:
- grid_type#
Kernel grid size mode (
"single","double", orNonefor per-axis auto-derivation).- Type:
str or None
- kernel#
Implicit kernel generator (produces
(kernel_values, grid)on each forward call).- Type:
nn.Module
- mask#
Attenuation mask applied to kernel values after generation.
nn.Identitywhen no mask is configured.- Type:
nn.Module
- shortcut#
Learnable per-channel skip-connection scale of shape
(hidden_dim,). Fused into the FFT convolution op. Initialised withuniform(-1/√hidden_dim, 1/√hidden_dim)(Kaiming-uniform scale).- Type:
nn.Parameter
- fftconv_fn#
Selected FFT convolution function for channels-last (BLH) input with internal reshape. Signature:
(x, kernel, shortcut) → output.- Type:
callable
- fftconv_fn_bhl_input#
Selected FFT convolution function for channels-first (BHL) input. Signature:
(x, kernel, shortcut) → output.- Type:
callable
- _periodic_per_axis#
Per-axis periodicity flags of length
data_dim, derived fromfft_padding.
- _is_tuple_mode#
Truewhenfft_paddingwas supplied as a sequence of mode strings (mixed-BC path).- Type:
- __init__(
- data_dim,
- hidden_dim,
- kernel_cfg,
- mask_cfg,
- grid_type,
- fft_padding,
- is_causal=False,
- use_chunked_fftconv=False,
- fft_backend='torch_fft',
Construct a CKConvND operator.
Validates the combination of
fft_padding,grid_type,is_causal, andfft_backend, normalises the per-axis boundary-condition representation, adjustskernel_cfgandmask_cfgto match the resolved kernel grid geometry, and selects the appropriate FFT convolution function pair.- Parameters:
data_dim (int) – Spatial rank of the input signal.
1for 1D sequences,2for 2D images (H, W),3for 3D volumes (D, H, W).hidden_dim (int) – Number of channels C. Determines the size of the learnable
shortcutparameter and the channel dimension of every intermediate tensor.kernel_cfg (LazyConfig) – Lazy config (
LazyConfig) that instantiates the implicit kernel generator when resolved. Typically points toSIRENKernelNDorRandomFourierKernelND. The kernel’sout_dimmust equalhidden_dim; a mismatch will produce incorrect tensor shapes at runtime. TheL_cachefield, if present, is adjusted to match the resolved kernel grid size (single or double grid) before instantiation.mask_cfg (LazyConfig) – Lazy config for the attenuation mask applied to the generated kernel values. Use
torch.nn.Identity(or an empty identity config) for no masking. If the mask class accepts agrid_sizeparameter, it is set automatically based on the largest per-axis kernel size.grid_type (Literal['double', 'single'] | None) –
Relationship between the SIREN coordinate grid and the input spatial size on each axis.
"single": grid spans(N+1)//2points → kernel size equals input size N (for periodic / circular conv)."double": grid spansNpoints → kernel size is2*N - 1 ≈ 2*N(for zero-padded conv).None: required whenfft_paddingis a per-axis list. The grid type is auto-derived per axis:"single"on periodic axes,"double"on non-periodic axes.
Must not be
Nonewhenfft_paddingis a single mode string ("zero"or"circular").fft_padding (Literal['zero', 'circular'] | str | ~collections.abc.Sequence[str]) –
Boundary-condition mode. Accepted forms:
"zero": all axes zero-padded (linear “same” conv)."circular": all axes periodic (wrap-around conv). Requiresgrid_type="single"anduse_chunked_fftconv=False.["circular", "zero"](list/tuple of mode strings, one per spatial axis, length must equaldata_dim): per-axis mixed boundary conditions. Requiresgrid_type=Noneandfft_backend="torch_fft". Mode names are case-insensitive and whitespace-stripped.
Must be
"zero"(or an all-"zero"list) whenis_causal=True.is_causal (bool) – If
True, enforce causal (past-only) convolution so that the output at positionnonly depends on inputs at positions0, …, n. Only valid fordata_dim=1. Incompatible with periodicfft_paddingand with the per-axis list form offft_padding. Default:False.use_chunked_fftconv (bool) – If
True, process channels in groups to reduce peak GPU memory from complex FFT intermediates. Typical savings: ~26% memory at ~11% compute overhead. Not supported withfft_padding="circular". Default:False.fft_backend (Literal['torch_fft', 'subq_ops']) –
Which FFT convolution backend to use.
"torch_fft"(default): torch.fft-based implementations innvsubquadratic.ops.fftconvand related modules."subq_ops": optimised CUDA kernels fromsubquadratic_ops_torch. Supported configurations:data_dim=2,is_causal=False,fft_padding="zero"(2D non-causal zero-padded conv). Per-sample (FiLM) batched kernel weights are supported on this path.data_dim=1,is_causal=True(1D causal conv). The 1D causal CUDA kernel does not accept batched per-sample weights; FiLM conditioning is unsupported on this path.
Does not support fp16 FFT, per-axis
fft_padding, ordata_dim=3.
- Raises:
AssertionError – If
fft_backendis not one of the recognised values, or if a constraint betweengrid_type,fft_padding,is_causal, andfft_backendis violated.ValueError – If
fft_paddingis invalid (wrong type, wrong length, comma-separated string, boolean), ifis_causalis combined with a per-axis padding list or periodic padding, or if the resolved(fft_padding, data_dim)combination has no registered FFT function.
- extra_repr()#
Return a concise summary string for
print(module)andrepr(module).- Returns:
data_dim,hidden_dim,fft_padding,periodic_per_axis(only when in per-axis list mode),grid_type,is_causal,use_chunked_fftconv, andfft_backend.- Return type:
A human-readable string listing the key hyperparameters
- flop_count(spatial_dims, inference=False)#
Count FLOPs for CKConv: kernel generation + FFT convolution.
Two phases.
Phase 1 — kernel generation (via SIREN MLP). Delegated to
self.kernel.flop_count(grid_lens, inference). Atinference=Truewithout FiLM, the kernel is input-independent and can be precomputed, so this phase returns 0.Phase 2 — FFT-based depthwise convolution with
C = self.hidden_dim. The convolution runs in the frequency domain. Padded signal sizesNp_idepend on the padding mode:"zero"non-causal (“same” mode):Np_i = min(s_i + (k_i + 1) // 2, 2 * s_i). Only half the kernel width of extra padding is needed because the output is centre-cropped back to the input size. Matchesfftconv.pylines 624-628."zero"causal (1D only):Np_i = min(s_i + k_i, 2 * s_i). Full linear convolution length; the output is tail-cropped."circular":Np_i = s_i. Wrap-around, no extra padding.
A separable N-D FFT on a grid of size
(Np_1, ..., Np_d)costs5 * prod(Np) * sum(log2(Np_i))real FLOPs per channel, from the radix-2 Cooley-Tukey decomposition (each butterfly ≈ 5 real FLOPs: 1 complex multiply = 4 real muls + 2 real adds, minus shared twiddle-factor optimisations). The implementation usesrfft(real-to-complex), which is ~2x cheaper than a full complex FFT; the5N log Nformula is a conservative upper bound consistent with vision-paper conventions.Three FFTs are needed (forward of input, forward of kernel, inverse of the product). At
inference=Truewithout FiLM the kernel FFT is precomputed and cached, reducing to two FFTs.Pointwise complex multiply in the frequency domain costs
6 * C * prod(Np)(4 real muls + 2 real adds for(a + bi)(c + di)). The shortcut (skip connection) costsC * prod(spatial_dims)elementwise multiplies.- Parameters:
spatial_dims (tuple[int, ...]) – Spatial dimensions of the input signal, e.g.
(H, W)for a 2D image or(L,)for a 1D sequence. Must have length equal toself.data_dim.inference (bool) – If
Trueand the kernel has no FiLM conditioning, skip the kernel generation and kernel FFT FLOPs (both can be precomputed and cached at inference time).
- Returns:
Total estimated FLOPs as an integer.
- Return type:
- apply_convolution(x, conv_kernel, shortcut, is_bhl_input)#
Apply the FFT-based depthwise convolution.
Dispatches to the pre-selected
fftconv_fnorfftconv_fn_bhl_inputdepending on the memory layout ofx. Whenis_bhl_input=Truethe kernel is first transposed from channels-last(B, *spatial, C)to channels-first(B, C, *spatial)to match the BHL-native FFT op.The output y is computed as:
y = IFFT( FFT(x) ⊙ FFT(conv_kernel) ) + shortcut ⊙ x
The
shortcutterm is fused inside the FFT op (no extra kernel launch).- Parameters:
x (Tensor) –
Input signal.
BLH layout (
is_bhl_input=False): shape(B, *spatial, C)whereC = self.hidden_dim.BHL layout (
is_bhl_input=True): shape(B, C, *spatial).
conv_kernel (Tensor) – Kernel values produced by
self.kerneland optionally masked byself.mask. Always in channels-last (BLH) format on entry: shape(1_or_B, *kernel_spatial, C).kernel_spatialequalsspatialon single-grid (circular) axes and2*N - 1on double-grid (zero-padded) axes, whereNis the corresponding input spatial size. Transposed internally whenis_bhl_input=True.shortcut (Tensor) – Per-channel skip-connection scale, shape
(C,). Typicallyself.shortcutor a CP-sliced view thereof.is_bhl_input (bool) – If
True, treatxas channels-first(B, C, *spatial)and useself.fftconv_fn_bhl_input. IfFalse, treatxas channels-last(B, *spatial, C)and useself.fftconv_fn(which handles the reshape internally).
- Returns:
(B, *spatial, C)whenis_bhl_input=False, or(B, C, *spatial)whenis_bhl_input=True.- Return type:
Output tensor in the same memory layout as the input
x
- forward(x, is_bhl_input=False, cp_group=None, **mixer_kwargs)#
Run the CKConv forward pass.
Generates the implicit kernel from the positional grid, optionally applies the attenuation mask, crops the kernel for causal mode, handles context-parallel channel slicing, and applies the FFT convolution with the shortcut term.
Computation (non-causal, no CP):
grid_lens = [(s+1)//2 if single-grid axis else s for s in spatial_dims] k_θ, grid = self.kernel(grid_lens, conditioning=conditioning) # (1, *grid_lens, C) k_θ = self.mask(grid=grid, x=k_θ) # attenuation y = IFFT(FFT(x) ⊙ FFT(k_θ)) + shortcut ⊙ x # FFT conv
For causal mode (1D only),
k_θis cropped to its causal (positive- lag) half before the FFT convolution:k_θ = k_θ[..., kernel_len // 2 :, :] # keep second half
- Parameters:
x (Tensor) –
Input signal tensor. Two supported layouts:
Channels-last (
is_bhl_input=False, default): shape(B, *spatial, hidden_dim)wherespatialhas lengthself.data_dim.Channels-first (
is_bhl_input=True): shape(B, hidden_dim, *spatial).
is_bhl_input (bool) – If
True,xis in channels-first(B, C, *spatial)layout. Default:False(channels-last).cp_group (ProcessGroup) – Context-parallel process group. When provided and
cp_group.size() > 1, the kernel and shortcut are sliced along the channel dimension to match the local channel slice held by this rank. The spatial slice ofxis expected to have already been distributed by the caller. Causal mode is not verified to be correct under CP. Default:None(single-device / no CP).**mixer_kwargs –
Additional keyword arguments forwarded to the kernel generator. The following key is recognised:
conditioning(torch.Tensor, shape(B, cond_dim)): conditioning vector for FiLM-enabled kernels such asSIRENKernelNDwith afilm_cfg. Ignored (no-op) when the kernel has no FiLM generator.
- Returns:
(B, *spatial, hidden_dim)whenis_bhl_input=False, or(B, hidden_dim, *spatial)whenis_bhl_input=True.- Return type:
Output tensor in the same memory layout as
x- Raises:
ValueError – If
cp_groupis provided together withis_causal=True. This combination is explicitly rejected because it has not been verified for correctness — the causal kernel crop and CP channel slicing interact in ways that may silently leak future positions. Do not rely on the error being absent in future versions without re-verification.