Unpatchify#

class Unpatchify(
in_features,
out_features,
data_dim,
patch_size,
stride=None,
bias=True,
weight_init='default',
)#

Bases: Module

Inverse patch-embedding layer: reconstruct spatial signal from token grid.

Unpatchify is the trainable inverse of Patchify. Given a grid of token embeddings at patch resolution, it reconstructs a signal at the original spatial resolution using a transposed convolution (ConvTranspose{data_dim}d).

For non-overlapping patches (stride == patch_size), the default transposed convolution is an exact spatial inverse: each output pixel is produced by exactly one input token. When stride < patch_size (overlapping), contributions from overlapping patches are summed by the transposed convolution — this is the linear adjoint (backward map) of the overlapping-patch forward pass, not a true inverse. Pixel values are accumulated rather than averaged, so Unpatchify(Patchify(x)) does not recover x exactly for overlapping patches; the output is a blurred, scaled version of x. Only for non-overlapping patches (stride == patch_size) does the round-trip preserve spatial alignment (up to the learned weights).

Output shape formula (each spatial axis s of the patch-grid input):

out_s = (s - 1) * stride - 2 * padding + kernel_size
      = (s - 1) * stride + patch_size          (since padding == 0)

For the non-overlapping case this gives s * patch_size.

Layout convention — inputs and outputs use channels-last ordering:

input  : [B, *patch_grid, C_embed]    (e.g. [B, H/P, W/P, C_embed])
output : [B, *spatial_dims, C_out]    (e.g. [B, H, W, C_out])

Weight initialisation — PyTorch’s default kaiming_uniform for ConvTranspose uses fan_out = out_features * patch_size ** data_dim. This is incorrect for large embedding dimensions; weight_init="fan_in" corrects this by using the true fan-in in_features * patch_size ** data_dim.

Parameters:
  • in_features (int)

  • out_features (int)

  • data_dim (int)

  • patch_size (int)

  • stride (int | None)

  • bias (bool)

  • weight_init (Literal['default', 'zeros', 'fan_in'])

data_dim#

Spatial dimensionality (1, 2, or 3).

Type:

int

patch_size#

Kernel size of the transposed convolution.

Type:

int

stride#

Stride of the transposed convolution.

Type:

int

deconv#

The underlying deconvolution.

Type:

torch.nn.ConvTranspose{data_dim}d

__init__(
in_features,
out_features,
data_dim,
patch_size,
stride=None,
bias=True,
weight_init='default',
)#

Initialise the Unpatchify layer.

Parameters:
  • in_features (int) – Embedding dimension C_embed of each input token.

  • out_features (int) – Number of output channels C_out of the reconstructed signal (e.g. 3 for RGB images).

  • data_dim (int) – Spatial dimensionality. Must be 1, 2, or 3.

  • patch_size (int) – Side length P of each patch. The transposed convolution uses kernel_size = patch_size along every spatial axis.

  • stride (int | None) – Step between consecutive output patch origins. Defaults to patch_size (non-overlapping, exact inverse of Patchify with default stride). Must match the stride used in the paired Patchify layer to recover the original spatial resolution.

  • bias (bool) – If True (default), the deconvolution includes a learnable bias term.

  • weight_init (Literal['default', 'zeros', 'fan_in']) – Weight initialisation strategy for the deconv kernel. "default" uses PyTorch’s built-in kaiming_uniform (fan computed from out_features; can cause output-variance blow-up for large in_features; retained primarily for loading pre-trained checkpoints whose weights were saved under PyTorch’s default init — prefer "fan_in" for new architectures). "zeros" zero-inits weights and bias (DiT-style; output is exactly zero at initialisation, safe for residual-stream entry). "fan_in" applies Kaiming-uniform with the corrected fan-in in_features * patch_size ** data_dim, giving output variance O(1) regardless of embedding dimension.

Raises:

ValueError – If data_dim is not 1, 2, or 3.

forward(x, output_spatial_shape=None)#

Reconstruct the spatial signal from a grid of patch-token embeddings.

Parameters:
  • x (Tensor) – Token-grid tensor in channels-last layout. Shape: [B, *patch_grid, C_embed], e.g. [B, H/P, W/P, C_embed] for 2D. The number of spatial dimensions must equal data_dim.

  • output_spatial_shape (Tuple[int, ...] | None) – When stride > 1, multiple patch-grid sizes map to the same output size (the floor in the forward direction discards remainders). Pass output_spatial_shape to resolve this ambiguity and guarantee recovery of the exact original spatial dimensions. Must have length data_dim. When None, PyTorch infers the output size and it may not match the original spatial size if spatial_dim % patch_size != 0.

Returns:

[B, *spatial_dims, C_out]. Without output_spatial_shape, each axis s of the patch grid expands to (s - 1) * stride + patch_size.

Return type:

Reconstructed signal tensor in channels-last layout. Shape

Raises:

AssertionError – If the rank of x does not equal data_dim + 2 (batch + spatial + channel dims).

Note

.contiguous() is called after the rearrangement to channels-first to avoid a stride-mismatch error in torch.compile’s convolution_backward.