WellUNetConvNextV2#

class WellUNetConvNextV2(**kwargs)#

Bases: Module

Like WellUNetConvNext but with fixed skip connections.

Input: {"input": [B, *spatial, C_in], "condition": None} Output: {"logits": [B, *spatial, C_out]}

__init__(**kwargs)#

Initialize by forwarding all kwargs to UNetConvNextV2.

forward(input_and_condition)#

Transpose to channels-first, run UNet V2, transpose back.

Parameters:

input_and_condition (dict[str, Tensor])

Return type:

dict[str, Tensor]