ExponentialModulationND#
- class ExponentialModulationND(
- data_dim,
- num_channels,
- fast_decay_pct=13.81,
- slow_decay_pct=2.3,
Bases:
ModuleFixed exponential-decay spatial window applied to implicit convolutional kernels.
Geometry#
Given a coordinate grid normalised to
[−1, 1](center = 0) and a set of per-axis, per-channel decay ratesw_{d,c} > 0, the mask at spatial position \(p = (p_0, \ldots, p_{D-1})\) for channelcis:\[m_c(p) = \prod_{d=0}^{D-1} \exp\!\bigl(-\lvert p_d \rvert \cdot \lvert w_{d,c} \rvert\bigr)\]All values lie in
(0, 1]. The mask equals 1 at the origin (displacement 0) and decays towards 0 as any coordinate moves away from the center. Channels with largewdecay quickly (narrow receptive field); channels with smallwdecay slowly (broad receptive field).The ND generalisation follows automatically from the product structure: for a 2D image grid the mask is a 2D tent surface; for a 3D volume it is a 3D “tent” shaped object.
Decay rates are not learnable (they are registered as a parameter so they travel with the module and appear in
state_dict, but they are marked_no_weight_decay = Trueand are not updated by the optimizer). The rates are initialised on a linear ramp fromslow_decay_pcttofast_decay_pct, divided bydata_dimso that the product across axes has a consistent magnitude regardless of the number of dimensions.Role in CKConvND#
CKConvNDoptionally passes the output of its implicit kernel network through this module before the FFT convolution step. The resulting masked kernel is then convolved with the input signal viafftconvNd.- param data_dim:
Number of spatial/temporal dimensions (1 for sequences, 2 for images, 3 for videos).
- param num_channels:
Number of feature channels
C. Each channel receives a distinct decay rate.- param fast_decay_pct:
Upper end of the decay-rate ramp (fastest / narrowest channel). Default
13.81(≈ \(\ln(10^6)\), so the narrowest channel decays to near zero within a small fraction of the grid).- param slow_decay_pct:
Lower end of the decay-rate ramp (slowest / broadest channel). Default
2.3(≈ \(\ln(10)\), so the broadest channel retains ≈ 10 % of its value at the grid boundary).
- __init__(
- data_dim,
- num_channels,
- fast_decay_pct=13.81,
- slow_decay_pct=2.3,
Initialise the exponential modulation module.
- Parameters:
- extra_repr()#
Additional printing for the ExponentialModulationND class.
- forward(grid, x)#
Apply exponential decay modulation element-wise to kernel features.
For each spatial position \(p\) and channel \(c\) computes:
\[\text{out}[\ldots, c] = x[\ldots, c] \cdot \prod_{d} \exp\!\bigl(-\lvert p_d \rvert \cdot \lvert w_{d,c} \rvert\bigr)\]The product is over all
data_dimspatial axes. The mask value is 1 at the origin and decreases monotonically towards 0 as the displacement from the origin grows.- Parameters:
grid (Tensor) – Coordinate grid of shape
[1, *spatial_dims, data_dim]with values in[−1, 1]. Each entrygrid[..., d]contains the normalised coordinate along axisd. Must betorch.float32(lower precision collapses nearby coordinates together).x (Tensor) – Kernel feature tensor of shape
[B, *spatial_dims, num_channels]to be modulated.Bis the batch size;*spatial_dimsmust match the spatial shape ofgrid.
- Returns:
Modulated features with the same shape and dtype as
x.- Return type:
- Raises:
AssertionError – If
grid.dtypeis nottorch.float32.