Patchify#

class Patchify(
in_features,
out_features,
data_dim,
patch_size,
stride=None,
bias=True,
)#

Bases: Module

Conv-based patch embedding for ND spatial signals (channels-last I/O).

Splits the spatial axes of the input into a regular grid of non-overlapping patches and linearly projects each patch into an embedding vector. The operation is equivalent to:

  1. Unfold every patch_size ** data_dim pixel neighbourhood into a vector of length C_in * patch_size ** data_dim.

  2. Apply a learned linear map from that vector to C_out dimensions.

Because the unfold and linear projection can be fused into a single strided convolution, this class simply wraps torch.nn.Conv{data_dim}d with kernel_size = patch_size, stride = stride, and padding = 0.

Output shape formula (each spatial axis s independently):

out_s = floor((s - patch_size) / stride) + 1

For the default non-overlapping case (stride == patch_size) this reduces to s // patch_size (assuming s is divisible by patch_size).

Warning

If spatial_dim % patch_size != 0, the last pixels in that axis are silently discarded (standard floor-division Conv semantics). Callers are responsible for ensuring spatial dimensions are divisible by patch_size before calling this layer (e.g. by padding the input).

Layout convention — inputs and outputs use channels-last ordering:

input  : [B, *spatial_dims, C_in]     (e.g. [B, H, W, C_in] for 2D)
output : [B, *patch_grid, C_out]      (e.g. [B, H/P, W/P, C_out])

Internally, the tensor is transposed to channels-first before the Conv and back to channels-last before returning, to match the layout expected by PositionEmbeddingND and the mixer blocks.

Overlapping patches — setting stride < patch_size produces overlapping patches with the same formula above. This is less common in ViT-style models but is supported.

Examples

1D sequence (data_dim=1):

layer = Patchify(in_features=64, out_features=128, data_dim=1, patch_size=4)
x = torch.randn(2, 256, 64)   # [B, L, C_in]
y = layer(x)                  # [B, L/4, 128] == [2, 64, 128]

2D image (data_dim=2) — see the __main__ block for a runnable demo:

layer = Patchify(in_features=3, out_features=768, data_dim=2, patch_size=16)
x = torch.randn(8, 224, 224, 3)    # [B, H, W, C_in]
y = layer(x)                        # [B, 14, 14, 768]

3D volume (data_dim=3):

layer = Patchify(in_features=1, out_features=256, data_dim=3, patch_size=8)
x = torch.randn(2, 64, 64, 64, 1)  # [B, D, H, W, C_in]
y = layer(x)                         # [B, 8, 8, 8, 256]
Parameters:
  • in_features (int)

  • out_features (int)

  • data_dim (int)

  • patch_size (int)

  • stride (int | None)

  • bias (bool)

data_dim#

Spatial dimensionality (1, 2, or 3).

Type:

int

patch_size#

Receptive field size of each patch along every axis.

Type:

int

stride#

Step between successive patch origins along every axis.

Type:

int

conv#

The underlying strided convolution.

Type:

torch.nn.Conv{data_dim}d

__init__(
in_features,
out_features,
data_dim,
patch_size,
stride=None,
bias=True,
)#

Initialise the Patchify layer.

Parameters:
  • in_features (int) – Number of input channels C_in (e.g. 3 for RGB).

  • out_features (int) – Embedding dimension C_out of each output token.

  • data_dim (int) – Spatial dimensionality of the input signal. Must be 1 (sequences), 2 (images), or 3 (volumes).

  • patch_size (int) – Side length P of each patch. The convolution uses kernel_size = patch_size along every spatial axis, so each patch covers P ** data_dim input pixels (P × P for 2D images, P × P × P voxels for 3D volumes).

  • stride (int | None) – Step size between consecutive patch origins along every spatial axis. Defaults to patch_size, giving non-overlapping ViT-style patches. Set to a smaller value for overlapping patches (denser token grids at the cost of more tokens).

  • bias (bool) – If True (default), the projection conv includes a learnable bias. Set to False for bias-free architectures (e.g. when a subsequent normalisation layer makes bias redundant).

Raises:

ValueError – If data_dim is not 1, 2, or 3.

forward(x)#

Embed the input tensor into a grid of patch tokens.

Parameters:

x (Tensor) – Input tensor in channels-last layout. Shape: [B, *spatial_dims, C_in], e.g. [B, H, W, C_in] for 2D images.

Returns:

[B, *patch_grid, C_out], where each spatial axis s is reduced to floor((s - patch_size) / stride) + 1 (equal to s // patch_size when stride == patch_size and s is divisible by patch_size).

Return type:

Patch-embedded tensor in channels-last layout. Shape

Note

.contiguous() is called after the channels-last → channels-first rearrangement to avoid a stride-mismatch error in torch.compile’s convolution_backward.