ExponentialModulationND#

class ExponentialModulationND(
data_dim,
num_channels,
fast_decay_pct=13.81,
slow_decay_pct=2.3,
)#

Bases: Module

Fixed exponential-decay spatial window applied to implicit convolutional kernels.

Geometry#

Given a coordinate grid normalised to [−1, 1] (center = 0) and a set of per-axis, per-channel decay rates w_{d,c} > 0, the mask at spatial position \(p = (p_0, \ldots, p_{D-1})\) for channel c is:

\[m_c(p) = \prod_{d=0}^{D-1} \exp\!\bigl(-\lvert p_d \rvert \cdot \lvert w_{d,c} \rvert\bigr)\]

All values lie in (0, 1]. The mask equals 1 at the origin (displacement 0) and decays towards 0 as any coordinate moves away from the center. Channels with large w decay quickly (narrow receptive field); channels with small w decay slowly (broad receptive field).

The ND generalisation follows automatically from the product structure: for a 2D image grid the mask is a 2D tent surface; for a 3D volume it is a 3D “tent” shaped object.

Decay rates are not learnable (they are registered as a parameter so they travel with the module and appear in state_dict, but they are marked _no_weight_decay = True and are not updated by the optimizer). The rates are initialised on a linear ramp from slow_decay_pct to fast_decay_pct, divided by data_dim so that the product across axes has a consistent magnitude regardless of the number of dimensions.

Role in CKConvND#

CKConvND optionally passes the output of its implicit kernel network through this module before the FFT convolution step. The resulting masked kernel is then convolved with the input signal via fftconvNd.

param data_dim:

Number of spatial/temporal dimensions (1 for sequences, 2 for images, 3 for videos).

param num_channels:

Number of feature channels C. Each channel receives a distinct decay rate.

param fast_decay_pct:

Upper end of the decay-rate ramp (fastest / narrowest channel). Default 13.81 (≈ \(\ln(10^6)\), so the narrowest channel decays to near zero within a small fraction of the grid).

param slow_decay_pct:

Lower end of the decay-rate ramp (slowest / broadest channel). Default 2.3 (≈ \(\ln(10)\), so the broadest channel retains ≈ 10 % of its value at the grid boundary).

__init__(
data_dim,
num_channels,
fast_decay_pct=13.81,
slow_decay_pct=2.3,
)#

Initialise the exponential modulation module.

Parameters:
  • data_dim (int) – Number of spatial/temporal dimensions.

  • num_channels (int) – Number of feature channels to modulate.

  • fast_decay_pct (float) – Upper end of the per-channel decay-rate ramp (fastest channel).

  • slow_decay_pct (float) – Lower end of the per-channel decay-rate ramp (slowest channel).

extra_repr()#

Additional printing for the ExponentialModulationND class.

forward(grid, x)#

Apply exponential decay modulation element-wise to kernel features.

For each spatial position \(p\) and channel \(c\) computes:

\[\text{out}[\ldots, c] = x[\ldots, c] \cdot \prod_{d} \exp\!\bigl(-\lvert p_d \rvert \cdot \lvert w_{d,c} \rvert\bigr)\]

The product is over all data_dim spatial axes. The mask value is 1 at the origin and decreases monotonically towards 0 as the displacement from the origin grows.

Parameters:
  • grid (Tensor) – Coordinate grid of shape [1, *spatial_dims, data_dim] with values in [−1, 1]. Each entry grid[..., d] contains the normalised coordinate along axis d. Must be torch.float32 (lower precision collapses nearby coordinates together).

  • x (Tensor) – Kernel feature tensor of shape [B, *spatial_dims, num_channels] to be modulated. B is the batch size; *spatial_dims must match the spatial shape of grid.

Returns:

Modulated features with the same shape and dtype as x.

Return type:

torch.Tensor

Raises:

AssertionError – If grid.dtype is not torch.float32.

Parameters:
  • data_dim (int)

  • num_channels (int)

  • fast_decay_pct (float)

  • slow_decay_pct (float)