SIRENPositionalEmbeddingND#

class SIRENPositionalEmbeddingND(
data_dim,
embedding_dim,
L_cache,
omega_0,
use_bias=True,
)#

Bases: Module

N-dimensional positional embedding using a SIREN first layer.

Mathematical form#

Given a coordinate grid x of shape [1, *spatial_dims, data_dim] with values normalised to [-1, 1] per axis, the embedding is:

phi(x) = sin( W x + b ) shape […, embedding_dim]

where:

  • W is a learned weight matrix of shape [embedding_dim, data_dim], initialised from U(-2*pi*omega_0/d, +2*pi*omega_0/d) (first-layer SIREN bound, see _init_siren_weights). Unlike the RFF counterpart, this weight is trainable.

  • b is an optional bias vector, zero-initialised.

The omega_0 parameter controls the frequency content at init: higher values bias the embedding toward higher spatial frequencies, giving the downstream MLP a head-start in representing rapid kernel variations. During training the weight can drift away from the init distribution.

Grid caching#

Identical to RandomFourierPositionalEmbeddingND: a coordinate tensor of shape [1, 2*L_0-1, ..., 2*L_{d-1}-1, data_dim] is pre-computed in float32 and cached as a non-persistent buffer. The forward pass slices the central [2*seq_len_i - 1] entries per axis and calls _maybe_extend_grid_cache if any axis is larger than the current cache.

Note: the linear projection is forced to float32 internally (even under autocast) to avoid quantisation errors in the SIREN’s high-frequency sine. The output is cast back to the weight’s dtype before return.

data_dim#

Number of spatial / temporal input dimensions.

Type:

int

embedding_dim#

Output embedding size.

Type:

int

L_cache_per_axis#

Current per-axis cache extents.

Type:

tuple[int, …]

L_cache#

Original L_cache argument (diagnostics).

Type:

int | Sequence[int]

omega_0#

Frequency scaling factor used for SIREN init.

Type:

float

use_bias#

Whether a bias is present in the linear projection.

Type:

bool

linear#

Trainable SIREN first-layer projection, shape [embedding_dim, data_dim].

Type:

torch.nn.Linear

grid_cache#

Non-persistent float32 buffer of shape [1, 2*L_0-1, ..., 2*L_{d-1}-1, data_dim].

Type:

torch.Tensor

step_sizes#

Per-axis grid step 1/(L_i - 1) at construction; kept frozen for consistent cache extension.

Type:

tuple[float, …]

__init__(
data_dim,
embedding_dim,
L_cache,
omega_0,
use_bias=True,
)#

Initialize the SIRENPositionalEmbeddingND class.

Parameters:
  • data_dim (int) – Dimension of input data.

  • embedding_dim (int) – Dimensionality of the positional embedding.

  • L_cache (int | Sequence[int]) – Per-axis cache extents. Either a scalar int (broadcast to every axis, isotropic grid) or a sequence of length data_dim (one extent per spatial axis, anisotropic grid). The cached grid then has shape (1, 2*L_0 - 1, ..., 2*L_{d-1} - 1, data_dim) and each axis spans [-1, 1] at its own resolution.

  • omega_0 (float) – Frequency scaling factor for the Fourier features.

  • use_bias (bool) – Whether to use a bias term in the linear layer.

forward(seq_lens)#

Compute the SIREN positional embeddings for a given spatial grid.

Parameters:

seq_lens (tuple[int, ...]) – Per-axis output sequence lengths. Length must equal self.data_dim. For example, for a 2D signal of height H and width W, pass (H, W).

Returns:

  • torch.Tensor: The positional embeddings sin(W x + b), where the linear projection is computed in float32 and the result is cast back to the weight dtype. Shape [1, *spatial_dims, embedding_dim].

  • torch.Tensor: The coordinate grid of positions normalised to [-1, 1] per axis. Shape [1, *spatial_dims, data_dim].

Return type:

tuple

Raises:
Parameters: