UNetConvNextV2#
- class UNetConvNextV2(
- dim_in,
- dim_out,
- n_spatial_dims,
- spatial_resolution=None,
- stages=4,
- blocks_per_stage=1,
- blocks_at_neck=1,
- init_features=32,
- gradient_checkpointing=False,
Bases:
ModuleUNet-ConvNeXt with corrected skip connections.
Fixes the upstream bug where
skips[0](finest-resolution encoder features) is never consumed. The decoder loop is identical to the original — first stage upsamples without a skip, stages 1..N-1 consumeskips[-1], ..., skips[-(N-1)]— but after all decoder stages we concatenateskips[0](full-resolution) and project beforeout_proj.- Parameters:
dim_in (int) – Number of input channels.
dim_out (int) – Number of output channels.
n_spatial_dims (int) – 2 for images, 3 for volumes.
spatial_resolution (tuple[int, ...] | None) – Unused, kept for API compatibility.
stages (int) – Number of encoder/decoder stages (default 4).
blocks_per_stage (int) – ConvNeXt blocks per stage (default 1).
blocks_at_neck (int) – ConvNeXt blocks at bottleneck (default 1).
init_features (int) – Feature map width at the first stage (default 32).
gradient_checkpointing (bool) – Use activation checkpointing to save memory.
- __init__(
- dim_in,
- dim_out,
- n_spatial_dims,
- spatial_resolution=None,
- stages=4,
- blocks_per_stage=1,
- blocks_at_neck=1,
- init_features=32,
- gradient_checkpointing=False,
Build encoder-decoder with corrected skip wiring.
- forward(x)#
Forward pass with all skip connections used.
The decoder loop is identical to the original (j=0 upsamples without skip, j=1..N-1 consume skips[-1]..skips[-(N-1)]). The V2 fix adds a final concatenation with skips[0] at the original resolution.