UNetConvNext#
- class UNetConvNext(
- dim_in,
- dim_out,
- n_spatial_dims,
- spatial_resolution=None,
- stages=4,
- blocks_per_stage=1,
- blocks_at_neck=1,
- init_features=32,
- gradient_checkpointing=False,
Bases:
ModuleUNet with ConvNeXt blocks — channels-first (NCHW / NCDHW) interface.
This is a faithful port of the_well.benchmark.models.unet_convnext.UNetConvNext. Input/output are channels-first tensors. For channels-last dict interface see
WellUNetConvNext.- Parameters:
dim_in (int) – Number of input channels.
dim_out (int) – Number of output channels.
n_spatial_dims (int) – 2 for images, 3 for volumes.
spatial_resolution (tuple[int, ...] | None) – Tuple of spatial sizes (used only for compat with BaseModel API).
stages (int) – Number of encoder/decoder stages (default 4).
blocks_per_stage (int) – ConvNeXt blocks per stage (default 1).
blocks_at_neck (int) – ConvNeXt blocks at bottleneck (default 1).
init_features (int) – Feature map width at the first stage (default 32).
gradient_checkpointing (bool) – Use activation checkpointing to save memory.
- __init__(
- dim_in,
- dim_out,
- n_spatial_dims,
- spatial_resolution=None,
- stages=4,
- blocks_per_stage=1,
- blocks_at_neck=1,
- init_features=32,
- gradient_checkpointing=False,
Build encoder-decoder with the given depth and width.
- forward(x)#
Forward pass.
- Parameters:
x (Tensor) – Channels-first input tensor
[B, C_in, *spatial].- Returns:
Channels-first output tensor
[B, C_out, *spatial].- Return type:
Note
Known bug (upstream):
skips[0](finest-resolution encoder features) is never used. With N encoder stages the decoder loop accessesskips[-1], skips[-2], ..., skips[-(N-1)]forj = 1, 2, ..., N-1, skippingskips[0]entirely. In a standard UNet the finest skip should connect to the last decoder stage. This matches the reference implementation inthe_well.benchmark.models.unet_convnext.UNetConvNext(v1.0.1) line-for-line, so we preserve it here for reproducibility. SeeUNetConvNextV2for a corrected version.