ResidualNetwork#

class ResidualNetwork(
in_channels,
out_channels,
num_blocks,
hidden_dim,
data_dim,
in_proj_cfg,
out_proj_cfg,
norm_cfg,
block_cfg,
dropout_in_cfg,
condition_in_proj_cfg=None,
target_size=None,
gradient_checkpointing=False,
)#

Bases: Module

General-purpose residual network backbone (see module docstring for architecture).

All sub-modules (projections, norm, blocks) are instantiated from LazyConfig objects so the architecture can be fully configured from YAML/JSON without subclassing.

Tensor layout

All tensors are in channels-last format: [B, *spatial, C]. The data_dim argument records the number of spatial axes (1 for sequences, 2 for images, 3 for volumes) and is used to convert a scalar target_size into a per-axis tuple.

Output format

forward() always returns a dict with key "logits" whose value has shape [B, *spatial_out, out_channels]. When target_size=None, spatial_out = spatial; otherwise it is the cropped target region.

in_channels#

Input channel count.

Type:

int

out_channels#

Output channel count / number of classes.

Type:

int

num_blocks#

Number of stacked residual blocks.

Type:

int

hidden_dim#

Internal feature dimension used throughout the trunk.

Type:

int

data_dim#

Number of spatial axes (1/2/3).

Type:

int

gradient_checkpointing#

Recompute activations on backward.

Type:

bool

target_size#

Per-axis readout crop size, or None.

Type:

tuple | None

dropout_in#

Input dropout / augmentation applied first.

Type:

nn.Module

in_proj#

in_channels hidden_dim linear projection.

Type:

nn.Module

condition_in_proj#

Optional hidden_dim hidden_dim projection for the conditioning signal.

Type:

nn.Module | None

blocks#

Stack of num_blocks residual blocks.

Type:

nn.ModuleList

out_norm#

Post-trunk normalisation (weight-decay excluded).

Type:

nn.Module

out_proj#

hidden_dim out_channels readout projection.

Type:

nn.Module

Parameters:
  • in_channels (int) – Number of input signal channels.

  • out_channels (int) – Number of output channels (e.g. vocabulary / class count).

  • num_blocks (int) – Depth of the residual tower.

  • hidden_dim (int) – Width of the residual tower.

  • data_dim (int) – Spatial dimensionality (1, 2, or 3).

  • in_proj_cfg (LazyConfig) – LazyConfig for the input projection (typically nn.Linear).

  • out_proj_cfg (LazyConfig) – LazyConfig for the output projection.

  • norm_cfg (LazyConfig) – LazyConfig for the output normalisation layer.

  • block_cfg (LazyConfig) – LazyConfig for each residual block; instantiated num_blocks times.

  • dropout_in_cfg (LazyConfig) – LazyConfig for the input dropout layer.

  • condition_in_proj_cfg (LazyConfig | None) – Optional LazyConfig for the condition projection. Pass None for unconditional networks.

  • target_size (int | Sequence[int] | None) – Readout crop size. int → same size on every spatial axis. tuple → per-axis sizes (use 1 to squeeze that axis). None → return the full output.

  • gradient_checkpointing (bool) – Enable activation recomputation in forward() to reduce peak memory at the cost of extra compute.

__init__(
in_channels,
out_channels,
num_blocks,
hidden_dim,
data_dim,
in_proj_cfg,
out_proj_cfg,
norm_cfg,
block_cfg,
dropout_in_cfg,
condition_in_proj_cfg=None,
target_size=None,
gradient_checkpointing=False,
)#

Instantiate all sub-modules from LazyConfig objects.

Parameters:
  • in_channels (int) – Number of input signal channels.

  • out_channels (int) – Number of output channels.

  • num_blocks (int) – Number of residual blocks to stack.

  • hidden_dim (int) – Internal feature width.

  • data_dim (int) – Spatial dimensionality (1 / 2 / 3).

  • in_proj_cfg (LazyConfig) – Config for the input projection.

  • out_proj_cfg (LazyConfig) – Config for the output projection.

  • norm_cfg (LazyConfig) – Config for the output norm layer.

  • block_cfg (LazyConfig) – Config for each residual block (instantiated N times).

  • dropout_in_cfg (LazyConfig) – Config for input dropout.

  • condition_in_proj_cfg (LazyConfig | None) – Optional config for condition projection.

  • target_size (int | Sequence[int] | None) – Readout crop specification (see class docstring).

  • gradient_checkpointing (bool) – Recompute activations during backward pass.

forward(input_and_condition)#

Run the full forward pass: project → blocks → norm → project → crop.

Parameters:

input_and_condition (dict[str, Tensor]) –

Dictionary with two keys:

  • "input" — signal tensor of shape [B, *spatial, in_channels].

  • "condition" — optional conditioning tensor of shape [B, *spatial_cond, hidden_dim], or None when condition_in_proj_cfg was not provided.

Returns:

Single-key dict:

  • "logits" — shape [B, *spatial_out, out_channels] where spatial_out equals spatial unless target_size is set, in which case it is the cropped readout region.

Return type:

dict[str, torch.Tensor]