SIRENKernelND#

class SIRENKernelND(
out_dim,
data_dim,
mlp_hidden_dim,
num_layers,
embedding_dim,
omega_0,
L_cache,
use_bias,
hidden_omega_0=1.0,
film_cfg=None,
film_after_pos_embed=False,
)#

Bases: Module

Convolutional kernel parametrised by a SIREN (sinusoidal representation network) MLP.

Mathematical form#

The kernel at coordinate x is:

k(x) = Linear_out( SIREN_MLP( phi(x) ) )

where:

  • phi(x) = sin(W_0 x + b_0) is the SIREN positional embedding (SIRENPositionalEmbeddingND) with first-layer frequency omega_0.

  • SIREN_MLP is a stack of num_layers - 1 layers, each computing sin(W_i h + b_i) with weights initialised at frequency hidden_omega_0.

  • Linear_out is a linear readout to out_dim channels, scaled by Wang init (sqrt(1 / kernel_volume)) to normalise initial kernel energy.

The full pipeline (without FiLM conditioning) is therefore:

h_0 = sin(W_0 x + b_0) – pos embedding h_i = sin(W_i h_{i-1} + b_i) for i=1..N-1 – hidden layers k = W_out h_{N-1} + b_out – output layer

Hyperparameters controlling bandwidth / smoothness#

  • omega_0: Frequency of the first SIREN layer. Higher values produce higher-frequency positional features at init. Typical range: 1.0–30.0.

  • hidden_omega_0: Frequency of the hidden SIREN layers. Usually set to 1.0 (default) following the recommendation in the SIREN paper.

  • mlp_hidden_dim: Width of all hidden layers; wider networks can express more complex kernel shapes.

FiLM conditioning#

When film_cfg is provided, a KernelFiLMGenerator is instantiated and called on the conditioning tensor (shape [B, C]) to produce per-layer (gamma, beta) pairs (each of shape [B, mlp_hidden_dim]). The hidden activations are then modulated as:

h_i <- gamma_i * h_i + beta_i

When film_after_pos_embed=True, an extra FiLM layer is applied to the output of the positional embedding (before the first hidden layer), making the positional features themselves input-dependent. This requires embedding_dim == mlp_hidden_dim and one additional film layer in the generator (num_film_layers = num_layers).

When conditioning is present, the output kernel has shape [B, *spatial, out_dim]; otherwise it is [1, *spatial, out_dim].

Initialisation#

  • All hidden_linears are SIREN-initialised with hidden_omega_0.

  • out_linear is SIREN-initialised with hidden_omega_0, then additionally Wang-scaled by sqrt(1 / prod(L_cache_per_axis)). This “Wang init” (from the CKConv paper, Romero et al. 2021) divides the output layer’s weights by the square root of the total grid volume (L_cache**data_dim for isotropic grids), so the initial filter’s L2 energy is independent of the grid resolution.

  • Hidden linear weights and output bias get _no_weight_decay = True so that weight-decay optimizers do not destroy the SIREN spectrum.

out_dim#

Number of output channels (kernel depth).

Type:

int

data_dim#

Number of spatial / temporal input dimensions.

Type:

int

mlp_hidden_dim#

Hidden width of the SIREN MLP.

Type:

int

num_layers#

Total number of SIREN layers (>= 2).

Type:

int

embedding_dim#

SIREN positional-embedding dimensionality.

Type:

int

omega_0#

First-layer frequency scaling.

Type:

float

hidden_omega_0#

Hidden-layer frequency scaling.

Type:

float

L_cache_per_axis#

Per-axis cache extents (canonical form).

Type:

tuple[int, …]

L_cache#

Original L_cache argument (diagnostics).

Type:

int | Sequence[int]

positional_embedding#

First SIREN layer.

Type:

SIRENPositionalEmbeddingND

hidden_linears#

Hidden linear layers (length num_layers - 1). Interleaved with self.sine in the forward pass; stored separately so FiLM can be inserted between them.

Type:

torch.nn.ModuleList

sine#

Shared sine activation applied after every hidden linear.

Type:

Sine

out_linear#

Final readout to out_dim channels.

Type:

torch.nn.Linear

num_film_layers#

Number of hidden layers eligible for FiLM modulation (equal to len(hidden_linears)).

Type:

int

film_generator#

KernelFiLMGenerator instance or None.

film_after_pos_embed#

Whether the first FiLM pair modulates the positional embedding output.

Type:

bool

param out_dim:

Number of output channels for the generated kernel.

param data_dim:

Number of spatial/temporal input dimensions (size of coordinate vector).

param mlp_hidden_dim:

Hidden width of the SIREN network.

param num_layers:

Total number of layers including the first and hidden layers (>= 2).

param embedding_dim:

Dimensionality of the SIREN positional embedding.

param omega_0:

Frequency scaling for the first SIREN layer.

param L_cache:

Cache extent controlling the maximum supported grid size before cache growth. Either a scalar int (isotropic, same extent on all axes) or a sequence of length data_dim (anisotropic, per-axis extents).

param use_bias:

Whether to include biases in linear layers.

param hidden_omega_0:

Frequency scaling for subsequent SIREN layers (default 1.0).

param film_cfg:

Optional LazyConfig for KernelFiLMGenerator. When provided, enables input-dependent FiLM conditioning of all hidden SIREN layers.

param film_after_pos_embed:

If True, the first FiLM (gamma, beta) pair modulates the positional embedding after the sine activation. Requires embedding_dim == mlp_hidden_dim and one extra FiLM layer in film_cfg (i.e. num_film_layers = num_layers - 1 + 1 = num_layers).

__init__(
out_dim,
data_dim,
mlp_hidden_dim,
num_layers,
embedding_dim,
omega_0,
L_cache,
use_bias,
hidden_omega_0=1.0,
film_cfg=None,
film_after_pos_embed=False,
)#

Build SIREN MLP and optional FiLM conditioner.

Parameters:
flop_count(grid_lens, inference=False)#

Return an integer FLOP estimate for one kernel generation forward pass.

At inference=True with no FiLM generator, returns 0 because the kernel is input-independent and can be precomputed once and cached. When a film_generator exists the kernel is input-dependent (via register-conditioned FiLM modulation) and must be recomputed on every forward pass regardless of the inference flag.

Let G = prod(2 * L_i - 1 for L_i in grid_lens) be the total grid points. FLOPs breakdown:

  • Positional embedding (SIRENPositionalEmbeddingND): 2 * G * data_dim * embedding_dim for the linear, plus G * embedding_dim for the sin activation.

  • Hidden SIREN layers (len(self.hidden_linears) = num_layers - 1). First layer: 2 * G * embedding_dim * mlp_hidden_dim plus G * mlp_hidden_dim for the sin. Each subsequent layer: 2 * G * mlp_hidden_dim * mlp_hidden_dim plus G * mlp_hidden_dim for the sin.

  • Output linear: 2 * G * mlp_hidden_dim * out_dim.

  • FiLM conditioning (only when self.film_generator is set): the FiLM generator MLP costs self.film_generator.flop_count(); applying gamma * h + beta per modulated layer costs 2 * G * mlp_hidden_dim and is applied to each hidden layer (plus the positional embedding when film_after_pos_embed is True — which requires embedding_dim == mlp_hidden_dim).

Parameters:
  • grid_lens (tuple[int, ...]) – Per-axis output sequence lengths — the same tuple you would pass to forward as seq_lens. The total number of coordinate points the MLP processes is G = prod(2*L - 1 for L in grid_lens).

  • inference (bool) – If True and no FiLM generator, return 0 (cacheable kernel).

Returns:

Total FLOPs as an integer.

Return type:

int

forward(seq_lens, conditioning=None)#

Compute the SIREN kernel for a given grid of spatial dimensions.

Parameters:
  • seq_lens (tuple[int, ...]) – Lengths of the input grid for which to compute the positional embeddings.

  • conditioning (Tensor | None) – Optional [B, C] conditioning vector for FiLM modulation. When provided and a film_generator exists, SIREN hidden layers are modulated, making the output kernel batch-dependent: [B, *spatial, out_dim]. When None, behaves identically to the original SIREN: [1, *spatial, out_dim].

Returns:

(kernel, grid) where kernel has shape [1|B, *spatial, out_dim] and grid has shape [1, *spatial, data_dim].

Return type:

tuple[Tensor, Tensor]

Parameters: