UNetConvNextV2#

class UNetConvNextV2(
dim_in,
dim_out,
n_spatial_dims,
spatial_resolution=None,
stages=4,
blocks_per_stage=1,
blocks_at_neck=1,
init_features=32,
gradient_checkpointing=False,
)#

Bases: Module

UNet-ConvNeXt with corrected skip connections.

Fixes the upstream bug where skips[0] (finest-resolution encoder features) is never consumed. The decoder loop is identical to the original — first stage upsamples without a skip, stages 1..N-1 consume skips[-1], ..., skips[-(N-1)] — but after all decoder stages we concatenate skips[0] (full-resolution) and project before out_proj.

Parameters:
  • dim_in (int) – Number of input channels.

  • dim_out (int) – Number of output channels.

  • n_spatial_dims (int) – 2 for images, 3 for volumes.

  • spatial_resolution (tuple[int, ...] | None) – Unused, kept for API compatibility.

  • stages (int) – Number of encoder/decoder stages (default 4).

  • blocks_per_stage (int) – ConvNeXt blocks per stage (default 1).

  • blocks_at_neck (int) – ConvNeXt blocks at bottleneck (default 1).

  • init_features (int) – Feature map width at the first stage (default 32).

  • gradient_checkpointing (bool) – Use activation checkpointing to save memory.

__init__(
dim_in,
dim_out,
n_spatial_dims,
spatial_resolution=None,
stages=4,
blocks_per_stage=1,
blocks_at_neck=1,
init_features=32,
gradient_checkpointing=False,
)#

Build encoder-decoder with corrected skip wiring.

Parameters:
  • dim_in (int)

  • dim_out (int)

  • n_spatial_dims (int)

  • spatial_resolution (tuple[int, ...] | None)

  • stages (int)

  • blocks_per_stage (int)

  • blocks_at_neck (int)

  • init_features (int)

  • gradient_checkpointing (bool)

forward(x)#

Forward pass with all skip connections used.

The decoder loop is identical to the original (j=0 upsamples without skip, j=1..N-1 consume skips[-1]..skips[-(N-1)]). The V2 fix adds a final concatenation with skips[0] at the original resolution.

Parameters:

x (Tensor) – Channels-first input tensor [B, C_in, *spatial].

Returns:

Channels-first output tensor [B, C_out, *spatial].

Return type:

Tensor