Unpatchify#
- class Unpatchify(
- in_features,
- out_features,
- data_dim,
- patch_size,
- stride=None,
- bias=True,
- weight_init='default',
Bases:
ModuleInverse patch-embedding layer: reconstruct spatial signal from token grid.
Unpatchifyis the trainable inverse ofPatchify. Given a grid of token embeddings at patch resolution, it reconstructs a signal at the original spatial resolution using a transposed convolution (ConvTranspose{data_dim}d).For non-overlapping patches (
stride == patch_size), the default transposed convolution is an exact spatial inverse: each output pixel is produced by exactly one input token. Whenstride < patch_size(overlapping), contributions from overlapping patches are summed by the transposed convolution — this is the linear adjoint (backward map) of the overlapping-patch forward pass, not a true inverse. Pixel values are accumulated rather than averaged, soUnpatchify(Patchify(x))does not recoverxexactly for overlapping patches; the output is a blurred, scaled version ofx. Only for non-overlapping patches (stride == patch_size) does the round-trip preserve spatial alignment (up to the learned weights).Output shape formula (each spatial axis
sof the patch-grid input):out_s = (s - 1) * stride - 2 * padding + kernel_size = (s - 1) * stride + patch_size (since padding == 0)
For the non-overlapping case this gives
s * patch_size.Layout convention — inputs and outputs use channels-last ordering:
input : [B, *patch_grid, C_embed] (e.g. [B, H/P, W/P, C_embed]) output : [B, *spatial_dims, C_out] (e.g. [B, H, W, C_out])
Weight initialisation — PyTorch’s default kaiming_uniform for
ConvTransposeusesfan_out = out_features * patch_size ** data_dim. This is incorrect for large embedding dimensions;weight_init="fan_in"corrects this by using the true fan-inin_features * patch_size ** data_dim.- Parameters:
- deconv#
The underlying deconvolution.
- Type:
torch.nn.ConvTranspose{data_dim}d
- __init__(
- in_features,
- out_features,
- data_dim,
- patch_size,
- stride=None,
- bias=True,
- weight_init='default',
Initialise the Unpatchify layer.
- Parameters:
in_features (int) – Embedding dimension
C_embedof each input token.out_features (int) – Number of output channels
C_outof the reconstructed signal (e.g. 3 for RGB images).data_dim (int) – Spatial dimensionality. Must be 1, 2, or 3.
patch_size (int) – Side length
Pof each patch. The transposed convolution useskernel_size = patch_sizealong every spatial axis.stride (int | None) – Step between consecutive output patch origins. Defaults to
patch_size(non-overlapping, exact inverse ofPatchifywith default stride). Must match thestrideused in the pairedPatchifylayer to recover the original spatial resolution.bias (bool) – If
True(default), the deconvolution includes a learnable bias term.weight_init (Literal['default', 'zeros', 'fan_in']) – Weight initialisation strategy for the deconv kernel.
"default"uses PyTorch’s built-inkaiming_uniform(fan computed fromout_features; can cause output-variance blow-up for largein_features; retained primarily for loading pre-trained checkpoints whose weights were saved under PyTorch’s default init — prefer"fan_in"for new architectures)."zeros"zero-inits weights and bias (DiT-style; output is exactly zero at initialisation, safe for residual-stream entry)."fan_in"applies Kaiming-uniform with the corrected fan-inin_features * patch_size ** data_dim, giving output variance O(1) regardless of embedding dimension.
- Raises:
ValueError – If
data_dimis not 1, 2, or 3.
- forward(x, output_spatial_shape=None)#
Reconstruct the spatial signal from a grid of patch-token embeddings.
- Parameters:
x (Tensor) – Token-grid tensor in channels-last layout. Shape:
[B, *patch_grid, C_embed], e.g.[B, H/P, W/P, C_embed]for 2D. The number of spatial dimensions must equaldata_dim.output_spatial_shape (Tuple[int, ...] | None) – When
stride > 1, multiple patch-grid sizes map to the same output size (the floor in the forward direction discards remainders). Passoutput_spatial_shapeto resolve this ambiguity and guarantee recovery of the exact original spatial dimensions. Must have lengthdata_dim. WhenNone, PyTorch infers the output size and it may not match the original spatial size ifspatial_dim % patch_size != 0.
- Returns:
[B, *spatial_dims, C_out]. Withoutoutput_spatial_shape, each axissof the patch grid expands to(s - 1) * stride + patch_size.- Return type:
Reconstructed signal tensor in channels-last layout. Shape
- Raises:
AssertionError – If the rank of
xdoes not equaldata_dim + 2(batch + spatial + channel dims).
Note
.contiguous()is called after the rearrangement to channels-first to avoid a stride-mismatch error intorch.compile’sconvolution_backward.