GaussianModulationND#

class GaussianModulationND(
data_dim,
num_channels,
grid_size,
min_attenuation_at_step=0.1,
max_attenuation_at_limit=0.95,
init_extent=1.0,
parametrization='direct',
)#

Bases: Module

Learnable Gaussian-window spatial mask for ND convolutional kernels.

Geometry#

For a coordinate grid normalised to [−1, 1] (center = 0) and per-axis, per-channel standard deviations \(\sigma_{d,c} > 0\), the mask value at position \(p = (p_0, \ldots, p_{D-1})\) for channel c is the product of per-axis Gaussians:

\[m_c(p) = \prod_{d=0}^{D-1} \exp\!\Bigl(-\tfrac{1}{2}\bigl(p_d / \sigma_{d,c}\bigr)^2\Bigr)\]

All values lie in (0, 1]. The mask equals 1 at the origin and decays symmetrically in all directions; the level set at value v is an axis-aligned ellipsoid with semi-axes \(\sigma_{d,c}\sqrt{-2\ln v}\).

1 = fully included, 0 = fully excluded. A narrow Gaussian (small σ) concentrates the effective kernel around the origin, making the operator local; a wide Gaussian (large σ) lets the full grid contribute, approaching a global convolution.

ND generalisation#

The mask factorises over axes. In 2D the mask surface looks like a 2D Gaussian bell (not a sphere — each axis has an independent σ). In 3D it is a trivariate axis-aligned Gaussian. Because the mask is a product, the corner value at position (position, position, …, position) is the product of the individual per-axis Gaussian values, which equals single_axis_mask_value ** data_dim. The attenuation parameters (min_attenuation_at_step, max_attenuation_at_limit) are therefore defined as single-axis (1D) measurements — the effective ND attenuation at the grid corner is stricter by a factor of data_dim in the exponent.

Parametrisation and clamping#

The learned parameter std_param of shape [data_dim, num_channels] stores raw values that are mapped to strictly-positive std values via parametrization:

  • 'direct'std_param IS the std; a register_forward_pre_hook clamps it into [min_std, max_std] in-place (torch.no_grad), so gradients at the boundary are preserved through the activation.

  • 'log'std = exp(std_param); hard clamp applied after (breaks boundary gradients — see inline warning).

  • 'softplus'std = softplus(std_param); hard clamp applied after.

Initialisation#

min_attenuation_at_step and max_attenuation_at_limit define the clamp bounds [min_std, max_std]. The initial std_param is a logspace ramp from min_std to init_std_high_unit on every axis, scaled per-axis by init_extent:

  • init_std_low[d]  = clamp(min_std            * extent[d], min_std, max_std)

  • init_std_high[d] = clamp(init_std_high_unit * extent[d], min_std, max_std)

where init_std_high_unit 0.4724 is the std at which a 1D Gaussian reaches 0.1 at position 1. See init_extent below.

All attenuation values are single-axis (1D) measurements; see the ND generalisation note above.

param data_dim:

Number of spatial/temporal dimensions (1 for sequences, 2 for images, 3 for volumes).

param num_channels:

Number of feature channels C to modulate.

param grid_size:

Number of grid points per spatial dimension. Used to compute the size of the smallest grid step (min_step = 2 / (grid_size - 1)), which sets min_std. Auto-injected by CKConvND.

param min_attenuation_at_step:

Target 1D mask value at the first grid step from the origin for the narrowest channel. Smaller values → narrower minimum std → more local minimum channel. Default 0.1.

param max_attenuation_at_limit:

Target 1D mask value at the grid boundary (position = 1) for the widest channel. Larger values → wider maximum std → less attenuation at the boundary. Default 0.95.

param init_extent:

Scalar or per-axis sequence controlling the initial bandwidth scale on each axis. Must be strictly > 0; defaults to 1.0 on every axis (reference ramp on every axis).

Examples for an anisotropic L_cache = (8, 64, 64) cache with grid_size = 127:

  • init_extent = 1.0 — all axes use the reference ramp. On a short depth axis the bottom of the ramp can be unusably narrow.

  • init_extent = (max_std/min_std, 1.0, 1.0) — depth ramp saturates at max_std (axis effectively unmasked at init).

  • init_extent = (1.0, 0.25, 0.25) — H/W are 4× narrower than the reference (extreme localisation on spatial axes).

param parametrization:

One of 'direct', 'log', 'softplus'. Controls the mapping from std_param to std values. Default 'direct'.

__init__(
data_dim,
num_channels,
grid_size,
min_attenuation_at_step=0.1,
max_attenuation_at_limit=0.95,
init_extent=1.0,
parametrization='direct',
)#

Initialise the Gaussian modulation module.

Parameters:
  • data_dim (int) – Number of spatial/temporal dimensions.

  • num_channels (int) – Number of feature channels to modulate.

  • grid_size (int) – Number of grid points per spatial dimension.

  • min_attenuation_at_step (float) – 1D mask value at the first grid step from the origin for the narrowest channel (sets min_std).

  • max_attenuation_at_limit (float) – 1D mask value at the grid boundary for the widest channel (sets max_std).

  • init_extent (float | Sequence[float]) – Per-axis bandwidth scale for initialisation (> 0). Pass a float to broadcast, or a sequence of length data_dim.

  • parametrization (str) – Mapping from std_param to std values. One of 'direct', 'log', 'softplus'.

extra_repr()#

Additional printing for the GaussianModulationND class.

forward(grid, x)#

Apply Gaussian spatial modulation element-wise to kernel features.

For each spatial position \(p\) and channel \(c\) computes:

\[\text{out}[\ldots, c] = x[\ldots, c] \cdot \exp\!\Bigl(-\tfrac{1}{2} \sum_{d} \bigl(p_d / \sigma_{d,c}\bigr)^2\Bigr)\]

The exponent is computed as a single einsum (sum over axes) for efficiency, avoiding the intermediate prod(exp) formulation.

Mask convention: 1 = fully included (origin), 0 = fully excluded (large displacement).

Parameters:
  • grid (Tensor) – Coordinate grid of shape [1, *spatial_dims, data_dim] with values in [−1, 1]. Must be torch.float32.

  • x (Tensor) – Kernel feature tensor of shape [B, *spatial_dims, num_channels]. *spatial_dims must match the spatial shape of grid.

Returns:

Modulated features with the same shape and dtype as x. The internal Gaussian computation is always done in float32 and then cast to x.dtype.

Return type:

torch.Tensor

Raises:

AssertionError – If grid.dtype is not torch.float32.

Parameters: