ClassificationResNet#
- class ClassificationResNet(
- in_channels,
- out_channels,
- num_blocks,
- hidden_dim,
- data_dim,
- in_proj_cfg,
- out_proj_cfg,
- norm_cfg,
- block_cfg,
- dropout_in_cfg,
- condition_in_proj_cfg=None,
- target_size=None,
- gradient_checkpointing=False,
Bases:
ResidualNetworkResidual network with global-average-pool readout for classification.
Inherits the full constructor and backbone from
ResidualNetwork. Overrides onlyforward()to replace the spatial output with a single class-logit vector via global average pooling.Output shape:
[B, out_channels]regardless of input spatial size — the model is therefore resolution-agnostic at inference time.No ``target_size``: the inherited
target_sizeattribute is ignored; GAP serves as the spatial aggregation step.All constructor arguments are documented in
ResidualNetwork. The typical value forout_channelshere is the number of classes.- Parameters:
in_channels (int)
out_channels (int)
num_blocks (int)
hidden_dim (int)
data_dim (int)
in_proj_cfg (LazyConfig)
out_proj_cfg (LazyConfig)
norm_cfg (LazyConfig)
block_cfg (LazyConfig)
dropout_in_cfg (LazyConfig)
condition_in_proj_cfg (LazyConfig | None)
gradient_checkpointing (bool)
- forward(input_and_condition)#
Run classification forward pass: backbone → GAP → norm → projection.
- Parameters:
input_and_condition (dict[str, Tensor]) –
Dictionary with two keys:
"input"— signal tensor of shape[B, *spatial, in_channels]."condition"— optional conditioning tensor of shape[B, *spatial_cond, hidden_dim], orNone.
- Returns:
Single-key dict:
"logits"— shape[B, out_channels](one logit vector per sample, all spatial information collapsed by global average pooling).
- Return type: