ViT5ClassificationNet#

class ViT5ClassificationNet(
in_channels,
num_classes,
hidden_dim,
num_blocks,
patch_size,
image_size,
num_registers,
norm_cfg,
readout,
block_cfg=None,
dropout_rate=0.0,
neck_compression_ratio=None,
reg_init='trunc_normal',
layer_pattern=None,
layer_types=None,
padding_types=None,
max_drop_path_rate=0.0,
drop_path_schedule='constant',
)#

Bases: Module

ViT-5 classification network.

Token layout: [patches (H*W), CLS (1), registers (R), padding (P)].

Padding makes T % grid_w == 0 for 2D spatial mixers (Hyena). Attention blocks receive the sequence with padding stripped; Hyena blocks receive the full padded sequence.

Supports two block-stacking modes (mutually exclusive):

  1. Homogeneous (default): a single block_cfg is replicated num_blocks times.

  2. Hybrid / interleaved: layer_pattern + layer_types define per-layer block types. For example, layer_pattern="HA" * 6 with layer_types={"H": hyena_cfg, "A": attn_cfg} creates 12 blocks alternating between Hyena and Attention.

Parameters:
  • in_channels (int) – Number of input channels (3 for RGB).

  • num_classes (int) – Number of output classes.

  • hidden_dim (int) – Transformer hidden dimension.

  • num_blocks (int) – Number of transformer blocks.

  • patch_size (int) – Patch size for patchification.

  • image_size (int) – Input image size (assumes square).

  • num_registers (int) – Number of learnable register tokens.

  • block_cfg (LazyConfig | None) – LazyConfig for ViT5ResidualBlock (homogeneous mode). Mutually exclusive with layer_pattern/layer_types.

  • norm_cfg (LazyConfig) – LazyConfig for the normalization layer (RMSNorm).

  • dropout_rate (float) – Dropout rate applied before the classification head.

  • readout (Literal['cls', 'gap', 'register_concat']) – Classification readout strategy. "cls": append a learnable CLS token after patches and read it out. "gap": global average pooling over patch tokens. "register_concat": gather register tokens after all blocks, compress each via a shared neck linear, concatenate, and project.

  • neck_compression_ratio (int | None) – Compression ratio for register_concat readout. Required when readout="register_concat".

  • reg_init (Literal['trunc_normal', 'zeros']) – Initialization strategy for register tokens. "trunc_normal" (default) or "zeros".

  • layer_pattern (str | None) – Pattern string defining per-layer block types (hybrid mode). Each character maps to a key in layer_types. Length must equal num_blocks. Example: "HA" * 6.

  • layer_types (dict[str, LazyConfig] | None) – Dict mapping pattern characters to block LazyConfigs. Required when layer_pattern is set.

  • padding_types (set[str] | None) – Set of layer_pattern characters whose blocks need the full padded sequence (e.g. Hyena). Blocks whose character is NOT in this set receive the sequence with padding stripped. Only relevant when layer_pattern is used and pad_size > 0. Default: {"H"}.

  • max_drop_path_rate (float) – Maximum stochastic depth drop probability. Per-layer rates are computed according to drop_path_schedule and injected into each block config at construction time.

  • drop_path_schedule (Literal['constant', 'linear']) – How drop path rates are distributed across depth. "constant": every layer gets max_drop_path_rate. "linear": ramp from 0 to max_drop_path_rate across depth.

__init__(
in_channels,
num_classes,
hidden_dim,
num_blocks,
patch_size,
image_size,
num_registers,
norm_cfg,
readout,
block_cfg=None,
dropout_rate=0.0,
neck_compression_ratio=None,
reg_init='trunc_normal',
layer_pattern=None,
layer_types=None,
padding_types=None,
max_drop_path_rate=0.0,
drop_path_schedule='constant',
)#

Construct patch embedding, positional embeddings, token buffers, and transformer blocks.

Validates readout / layer_pattern constraints, computes the zero-padding size so T % grid_w == 0, builds the per-layer _block_needs_padding flag list, instantiates each block via instantiate() with per-layer drop_path_rate and register_start_idx injected, and initialises all parameters with truncated-normal (std 0.02).

Parameters:
  • in_channels (int) – Input image channels (3 for RGB).

  • num_classes (int) – Number of output logits / classes.

  • hidden_dim (int) – Transformer hidden width D.

  • num_blocks (int) – Number of transformer blocks N.

  • patch_size (int) – Non-overlapping patch stride P; patches are P×P.

  • image_size (int) – Square input resolution H = W. Produces (H/P)² patch tokens.

  • num_registers (int) – Number of learnable register tokens R appended after the CLS token.

  • norm_cfg (LazyConfig) – LazyConfig for the output normalisation layer.

  • readout (Literal['cls', 'gap', 'register_concat']) – Token aggregation strategy — "cls", "gap", or "register_concat" (see class docstring).

  • block_cfg (LazyConfig | None) – Single LazyConfig replicated N times (homogeneous mode). Mutually exclusive with layer_pattern.

  • dropout_rate (float) – Dropout probability applied between the norm and the classification head. 0.0 disables dropout.

  • neck_compression_ratio (int | None) – Compression factor for register_concat readout; neck_dim = hidden_dim // neck_compression_ratio. Required when readout="register_concat".

  • reg_init (Literal['trunc_normal', 'zeros']) – Register-token initialisation — "trunc_normal" (std 0.02, default) or "zeros".

  • layer_pattern (str | None) – Per-layer type string of length num_blocks (hybrid mode). Each character maps to a key in layer_types.

  • layer_types (dict[str, LazyConfig] | None) – Dict mapping pattern characters to block LazyConfigs. Required when layer_pattern is set.

  • padding_types (set[str] | None) – Set of pattern characters whose blocks receive the full padded sequence. Default {"H"} (Hyena blocks).

  • max_drop_path_rate (float) – Peak stochastic depth probability distributed across blocks according to drop_path_schedule.

  • drop_path_schedule (Literal['constant', 'linear']) – "constant" (uniform) or "linear" (ramp 0 → max_drop_path_rate with depth).

flop_count(inference=False)#

Count FLOPs for a full ViT-5 classification forward pass (one sample).

Pipeline:
  1. Patch embedding (Conv2d): 2 * in_ch * D * P^2 * num_patches

  2. Positional embedding add: num_patches * D

  3. Transformer blocks: sum of block.flop_count(T)

  4. Output norm: self.out_norm.flop_count(1)

  5. Classification head: 2 * head_dim * num_classes

Token count T = num_non_pad + pad_size. Attention blocks see T_attn = num_non_pad (padding stripped).

Parameters:

inference (bool) – Passed through to each block for kernel caching decisions.

Returns:

Total FLOPs as an integer.

Return type:

int

forward(input_and_condition)#

Forward pass.

Token layout: [patches (H*W), CLS (1?), registers (R), padding (P)]. Attention blocks see [patches, CLS, registers] (padding stripped). Hyena blocks see the full padded sequence.

Parameters:

input_and_condition (dict[str, Tensor]) – Dict with keys “input” (images [B, H, W, C]) and “condition” (unused).

Returns:

Dict with key “logits” of shape [B, num_classes].

Return type:

dict[str, Tensor]