PositionEmbeddingND#
- class PositionEmbeddingND(embedding_dim, data_dim, max_dim_lengths)#
Bases:
ModuleAxis-factorised learned positional encoding for ND spatial token grids.
Each spatial axis has its own
nn.Embeddingtable of shape(max_dim_lengths[d], embedding_dim // data_dim). For a grid position(i_0, i_1, ..., i_{D-1})the encoding is formed by looking up the per-axis embeddings and concatenating them channel-wise:PE(i_0, ..., i_{D-1}) = concat(E_0(i_0), E_1(i_1), ..., E_{D-1}(i_{D-1}))
where
concatdenotes concatenation along the channel dimension and eachE_d ∈ R^{max_dim_lengths[d] × (embedding_dim // data_dim)}is a learned embedding matrix. The result has lengthembedding_dim(requiresembedding_dim % data_dim == 0).This factorised form is a separable positional encoding: the position information for axis
dis captured entirely in the slice of channels[d * per_dim, (d+1) * per_dim). It is not a joint encoding of the full ND coordinate (unlike 2D sinusoidal encodings that mix axes), so cross-axis interactions must be learned by the mixer layers themselves.Output shape:
forward(x) -> Tensor of shape [B, *spatial_dims, embedding_dim]
The returned tensor is a broadcast-expanded encoding grid with the same shape as the input
x, and is typically added toxbefore the first mixer block:x = x + position_embedding(x)
Parameter count —
data_dimembedding tables, each of sizemax_dim_lengths[d] × (embedding_dim // data_dim):total = sum(max_dim_lengths[d] * (embedding_dim // data_dim) for d in range(data_dim))
No weight decay — all embedding parameters are tagged
param._no_weight_decay = Trueso that optimiser builders (e.g. those usingparam._no_weight_decayto separate param groups) can exclude them from L2 regularisation, following the standard ViT practice.- max_dim_lengths#
Maximum supported grid size for each spatial axis. Length equals
data_dim.
- data_embeddings#
Dictionary mapping axis keys
{"x"}/{"x", "y"}/{"x", "y", "z"}to the correspondingnn.Embeddingmodules.- Type:
nn.ModuleDict
- __init__(embedding_dim, data_dim, max_dim_lengths)#
Initialise per-axis embedding tables.
- Parameters:
embedding_dim (int) – Total number of channels in the output position encoding. Must be divisible by
data_dimso that the budget can be split evenly across axes.data_dim (int) – Number of spatial axes of the input token grid. Must satisfy
1 <= data_dim <= 3.max_dim_lengths (Sequence[int]) – Maximum grid size for each spatial axis. Must have exactly
data_dimentries. Annn.Embeddingtable of lengthmax_dim_lengths[d]is allocated for axisd; inputs whose spatial size along axisdexceeds this value will raise aValueErrorinforward.
- Raises:
ValueError – If
data_dim < 1.ValueError – If
data_dim > 3.ValueError – If
len(max_dim_lengths) != data_dim.ValueError – If
embedding_dim % data_dim != 0.
- forward(x)#
Compute the position encoding grid for the given input token grid.
For each spatial axis
d, position indices0, 1, ..., L_d - 1are looked up in the correspondingnn.Embeddingtable to produce a 1-D embedding slice of shape[L_d, per_dim_embedding_dim]. That slice is broadcast-expanded to the full spatial shape[B, *spatial_dims, per_dim_embedding_dim]and the per-axis tensors are concatenated along the last dimension to give the final encoding of shape[B, *spatial_dims, embedding_dim].The returned tensor has the same shape as
xand should be added tox:x = x + position_embedding(x)
- Parameters:
x (Tensor) –
Input token-grid tensor in channels-last layout. Shape:
[B, *spatial_dims, embedding_dim], where the number of spatial axes must equaldata_dimand the last dimension must equalembedding_dim. For example:1-D (sequences):
[B, L, C]2-D (images):
[B, H, W, C]3-D (volumes):
[B, D, H, W, C]
- Returns:
Position encoding tensor of shape
[B, *spatial_dims, embedding_dim]— the same shape asx. Each spatial location(i_0, ..., i_{D-1})holds the concatenation of the per-axis embedding lookups:out[b, i_0, ..., i_{D-1}, :] = concat(E_0(i_0), E_1(i_1), ..., E_{D-1}(i_{D-1}))
Internally, for each axis
dthe 1-D embedding of shape[L_d, per_dim_embedding_dim]is reshaped to[1, ..., L_d, ..., 1, per_dim_embedding_dim](singleton in all axes except axisd) and then broadcast-expanded to[B, *spatial_dims, per_dim_embedding_dim]before the per-axis tensors are concatenated along the last dimension.- Return type:
Note
The returned tensor has the same
dtypeas the embedding weight parameters (typicallytorch.float32regardless of the input dtype). In mixed-precision training cast the result before adding it to the token grid:x = x + pos_enc(x).to(x.dtype)
- Raises:
ValueError – If
x.ndim != data_dim + 2(wrong number of dimensions — expected batch +data_dimspatial + channel).ValueError – If
x.shape[-1] != self.embedding_dim(channel count of the input does not match theembedding_dimpassed at construction time).ValueError – If any spatial dimension of
xexceeds the corresponding entry ofmax_dim_lengths.
- Parameters:
x (Tensor)
- Return type: