RandomFourierPositionalEmbeddingND#

class RandomFourierPositionalEmbeddingND(
data_dim,
embedding_dim,
L_cache,
omega_0,
use_bias=True,
)#

Bases: Module

N-dimensional positional embedding using Random Fourier Features (RFF).

Mathematical form#

Given a coordinate grid x of shape [1, *spatial_dims, data_dim] with values normalised to [-1, 1] per axis, the embedding is:

phi(x) = [ cos(W x + b), sin(W x + b) ] shape […, embedding_dim]

where:

  • W is the first-layer weight matrix of shape [embedding_dim//2, data_dim], drawn once at construction from N(0, (2*pi*omega_0)^2) and then frozen (not trained).

  • b is a bias vector of shape [embedding_dim//2], initialised to zero. It is also frozen.

  • The concatenation of cosine and sine doubles the embedding dimension.

The resulting features approximate the feature map of a stationary RBF (Gaussian) kernel with bandwidth omega_0 — the larger omega_0, the higher the dominant spatial frequency encoded in the embedding.

Grid caching#

To avoid rebuilding the meshgrid on every forward pass, the module maintains a grid_cache buffer (a pre-computed coordinate tensor of shape [1, 2*L_0-1, ..., 2*L_{d-1}-1, data_dim] in float32). On each forward call the central [2*seq_len_i - 1] points are sliced per axis. When a larger seq_len is seen at runtime the cache grows automatically via _maybe_extend_grid_cache, preserving the original step size on each axis.

Note: the W and b parameters have _no_weight_decay = True set so that any weight-decay optimizer does not shrink the random projection.

data_dim#

Number of spatial / temporal input dimensions.

Type:

int

embedding_dim#

Output embedding size (must be even; split equally between cos and sin features).

Type:

int

L_cache_per_axis#

Current per-axis cache extents. May grow at runtime; the original value at construction is stored in self.L_cache.

Type:

tuple[int, …]

L_cache#

Original L_cache argument (for diagnostics and external read-back).

Type:

int | Sequence[int]

omega_0#

Bandwidth / frequency scaling factor used for weight init and for diagnostics.

Type:

float

use_bias#

Whether a bias is present in the linear projection.

Type:

bool

linear#

The frozen random frequency projection W (and optionally b), shape [embedding_dim//2, data_dim].

Type:

torch.nn.Linear

grid_cache#

Non-persistent float32 buffer of shape [1, 2*L_0-1, ..., 2*L_{d-1}-1, data_dim].

Type:

torch.Tensor

step_sizes#

Per-axis grid step 1/(L_i - 1) at construction; used by cache extensions to keep spacing constant.

Type:

tuple[float, …]

__init__(
data_dim,
embedding_dim,
L_cache,
omega_0,
use_bias=True,
)#

Initialize the RandomFourierPositionalEmbeddingND.

Parameters:
  • data_dim (int) – Dimension of input data.

  • embedding_dim (int) – Dimensionality of the positional embedding. Must be even.

  • L_cache (int | Sequence[int]) – Number of cached time steps per axis. Either a scalar int (broadcast to every axis, isotropic grid) or a sequence of length data_dim (one extent per spatial axis, anisotropic grid). The cached grid then has shape (1, 2*L_0 - 1, ..., 2*L_{d-1} - 1, data_dim).

  • omega_0 (float) – Frequency scaling factor for the Fourier features.

  • use_bias (bool) – Whether to use a bias term in the linear layer.

Raises:

ValueError – If embedding_dim is not an even number.

forward(seq_lens)#

Compute the RFF positional embeddings for a given spatial grid.

Parameters:

seq_lens (tuple[int, ...]) – Per-axis output sequence lengths. Length must equal self.data_dim. For example, for a 2D signal of height H and width W, pass (H, W).

Returns:

  • torch.Tensor: The positional embeddings, [cos(Wx+b), sin(Wx+b)] concatenated along the last axis. Shape [1, *spatial_dims, embedding_dim].

  • torch.Tensor: The coordinate grid of positions normalised to [-1, 1] per axis. Shape [1, *spatial_dims, data_dim].

Return type:

tuple

Raises:
Parameters: