ResidualNetwork#
- class ResidualNetwork(
- in_channels,
- out_channels,
- num_blocks,
- hidden_dim,
- data_dim,
- in_proj_cfg,
- out_proj_cfg,
- norm_cfg,
- block_cfg,
- dropout_in_cfg,
- condition_in_proj_cfg=None,
- target_size=None,
- gradient_checkpointing=False,
Bases:
ModuleGeneral-purpose residual network backbone (see module docstring for architecture).
All sub-modules (projections, norm, blocks) are instantiated from
LazyConfigobjects so the architecture can be fully configured from YAML/JSON without subclassing.Tensor layout
All tensors are in channels-last format:
[B, *spatial, C]. Thedata_dimargument records the number of spatial axes (1 for sequences, 2 for images, 3 for volumes) and is used to convert a scalartarget_sizeinto a per-axis tuple.Output format
forward()always returns adictwith key"logits"whose value has shape[B, *spatial_out, out_channels]. Whentarget_size=None,spatial_out = spatial; otherwise it is the cropped target region.Internal feature dimension used throughout the trunk.
- Type:
- dropout_in#
Input dropout / augmentation applied first.
- Type:
nn.Module
- in_proj#
in_channels → hidden_dimlinear projection.- Type:
nn.Module
- condition_in_proj#
Optional
hidden_dim → hidden_dimprojection for the conditioning signal.- Type:
nn.Module | None
- blocks#
Stack of
num_blocksresidual blocks.- Type:
nn.ModuleList
- out_norm#
Post-trunk normalisation (weight-decay excluded).
- Type:
nn.Module
- out_proj#
hidden_dim → out_channelsreadout projection.- Type:
nn.Module
- Parameters:
in_channels (int) – Number of input signal channels.
out_channels (int) – Number of output channels (e.g. vocabulary / class count).
num_blocks (int) – Depth of the residual tower.
hidden_dim (int) – Width of the residual tower.
data_dim (int) – Spatial dimensionality (1, 2, or 3).
in_proj_cfg (LazyConfig) – LazyConfig for the input projection (typically
nn.Linear).out_proj_cfg (LazyConfig) – LazyConfig for the output projection.
norm_cfg (LazyConfig) – LazyConfig for the output normalisation layer.
block_cfg (LazyConfig) – LazyConfig for each residual block; instantiated
num_blockstimes.dropout_in_cfg (LazyConfig) – LazyConfig for the input dropout layer.
condition_in_proj_cfg (LazyConfig | None) – Optional LazyConfig for the condition projection. Pass
Nonefor unconditional networks.target_size (int | Sequence[int] | None) – Readout crop size.
int→ same size on every spatial axis.tuple→ per-axis sizes (use1to squeeze that axis).None→ return the full output.gradient_checkpointing (bool) – Enable activation recomputation in
forward()to reduce peak memory at the cost of extra compute.
- __init__(
- in_channels,
- out_channels,
- num_blocks,
- hidden_dim,
- data_dim,
- in_proj_cfg,
- out_proj_cfg,
- norm_cfg,
- block_cfg,
- dropout_in_cfg,
- condition_in_proj_cfg=None,
- target_size=None,
- gradient_checkpointing=False,
Instantiate all sub-modules from LazyConfig objects.
- Parameters:
in_channels (int) – Number of input signal channels.
out_channels (int) – Number of output channels.
num_blocks (int) – Number of residual blocks to stack.
hidden_dim (int) – Internal feature width.
data_dim (int) – Spatial dimensionality (1 / 2 / 3).
in_proj_cfg (LazyConfig) – Config for the input projection.
out_proj_cfg (LazyConfig) – Config for the output projection.
norm_cfg (LazyConfig) – Config for the output norm layer.
block_cfg (LazyConfig) – Config for each residual block (instantiated N times).
dropout_in_cfg (LazyConfig) – Config for input dropout.
condition_in_proj_cfg (LazyConfig | None) – Optional config for condition projection.
target_size (int | Sequence[int] | None) – Readout crop specification (see class docstring).
gradient_checkpointing (bool) – Recompute activations during backward pass.
- forward(input_and_condition)#
Run the full forward pass: project → blocks → norm → project → crop.
- Parameters:
input_and_condition (dict[str, Tensor]) –
Dictionary with two keys:
"input"— signal tensor of shape[B, *spatial, in_channels]."condition"— optional conditioning tensor of shape[B, *spatial_cond, hidden_dim], orNonewhencondition_in_proj_cfgwas not provided.
- Returns:
Single-key dict:
"logits"— shape[B, *spatial_out, out_channels]wherespatial_outequalsspatialunlesstarget_sizeis set, in which case it is the cropped readout region.
- Return type: