ClassificationResNet#

class ClassificationResNet(
in_channels,
out_channels,
num_blocks,
hidden_dim,
data_dim,
in_proj_cfg,
out_proj_cfg,
norm_cfg,
block_cfg,
dropout_in_cfg,
condition_in_proj_cfg=None,
target_size=None,
gradient_checkpointing=False,
)#

Bases: ResidualNetwork

Residual network with global-average-pool readout for classification.

Inherits the full constructor and backbone from ResidualNetwork. Overrides only forward() to replace the spatial output with a single class-logit vector via global average pooling.

Output shape: [B, out_channels] regardless of input spatial size — the model is therefore resolution-agnostic at inference time.

No ``target_size``: the inherited target_size attribute is ignored; GAP serves as the spatial aggregation step.

All constructor arguments are documented in ResidualNetwork. The typical value for out_channels here is the number of classes.

Parameters:
forward(input_and_condition)#

Run classification forward pass: backbone → GAP → norm → projection.

Parameters:

input_and_condition (dict[str, Tensor]) –

Dictionary with two keys:

  • "input" — signal tensor of shape [B, *spatial, in_channels].

  • "condition" — optional conditioning tensor of shape [B, *spatial_cond, hidden_dim], or None.

Returns:

Single-key dict:

  • "logits" — shape [B, out_channels] (one logit vector per sample, all spatial information collapsed by global average pooling).

Return type:

dict[str, torch.Tensor]