RandomFourierPositionalEmbeddingND#
- class RandomFourierPositionalEmbeddingND(
- data_dim,
- embedding_dim,
- L_cache,
- omega_0,
- use_bias=True,
Bases:
ModuleN-dimensional positional embedding using Random Fourier Features (RFF).
Mathematical form#
Given a coordinate grid
xof shape[1, *spatial_dims, data_dim]with values normalised to[-1, 1]per axis, the embedding is:phi(x) = [ cos(W x + b), sin(W x + b) ] shape […, embedding_dim]
where:
Wis the first-layer weight matrix of shape[embedding_dim//2, data_dim], drawn once at construction fromN(0, (2*pi*omega_0)^2)and then frozen (not trained).bis a bias vector of shape[embedding_dim//2], initialised to zero. It is also frozen.The concatenation of cosine and sine doubles the embedding dimension.
The resulting features approximate the feature map of a stationary RBF (Gaussian) kernel with bandwidth
omega_0— the largeromega_0, the higher the dominant spatial frequency encoded in the embedding.Grid caching#
To avoid rebuilding the meshgrid on every forward pass, the module maintains a
grid_cachebuffer (a pre-computed coordinate tensor of shape[1, 2*L_0-1, ..., 2*L_{d-1}-1, data_dim]in float32). On each forward call the central[2*seq_len_i - 1]points are sliced per axis. When a largerseq_lenis seen at runtime the cache grows automatically via_maybe_extend_grid_cache, preserving the original step size on each axis.Note: the
Wandbparameters have_no_weight_decay = Trueset so that any weight-decay optimizer does not shrink the random projection.- embedding_dim#
Output embedding size (must be even; split equally between cos and sin features).
- Type:
- L_cache_per_axis#
Current per-axis cache extents. May grow at runtime; the original value at construction is stored in
self.L_cache.
- L_cache#
Original
L_cacheargument (for diagnostics and external read-back).
- linear#
The frozen random frequency projection
W(and optionallyb), shape[embedding_dim//2, data_dim].- Type:
- grid_cache#
Non-persistent float32 buffer of shape
[1, 2*L_0-1, ..., 2*L_{d-1}-1, data_dim].- Type:
- step_sizes#
Per-axis grid step
1/(L_i - 1)at construction; used by cache extensions to keep spacing constant.
- __init__(
- data_dim,
- embedding_dim,
- L_cache,
- omega_0,
- use_bias=True,
Initialize the RandomFourierPositionalEmbeddingND.
- Parameters:
data_dim (int) – Dimension of input data.
embedding_dim (int) – Dimensionality of the positional embedding. Must be even.
L_cache (int | Sequence[int]) – Number of cached time steps per axis. Either a scalar int (broadcast to every axis, isotropic grid) or a sequence of length
data_dim(one extent per spatial axis, anisotropic grid). The cached grid then has shape(1, 2*L_0 - 1, ..., 2*L_{d-1} - 1, data_dim).omega_0 (float) – Frequency scaling factor for the Fourier features.
use_bias (bool) – Whether to use a bias term in the linear layer.
- Raises:
ValueError – If embedding_dim is not an even number.
- forward(seq_lens)#
Compute the RFF positional embeddings for a given spatial grid.
- Parameters:
seq_lens (tuple[int, ...]) – Per-axis output sequence lengths. Length must equal
self.data_dim. For example, for a 2D signal of height H and width W, pass(H, W).- Returns:
torch.Tensor: The positional embeddings,
[cos(Wx+b), sin(Wx+b)]concatenated along the last axis. Shape[1, *spatial_dims, embedding_dim].torch.Tensor: The coordinate grid of positions normalised to
[-1, 1]per axis. Shape[1, *spatial_dims, data_dim].
- Return type:
- Raises:
AssertionError – If
len(seq_lens) != self.data_dim.AssertionError – If
self.grid_cacheis notfloat32.