Patchify#
- class Patchify(
- in_features,
- out_features,
- data_dim,
- patch_size,
- stride=None,
- bias=True,
Bases:
ModuleConv-based patch embedding for ND spatial signals (channels-last I/O).
Splits the spatial axes of the input into a regular grid of non-overlapping patches and linearly projects each patch into an embedding vector. The operation is equivalent to:
Unfold every
patch_size ** data_dimpixel neighbourhood into a vector of lengthC_in * patch_size ** data_dim.Apply a learned linear map from that vector to
C_outdimensions.
Because the unfold and linear projection can be fused into a single strided convolution, this class simply wraps
torch.nn.Conv{data_dim}dwithkernel_size = patch_size,stride = stride, andpadding = 0.Output shape formula (each spatial axis
sindependently):out_s = floor((s - patch_size) / stride) + 1
For the default non-overlapping case (
stride == patch_size) this reduces tos // patch_size(assumingsis divisible bypatch_size).Warning
If
spatial_dim % patch_size != 0, the last pixels in that axis are silently discarded (standard floor-division Conv semantics). Callers are responsible for ensuring spatial dimensions are divisible bypatch_sizebefore calling this layer (e.g. by padding the input).Layout convention — inputs and outputs use channels-last ordering:
input : [B, *spatial_dims, C_in] (e.g. [B, H, W, C_in] for 2D) output : [B, *patch_grid, C_out] (e.g. [B, H/P, W/P, C_out])
Internally, the tensor is transposed to channels-first before the Conv and back to channels-last before returning, to match the layout expected by
PositionEmbeddingNDand the mixer blocks.Overlapping patches — setting
stride < patch_sizeproduces overlapping patches with the same formula above. This is less common in ViT-style models but is supported.Examples
1D sequence (
data_dim=1):layer = Patchify(in_features=64, out_features=128, data_dim=1, patch_size=4) x = torch.randn(2, 256, 64) # [B, L, C_in] y = layer(x) # [B, L/4, 128] == [2, 64, 128]
2D image (
data_dim=2) — see the__main__block for a runnable demo:layer = Patchify(in_features=3, out_features=768, data_dim=2, patch_size=16) x = torch.randn(8, 224, 224, 3) # [B, H, W, C_in] y = layer(x) # [B, 14, 14, 768]
3D volume (
data_dim=3):layer = Patchify(in_features=1, out_features=256, data_dim=3, patch_size=8) x = torch.randn(2, 64, 64, 64, 1) # [B, D, H, W, C_in] y = layer(x) # [B, 8, 8, 8, 256]
- Parameters:
- conv#
The underlying strided convolution.
- Type:
torch.nn.Conv{data_dim}d
- __init__(
- in_features,
- out_features,
- data_dim,
- patch_size,
- stride=None,
- bias=True,
Initialise the Patchify layer.
- Parameters:
in_features (int) – Number of input channels
C_in(e.g. 3 for RGB).out_features (int) – Embedding dimension
C_outof each output token.data_dim (int) – Spatial dimensionality of the input signal. Must be 1 (sequences), 2 (images), or 3 (volumes).
patch_size (int) – Side length
Pof each patch. The convolution useskernel_size = patch_sizealong every spatial axis, so each patch coversP ** data_diminput pixels (P × Pfor 2D images,P × P × Pvoxels for 3D volumes).stride (int | None) – Step size between consecutive patch origins along every spatial axis. Defaults to
patch_size, giving non-overlapping ViT-style patches. Set to a smaller value for overlapping patches (denser token grids at the cost of more tokens).bias (bool) – If
True(default), the projection conv includes a learnable bias. Set toFalsefor bias-free architectures (e.g. when a subsequent normalisation layer makes bias redundant).
- Raises:
ValueError – If
data_dimis not 1, 2, or 3.
- forward(x)#
Embed the input tensor into a grid of patch tokens.
- Parameters:
x (Tensor) – Input tensor in channels-last layout. Shape:
[B, *spatial_dims, C_in], e.g.[B, H, W, C_in]for 2D images.- Returns:
[B, *patch_grid, C_out], where each spatial axissis reduced tofloor((s - patch_size) / stride) + 1(equal tos // patch_sizewhenstride == patch_sizeandsis divisible bypatch_size).- Return type:
Patch-embedded tensor in channels-last layout. Shape
Note
.contiguous()is called after the channels-last → channels-first rearrangement to avoid a stride-mismatch error intorch.compile’sconvolution_backward.