LearnableOmegaSIRENPositionalEmbeddingND#
- class LearnableOmegaSIRENPositionalEmbeddingND(
- data_dim,
- embedding_dim,
- L_cache,
- omega_0,
- omega_0_scale_init=1.0,
- omega_0_scale_min=1e-2,
- omega_0_scale_max=2.0,
- use_bias=True,
- apply_lr_scale=False,
Bases:
SIRENPositionalEmbeddingNDSIREN positional embedding with a learnable per-row ω₀ multiplier.
Forward computes (in float32, regardless of input dtype):
sin( 2π · ω₀ · ω₀_scale · (W·x + b) )
where:
Wis the first-layer weight, initialized toU(-1/d, +1/d)(the standard SIREN-1 init without the usual2π·ω₀bound scaling).2π·ω₀is instead applied at every iteration as a single scalar buffer.ω₀_scaleis a learnable per-row parameter of shape[embedding_dim], initialized toomega_0_scale_initand clamped in-place (forward pre-hook,"direct"parametrization) to[omega_0_scale_min, omega_0_scale_max]. The default lower bound is a small positive floor (1e-2) rather than0so a row’s effective ω₀ never collapses to zero — at zero the row’s sine becomes a constantsin(bias)and the gradient signal through that row’sω₀_scalelargely vanishes, making recovery hard. With the defaults (init=1, max=2) the per-row effective ω₀ ranges from roughly0.01·ω₀to2·ω₀and the total multiplier inside the sine reaches4π·ω₀.
The float32 path covers the linear projection, the multiplier, and the sine; the result is cast back to the original input dtype after the sine. This matches the precision discipline already used for the grid cache in the parent
SIRENPositionalEmbeddingND.- Parameters:
data_dim (int) – Number of spatial/temporal input dimensions.
embedding_dim (int) – Dimensionality of the positional embedding.
L_cache (int | Sequence[int]) – Cache extent (controls the initial grid cache size).
omega_0 (float) – Constant scalar absorbed into the runtime
2π·ω₀factor.omega_0_scale_init (float | Sequence[float] | Tensor) – Initial value of the learnable per-row scale. Either a single float (broadcast to
embedding_dim) or a 1-D sequence/tensor of lengthembedding_dim. Defaults to 1.0, so the effective per-row ω₀ at init equalsomega_0.omega_0_scale_min (float) – Lower clamp on
ω₀_scale. Must be strictly positive — at0the row’s first-layer sine collapses to a constant and the gradient signal through its scale largely vanishes, making recovery hard. Default1e-2.omega_0_scale_max (float) – Upper clamp on
ω₀_scale. Default 2.0, giving a total multiplier inside the sine of up to4π·ω₀.use_bias (bool) – Whether to include a bias term.
apply_lr_scale (bool) – When True, attach
_lr_scale = 1/(2*pi*omega_0)toself.linear.weight. The optimizer utility_build_param_groups(inexperiments/) reads this attribute and multiplies the layer’s effective learning rate by_lr_scale, compensating for the missing2*pi*omega_0factor in the SIREN-1 init bound so that the per-step update size matches a standard SIREN. Default False (opt-in, so existing runs are unaffected and the new classes can be A/B-tested).
- omega_0#
Constant part of the runtime multiplier (same as the
omega_0constructor argument), stored for diagnostics.- Type:
- omega_0_const#
Non-persistent float32 scalar buffer holding
2*pi*omega_0; applied to the linear output at every forward pass.- Type:
- omega_0_scale#
Learnable per-row scale of shape
[embedding_dim]. Clamped to[omega_0_scale_min, omega_0_scale_max]by a forward pre-hook before each forward call.- Type:
torch.nn.Parameter
- linear#
First-layer weight
Wwith unscaled SIREN-1 initU(-1/d, +1/d)(no2*pi*omega_0factor in the bound). Shape[embedding_dim, data_dim].- Type:
- grid_cache, step_sizes, L_cache_per_axis, L_cache
Inherited from
SIRENPositionalEmbeddingND; see that class.
- __init__(
- data_dim,
- embedding_dim,
- L_cache,
- omega_0,
- omega_0_scale_init=1.0,
- omega_0_scale_min=1e-2,
- omega_0_scale_max=2.0,
- use_bias=True,
- apply_lr_scale=False,
Initialize the learnable-ω₀ SIREN positional embedding; see the class docstring.
- forward(seq_lens)#
Compute the positional embedding with a fp32-internal learnable-ω₀ first layer.
- Parameters:
seq_lens (tuple[int, ...]) – Lengths of the input grid for which to compute the positional embeddings.
- Returns:
torch.Tensor: The positional embedding
sin(2π·ω₀·s·(W·x+b))cast back to the original linear-weight dtype, of shape[1, *spatial_dims, embedding_dim].torch.Tensor: The grid coordinates, shape
[1, *spatial_dims, data_dim](fp32, as in the parent).
- Return type: