GaussianModulationND#
- class GaussianModulationND(
- data_dim,
- num_channels,
- grid_size,
- min_attenuation_at_step=0.1,
- max_attenuation_at_limit=0.95,
- init_extent=1.0,
- parametrization='direct',
Bases:
ModuleLearnable Gaussian-window spatial mask for ND convolutional kernels.
Geometry#
For a coordinate grid normalised to
[−1, 1](center = 0) and per-axis, per-channel standard deviations \(\sigma_{d,c} > 0\), the mask value at position \(p = (p_0, \ldots, p_{D-1})\) for channelcis the product of per-axis Gaussians:\[m_c(p) = \prod_{d=0}^{D-1} \exp\!\Bigl(-\tfrac{1}{2}\bigl(p_d / \sigma_{d,c}\bigr)^2\Bigr)\]All values lie in
(0, 1]. The mask equals 1 at the origin and decays symmetrically in all directions; the level set at valuevis an axis-aligned ellipsoid with semi-axes \(\sigma_{d,c}\sqrt{-2\ln v}\).1 = fully included, 0 = fully excluded. A narrow Gaussian (small
σ) concentrates the effective kernel around the origin, making the operator local; a wide Gaussian (largeσ) lets the full grid contribute, approaching a global convolution.ND generalisation#
The mask factorises over axes. In 2D the mask surface looks like a 2D Gaussian bell (not a sphere — each axis has an independent
σ). In 3D it is a trivariate axis-aligned Gaussian. Because the mask is a product, the corner value at position(position, position, …, position)is the product of the individual per-axis Gaussian values, which equalssingle_axis_mask_value ** data_dim. The attenuation parameters (min_attenuation_at_step,max_attenuation_at_limit) are therefore defined as single-axis (1D) measurements — the effective ND attenuation at the grid corner is stricter by a factor ofdata_dimin the exponent.Parametrisation and clamping#
The learned parameter
std_paramof shape[data_dim, num_channels]stores raw values that are mapped to strictly-positive std values viaparametrization:'direct'—std_paramIS the std; aregister_forward_pre_hookclamps it into[min_std, max_std]in-place (torch.no_grad), so gradients at the boundary are preserved through the activation.'log'—std = exp(std_param); hard clamp applied after (breaks boundary gradients — see inline warning).'softplus'—std = softplus(std_param); hard clamp applied after.
Initialisation#
min_attenuation_at_stepandmax_attenuation_at_limitdefine the clamp bounds[min_std, max_std]. The initialstd_paramis a logspace ramp frommin_stdtoinit_std_high_uniton every axis, scaled per-axis byinit_extent:init_std_low[d] = clamp(min_std * extent[d], min_std, max_std)init_std_high[d] = clamp(init_std_high_unit * extent[d], min_std, max_std)
where
init_std_high_unit ≈ 0.4724is the std at which a 1D Gaussian reaches0.1at position 1. Seeinit_extentbelow.All attenuation values are single-axis (1D) measurements; see the ND generalisation note above.
- param data_dim:
Number of spatial/temporal dimensions (1 for sequences, 2 for images, 3 for volumes).
- param num_channels:
Number of feature channels
Cto modulate.- param grid_size:
Number of grid points per spatial dimension. Used to compute the size of the smallest grid step (
min_step = 2 / (grid_size - 1)), which setsmin_std. Auto-injected byCKConvND.- param min_attenuation_at_step:
Target 1D mask value at the first grid step from the origin for the narrowest channel. Smaller values → narrower minimum std → more local minimum channel. Default
0.1.- param max_attenuation_at_limit:
Target 1D mask value at the grid boundary (
position = 1) for the widest channel. Larger values → wider maximum std → less attenuation at the boundary. Default0.95.- param init_extent:
Scalar or per-axis sequence controlling the initial bandwidth scale on each axis. Must be strictly
> 0; defaults to1.0on every axis (reference ramp on every axis).Examples for an anisotropic
L_cache = (8, 64, 64)cache withgrid_size = 127:init_extent = 1.0— all axes use the reference ramp. On a short depth axis the bottom of the ramp can be unusably narrow.init_extent = (max_std/min_std, 1.0, 1.0)— depth ramp saturates atmax_std(axis effectively unmasked at init).init_extent = (1.0, 0.25, 0.25)— H/W are 4× narrower than the reference (extreme localisation on spatial axes).
- param parametrization:
One of
'direct','log','softplus'. Controls the mapping fromstd_paramto std values. Default'direct'.
- __init__(
- data_dim,
- num_channels,
- grid_size,
- min_attenuation_at_step=0.1,
- max_attenuation_at_limit=0.95,
- init_extent=1.0,
- parametrization='direct',
Initialise the Gaussian modulation module.
- Parameters:
data_dim (int) – Number of spatial/temporal dimensions.
num_channels (int) – Number of feature channels to modulate.
grid_size (int) – Number of grid points per spatial dimension.
min_attenuation_at_step (float) – 1D mask value at the first grid step from the origin for the narrowest channel (sets
min_std).max_attenuation_at_limit (float) – 1D mask value at the grid boundary for the widest channel (sets
max_std).init_extent (float | Sequence[float]) – Per-axis bandwidth scale for initialisation (> 0). Pass a float to broadcast, or a sequence of length
data_dim.parametrization (str) – Mapping from
std_paramto std values. One of'direct','log','softplus'.
- extra_repr()#
Additional printing for the GaussianModulationND class.
- forward(grid, x)#
Apply Gaussian spatial modulation element-wise to kernel features.
For each spatial position \(p\) and channel \(c\) computes:
\[\text{out}[\ldots, c] = x[\ldots, c] \cdot \exp\!\Bigl(-\tfrac{1}{2} \sum_{d} \bigl(p_d / \sigma_{d,c}\bigr)^2\Bigr)\]The exponent is computed as a single einsum (sum over axes) for efficiency, avoiding the intermediate
prod(exp)formulation.Mask convention: 1 = fully included (origin), 0 = fully excluded (large displacement).
- Parameters:
- Returns:
Modulated features with the same shape and dtype as
x. The internal Gaussian computation is always done infloat32and then cast tox.dtype.- Return type:
- Raises:
AssertionError – If
grid.dtypeis nottorch.float32.