WellUNetConvNext#

class WellUNetConvNext(**kwargs)#

Bases: Module

UNet-ConvNeXt with the dict-based channels-last interface expected by WELLRegressionWrapper.

Input: {"input": [B, *spatial, C_in], "condition": None} Output: {"logits": [B, *spatial, C_out]}

All internal computation is channels-first; this wrapper only transposes at the boundary.

Constructor args are forwarded to UNetConvNext.

__init__(**kwargs)#

Initialize by forwarding all kwargs to UNetConvNext.

forward(input_and_condition)#

Transpose to channels-first, run UNet, transpose back.

Parameters:

input_and_condition (dict[str, Tensor])

Return type:

dict[str, Tensor]