SIRENPositionalEmbeddingND#
- class SIRENPositionalEmbeddingND(
- data_dim,
- embedding_dim,
- L_cache,
- omega_0,
- use_bias=True,
Bases:
ModuleN-dimensional positional embedding using a SIREN first layer.
Mathematical form#
Given a coordinate grid
xof shape[1, *spatial_dims, data_dim]with values normalised to[-1, 1]per axis, the embedding is:phi(x) = sin( W x + b ) shape […, embedding_dim]
where:
Wis a learned weight matrix of shape[embedding_dim, data_dim], initialised fromU(-2*pi*omega_0/d, +2*pi*omega_0/d)(first-layer SIREN bound, see_init_siren_weights). Unlike the RFF counterpart, this weight is trainable.bis an optional bias vector, zero-initialised.
The
omega_0parameter controls the frequency content at init: higher values bias the embedding toward higher spatial frequencies, giving the downstream MLP a head-start in representing rapid kernel variations. During training the weight can drift away from the init distribution.Grid caching#
Identical to
RandomFourierPositionalEmbeddingND: a coordinate tensor of shape[1, 2*L_0-1, ..., 2*L_{d-1}-1, data_dim]is pre-computed in float32 and cached as a non-persistent buffer. The forward pass slices the central[2*seq_len_i - 1]entries per axis and calls_maybe_extend_grid_cacheif any axis is larger than the current cache.Note: the linear projection is forced to float32 internally (even under autocast) to avoid quantisation errors in the SIREN’s high-frequency sine. The output is cast back to the weight’s dtype before return.
- linear#
Trainable SIREN first-layer projection, shape
[embedding_dim, data_dim].- Type:
- grid_cache#
Non-persistent float32 buffer of shape
[1, 2*L_0-1, ..., 2*L_{d-1}-1, data_dim].- Type:
- step_sizes#
Per-axis grid step
1/(L_i - 1)at construction; kept frozen for consistent cache extension.
- __init__(
- data_dim,
- embedding_dim,
- L_cache,
- omega_0,
- use_bias=True,
Initialize the SIRENPositionalEmbeddingND class.
- Parameters:
data_dim (int) – Dimension of input data.
embedding_dim (int) – Dimensionality of the positional embedding.
L_cache (int | Sequence[int]) – Per-axis cache extents. Either a scalar int (broadcast to every axis, isotropic grid) or a sequence of length
data_dim(one extent per spatial axis, anisotropic grid). The cached grid then has shape(1, 2*L_0 - 1, ..., 2*L_{d-1} - 1, data_dim)and each axis spans[-1, 1]at its own resolution.omega_0 (float) – Frequency scaling factor for the Fourier features.
use_bias (bool) – Whether to use a bias term in the linear layer.
- forward(seq_lens)#
Compute the SIREN positional embeddings for a given spatial grid.
- Parameters:
seq_lens (tuple[int, ...]) – Per-axis output sequence lengths. Length must equal
self.data_dim. For example, for a 2D signal of height H and width W, pass(H, W).- Returns:
torch.Tensor: The positional embeddings
sin(W x + b), where the linear projection is computed in float32 and the result is cast back to the weight dtype. Shape[1, *spatial_dims, embedding_dim].torch.Tensor: The coordinate grid of positions normalised to
[-1, 1]per axis. Shape[1, *spatial_dims, data_dim].
- Return type:
- Raises:
AssertionError – If
len(seq_lens) != self.data_dim.AssertionError – If
self.grid_cacheis notfloat32.