LearnableOmegaSIRENPositionalEmbeddingND#

class LearnableOmegaSIRENPositionalEmbeddingND(
data_dim,
embedding_dim,
L_cache,
omega_0,
omega_0_scale_init=1.0,
omega_0_scale_min=1e-2,
omega_0_scale_max=2.0,
use_bias=True,
apply_lr_scale=False,
)#

Bases: SIRENPositionalEmbeddingND

SIREN positional embedding with a learnable per-row ω₀ multiplier.

Forward computes (in float32, regardless of input dtype):

sin( 2π · ω₀ · ω₀_scale · (W·x + b) )

where:

  • W is the first-layer weight, initialized to U(-1/d, +1/d) (the standard SIREN-1 init without the usual 2π·ω₀ bound scaling). 2π·ω₀ is instead applied at every iteration as a single scalar buffer.

  • ω₀_scale is a learnable per-row parameter of shape [embedding_dim], initialized to omega_0_scale_init and clamped in-place (forward pre-hook, "direct" parametrization) to [omega_0_scale_min, omega_0_scale_max]. The default lower bound is a small positive floor (1e-2) rather than 0 so a row’s effective ω₀ never collapses to zero — at zero the row’s sine becomes a constant sin(bias) and the gradient signal through that row’s ω₀_scale largely vanishes, making recovery hard. With the defaults (init=1, max=2) the per-row effective ω₀ ranges from roughly 0.01·ω₀ to 2·ω₀ and the total multiplier inside the sine reaches 4π·ω₀.

The float32 path covers the linear projection, the multiplier, and the sine; the result is cast back to the original input dtype after the sine. This matches the precision discipline already used for the grid cache in the parent SIRENPositionalEmbeddingND.

Parameters:
  • data_dim (int) – Number of spatial/temporal input dimensions.

  • embedding_dim (int) – Dimensionality of the positional embedding.

  • L_cache (int | Sequence[int]) – Cache extent (controls the initial grid cache size).

  • omega_0 (float) – Constant scalar absorbed into the runtime 2π·ω₀ factor.

  • omega_0_scale_init (float | Sequence[float] | Tensor) – Initial value of the learnable per-row scale. Either a single float (broadcast to embedding_dim) or a 1-D sequence/tensor of length embedding_dim. Defaults to 1.0, so the effective per-row ω₀ at init equals omega_0.

  • omega_0_scale_min (float) – Lower clamp on ω₀_scale. Must be strictly positive — at 0 the row’s first-layer sine collapses to a constant and the gradient signal through its scale largely vanishes, making recovery hard. Default 1e-2.

  • omega_0_scale_max (float) – Upper clamp on ω₀_scale. Default 2.0, giving a total multiplier inside the sine of up to 4π·ω₀.

  • use_bias (bool) – Whether to include a bias term.

  • apply_lr_scale (bool) – When True, attach _lr_scale = 1/(2*pi*omega_0) to self.linear.weight. The optimizer utility _build_param_groups (in experiments/) reads this attribute and multiplies the layer’s effective learning rate by _lr_scale, compensating for the missing 2*pi*omega_0 factor in the SIREN-1 init bound so that the per-step update size matches a standard SIREN. Default False (opt-in, so existing runs are unaffected and the new classes can be A/B-tested).

omega_0#

Constant part of the runtime multiplier (same as the omega_0 constructor argument), stored for diagnostics.

Type:

float

omega_0_scale_min#

Lower clamp bound on omega_0_scale.

Type:

float

omega_0_scale_max#

Upper clamp bound on omega_0_scale.

Type:

float

omega_0_const#

Non-persistent float32 scalar buffer holding 2*pi*omega_0; applied to the linear output at every forward pass.

Type:

torch.Tensor

omega_0_scale#

Learnable per-row scale of shape [embedding_dim]. Clamped to [omega_0_scale_min, omega_0_scale_max] by a forward pre-hook before each forward call.

Type:

torch.nn.Parameter

linear#

First-layer weight W with unscaled SIREN-1 init U(-1/d, +1/d) (no 2*pi*omega_0 factor in the bound). Shape [embedding_dim, data_dim].

Type:

torch.nn.Linear

grid_cache, step_sizes, L_cache_per_axis, L_cache

Inherited from SIRENPositionalEmbeddingND; see that class.

__init__(
data_dim,
embedding_dim,
L_cache,
omega_0,
omega_0_scale_init=1.0,
omega_0_scale_min=1e-2,
omega_0_scale_max=2.0,
use_bias=True,
apply_lr_scale=False,
)#

Initialize the learnable-ω₀ SIREN positional embedding; see the class docstring.

Parameters:
forward(seq_lens)#

Compute the positional embedding with a fp32-internal learnable-ω₀ first layer.

Parameters:

seq_lens (tuple[int, ...]) – Lengths of the input grid for which to compute the positional embeddings.

Returns:

  • torch.Tensor: The positional embedding sin(2π·ω₀·s·(W·x+b)) cast back to the original linear-weight dtype, of shape [1, *spatial_dims, embedding_dim].

  • torch.Tensor: The grid coordinates, shape [1, *spatial_dims, data_dim] (fp32, as in the parent).

Return type:

tuple

extra_repr()#

Diagnostic string showing ω₀, the scale range, and the current scale stats.

Return type:

str