SIRENKernelND#
- class SIRENKernelND(
- out_dim,
- data_dim,
- mlp_hidden_dim,
- num_layers,
- embedding_dim,
- omega_0,
- L_cache,
- use_bias,
- hidden_omega_0=1.0,
- film_cfg=None,
- film_after_pos_embed=False,
Bases:
ModuleConvolutional kernel parametrised by a SIREN (sinusoidal representation network) MLP.
Mathematical form#
The kernel at coordinate
xis:k(x) = Linear_out( SIREN_MLP( phi(x) ) )
where:
phi(x) = sin(W_0 x + b_0)is the SIREN positional embedding (SIRENPositionalEmbeddingND) with first-layer frequencyomega_0.SIREN_MLPis a stack ofnum_layers - 1layers, each computingsin(W_i h + b_i)with weights initialised at frequencyhidden_omega_0.Linear_outis a linear readout toout_dimchannels, scaled by Wang init (sqrt(1 / kernel_volume)) to normalise initial kernel energy.
The full pipeline (without FiLM conditioning) is therefore:
h_0 = sin(W_0 x + b_0) – pos embedding h_i = sin(W_i h_{i-1} + b_i) for i=1..N-1 – hidden layers k = W_out h_{N-1} + b_out – output layer
Hyperparameters controlling bandwidth / smoothness#
omega_0: Frequency of the first SIREN layer. Higher values produce higher-frequency positional features at init. Typical range: 1.0–30.0.hidden_omega_0: Frequency of the hidden SIREN layers. Usually set to 1.0 (default) following the recommendation in the SIREN paper.mlp_hidden_dim: Width of all hidden layers; wider networks can express more complex kernel shapes.
FiLM conditioning#
When
film_cfgis provided, aKernelFiLMGeneratoris instantiated and called on theconditioningtensor (shape[B, C]) to produce per-layer(gamma, beta)pairs (each of shape[B, mlp_hidden_dim]). The hidden activations are then modulated as:h_i <- gamma_i * h_i + beta_i
When
film_after_pos_embed=True, an extra FiLM layer is applied to the output of the positional embedding (before the first hidden layer), making the positional features themselves input-dependent. This requiresembedding_dim == mlp_hidden_dimand one additional film layer in the generator (num_film_layers = num_layers).When conditioning is present, the output kernel has shape
[B, *spatial, out_dim]; otherwise it is[1, *spatial, out_dim].Initialisation#
All
hidden_linearsare SIREN-initialised withhidden_omega_0.out_linearis SIREN-initialised withhidden_omega_0, then additionally Wang-scaled bysqrt(1 / prod(L_cache_per_axis)). This “Wang init” (from the CKConv paper, Romero et al. 2021) divides the output layer’s weights by the square root of the total grid volume (L_cache**data_dimfor isotropic grids), so the initial filter’s L2 energy is independent of the grid resolution.Hidden linear weights and output bias get
_no_weight_decay = Trueso that weight-decay optimizers do not destroy the SIREN spectrum.
Hidden width of the SIREN MLP.
- Type:
Hidden-layer frequency scaling.
- Type:
- positional_embedding#
First SIREN layer.
Hidden linear layers (length
num_layers - 1). Interleaved withself.sinein the forward pass; stored separately so FiLM can be inserted between them.- Type:
- out_linear#
Final readout to
out_dimchannels.- Type:
- num_film_layers#
Number of hidden layers eligible for FiLM modulation (equal to
len(hidden_linears)).- Type:
- film_generator#
KernelFiLMGeneratorinstance orNone.
- film_after_pos_embed#
Whether the first FiLM pair modulates the positional embedding output.
- Type:
- param out_dim:
Number of output channels for the generated kernel.
- param data_dim:
Number of spatial/temporal input dimensions (size of coordinate vector).
- param mlp_hidden_dim:
Hidden width of the SIREN network.
- param num_layers:
Total number of layers including the first and hidden layers (>= 2).
- param embedding_dim:
Dimensionality of the SIREN positional embedding.
- param omega_0:
Frequency scaling for the first SIREN layer.
- param L_cache:
Cache extent controlling the maximum supported grid size before cache growth. Either a scalar int (isotropic, same extent on all axes) or a sequence of length
data_dim(anisotropic, per-axis extents).- param use_bias:
Whether to include biases in linear layers.
- param hidden_omega_0:
Frequency scaling for subsequent SIREN layers (default 1.0).
- param film_cfg:
Optional LazyConfig for KernelFiLMGenerator. When provided, enables input-dependent FiLM conditioning of all hidden SIREN layers.
- param film_after_pos_embed:
If True, the first FiLM (gamma, beta) pair modulates the positional embedding after the sine activation. Requires
embedding_dim == mlp_hidden_dimand one extra FiLM layer infilm_cfg(i.e.num_film_layers = num_layers - 1 + 1 = num_layers).
- __init__(
- out_dim,
- data_dim,
- mlp_hidden_dim,
- num_layers,
- embedding_dim,
- omega_0,
- L_cache,
- use_bias,
- hidden_omega_0=1.0,
- film_cfg=None,
- film_after_pos_embed=False,
Build SIREN MLP and optional FiLM conditioner.
- flop_count(grid_lens, inference=False)#
Return an integer FLOP estimate for one kernel generation forward pass.
At
inference=Truewith no FiLM generator, returns 0 because the kernel is input-independent and can be precomputed once and cached. When afilm_generatorexists the kernel is input-dependent (via register-conditioned FiLM modulation) and must be recomputed on every forward pass regardless of the inference flag.Let
G = prod(2 * L_i - 1 for L_i in grid_lens)be the total grid points. FLOPs breakdown:Positional embedding (
SIRENPositionalEmbeddingND):2 * G * data_dim * embedding_dimfor the linear, plusG * embedding_dimfor thesinactivation.Hidden SIREN layers (
len(self.hidden_linears) = num_layers - 1). First layer:2 * G * embedding_dim * mlp_hidden_dimplusG * mlp_hidden_dimfor the sin. Each subsequent layer:2 * G * mlp_hidden_dim * mlp_hidden_dimplusG * mlp_hidden_dimfor the sin.Output linear:
2 * G * mlp_hidden_dim * out_dim.FiLM conditioning (only when
self.film_generatoris set): the FiLM generator MLP costsself.film_generator.flop_count(); applyinggamma * h + betaper modulated layer costs2 * G * mlp_hidden_dimand is applied to each hidden layer (plus the positional embedding whenfilm_after_pos_embedisTrue— which requiresembedding_dim == mlp_hidden_dim).
- Parameters:
- Returns:
Total FLOPs as an integer.
- Return type:
- forward(seq_lens, conditioning=None)#
Compute the SIREN kernel for a given grid of spatial dimensions.
- Parameters:
seq_lens (tuple[int, ...]) – Lengths of the input grid for which to compute the positional embeddings.
conditioning (Tensor | None) – Optional
[B, C]conditioning vector for FiLM modulation. When provided and afilm_generatorexists, SIREN hidden layers are modulated, making the output kernel batch-dependent:[B, *spatial, out_dim]. WhenNone, behaves identically to the original SIREN:[1, *spatial, out_dim].
- Returns:
(kernel, grid)wherekernelhas shape[1|B, *spatial, out_dim]andgridhas shape[1, *spatial, data_dim].- Return type: