PositionEmbeddingND#

class PositionEmbeddingND(embedding_dim, data_dim, max_dim_lengths)#

Bases: Module

Axis-factorised learned positional encoding for ND spatial token grids.

Each spatial axis has its own nn.Embedding table of shape (max_dim_lengths[d], embedding_dim // data_dim). For a grid position (i_0, i_1, ..., i_{D-1}) the encoding is formed by looking up the per-axis embeddings and concatenating them channel-wise:

PE(i_0, ..., i_{D-1}) = concat(E_0(i_0), E_1(i_1), ..., E_{D-1}(i_{D-1}))

where concat denotes concatenation along the channel dimension and each E_d R^{max_dim_lengths[d] × (embedding_dim // data_dim)} is a learned embedding matrix. The result has length embedding_dim (requires embedding_dim % data_dim == 0).

This factorised form is a separable positional encoding: the position information for axis d is captured entirely in the slice of channels [d * per_dim, (d+1) * per_dim). It is not a joint encoding of the full ND coordinate (unlike 2D sinusoidal encodings that mix axes), so cross-axis interactions must be learned by the mixer layers themselves.

Output shape:

forward(x) -> Tensor of shape [B, *spatial_dims, embedding_dim]

The returned tensor is a broadcast-expanded encoding grid with the same shape as the input x, and is typically added to x before the first mixer block:

x = x + position_embedding(x)

Parameter countdata_dim embedding tables, each of size max_dim_lengths[d] × (embedding_dim // data_dim):

total = sum(max_dim_lengths[d] * (embedding_dim // data_dim)
            for d in range(data_dim))

No weight decay — all embedding parameters are tagged param._no_weight_decay = True so that optimiser builders (e.g. those using param._no_weight_decay to separate param groups) can exclude them from L2 regularisation, following the standard ViT practice.

Parameters:
  • embedding_dim (int)

  • data_dim (int)

  • max_dim_lengths (Sequence[int])

embedding_dim#

Total embedding dimension of the output encoding.

Type:

int

per_dim_embedding_dim#

Per-axis slice width, embedding_dim // data_dim.

Type:

int

data_dim#

Number of spatial axes (1, 2, or 3).

Type:

int

max_dim_lengths#

Maximum supported grid size for each spatial axis. Length equals data_dim.

Type:

tuple[int, …]

data_embeddings#

Dictionary mapping axis keys {"x"} / {"x", "y"} / {"x", "y", "z"} to the corresponding nn.Embedding modules.

Type:

nn.ModuleDict

__init__(embedding_dim, data_dim, max_dim_lengths)#

Initialise per-axis embedding tables.

Parameters:
  • embedding_dim (int) – Total number of channels in the output position encoding. Must be divisible by data_dim so that the budget can be split evenly across axes.

  • data_dim (int) – Number of spatial axes of the input token grid. Must satisfy 1 <= data_dim <= 3.

  • max_dim_lengths (Sequence[int]) – Maximum grid size for each spatial axis. Must have exactly data_dim entries. An nn.Embedding table of length max_dim_lengths[d] is allocated for axis d; inputs whose spatial size along axis d exceeds this value will raise a ValueError in forward.

Raises:
forward(x)#

Compute the position encoding grid for the given input token grid.

For each spatial axis d, position indices 0, 1, ..., L_d - 1 are looked up in the corresponding nn.Embedding table to produce a 1-D embedding slice of shape [L_d, per_dim_embedding_dim]. That slice is broadcast-expanded to the full spatial shape [B, *spatial_dims, per_dim_embedding_dim] and the per-axis tensors are concatenated along the last dimension to give the final encoding of shape [B, *spatial_dims, embedding_dim].

The returned tensor has the same shape as x and should be added to x:

x = x + position_embedding(x)
Parameters:

x (Tensor) –

Input token-grid tensor in channels-last layout. Shape: [B, *spatial_dims, embedding_dim], where the number of spatial axes must equal data_dim and the last dimension must equal embedding_dim. For example:

  • 1-D (sequences): [B, L, C]

  • 2-D (images): [B, H, W, C]

  • 3-D (volumes): [B, D, H, W, C]

Returns:

Position encoding tensor of shape [B, *spatial_dims, embedding_dim] — the same shape as x. Each spatial location (i_0, ..., i_{D-1}) holds the concatenation of the per-axis embedding lookups:

out[b, i_0, ..., i_{D-1}, :] =
    concat(E_0(i_0), E_1(i_1), ..., E_{D-1}(i_{D-1}))

Internally, for each axis d the 1-D embedding of shape [L_d, per_dim_embedding_dim] is reshaped to [1, ..., L_d, ..., 1, per_dim_embedding_dim] (singleton in all axes except axis d) and then broadcast-expanded to [B, *spatial_dims, per_dim_embedding_dim] before the per-axis tensors are concatenated along the last dimension.

Return type:

Tensor

Note

The returned tensor has the same dtype as the embedding weight parameters (typically torch.float32 regardless of the input dtype). In mixed-precision training cast the result before adding it to the token grid:

x = x + pos_enc(x).to(x.dtype)
Raises:
  • ValueError – If x.ndim != data_dim + 2 (wrong number of dimensions — expected batch + data_dim spatial + channel).

  • ValueError – If x.shape[-1] != self.embedding_dim (channel count of the input does not match the embedding_dim passed at construction time).

  • ValueError – If any spatial dimension of x exceeds the corresponding entry of max_dim_lengths.

Parameters:

x (Tensor)

Return type:

Tensor