UNetConvNext#

class UNetConvNext(
dim_in,
dim_out,
n_spatial_dims,
spatial_resolution=None,
stages=4,
blocks_per_stage=1,
blocks_at_neck=1,
init_features=32,
gradient_checkpointing=False,
)#

Bases: Module

UNet with ConvNeXt blocks — channels-first (NCHW / NCDHW) interface.

This is a faithful port of the_well.benchmark.models.unet_convnext.UNetConvNext. Input/output are channels-first tensors. For channels-last dict interface see WellUNetConvNext.

Parameters:
  • dim_in (int) – Number of input channels.

  • dim_out (int) – Number of output channels.

  • n_spatial_dims (int) – 2 for images, 3 for volumes.

  • spatial_resolution (tuple[int, ...] | None) – Tuple of spatial sizes (used only for compat with BaseModel API).

  • stages (int) – Number of encoder/decoder stages (default 4).

  • blocks_per_stage (int) – ConvNeXt blocks per stage (default 1).

  • blocks_at_neck (int) – ConvNeXt blocks at bottleneck (default 1).

  • init_features (int) – Feature map width at the first stage (default 32).

  • gradient_checkpointing (bool) – Use activation checkpointing to save memory.

__init__(
dim_in,
dim_out,
n_spatial_dims,
spatial_resolution=None,
stages=4,
blocks_per_stage=1,
blocks_at_neck=1,
init_features=32,
gradient_checkpointing=False,
)#

Build encoder-decoder with the given depth and width.

Parameters:
  • dim_in (int)

  • dim_out (int)

  • n_spatial_dims (int)

  • spatial_resolution (tuple[int, ...] | None)

  • stages (int)

  • blocks_per_stage (int)

  • blocks_at_neck (int)

  • init_features (int)

  • gradient_checkpointing (bool)

forward(x)#

Forward pass.

Parameters:

x (Tensor) – Channels-first input tensor [B, C_in, *spatial].

Returns:

Channels-first output tensor [B, C_out, *spatial].

Return type:

Tensor

Note

Known bug (upstream): skips[0] (finest-resolution encoder features) is never used. With N encoder stages the decoder loop accesses skips[-1], skips[-2], ..., skips[-(N-1)] for j = 1, 2, ..., N-1, skipping skips[0] entirely. In a standard UNet the finest skip should connect to the last decoder stage. This matches the reference implementation in the_well.benchmark.models.unet_convnext.UNetConvNext (v1.0.1) line-for-line, so we preserve it here for reproducibility. See UNetConvNextV2 for a corrected version.