ViT5ClassificationNet#
- class ViT5ClassificationNet(
- in_channels,
- num_classes,
- hidden_dim,
- num_blocks,
- patch_size,
- image_size,
- num_registers,
- norm_cfg,
- readout,
- block_cfg=None,
- dropout_rate=0.0,
- neck_compression_ratio=None,
- reg_init='trunc_normal',
- layer_pattern=None,
- layer_types=None,
- padding_types=None,
- max_drop_path_rate=0.0,
- drop_path_schedule='constant',
Bases:
ModuleViT-5 classification network.
Token layout:
[patches (H*W), CLS (1), registers (R), padding (P)].Padding makes
T % grid_w == 0for 2D spatial mixers (Hyena). Attention blocks receive the sequence with padding stripped; Hyena blocks receive the full padded sequence.Supports two block-stacking modes (mutually exclusive):
Homogeneous (default): a single
block_cfgis replicatednum_blockstimes.Hybrid / interleaved:
layer_pattern+layer_typesdefine per-layer block types. For example,layer_pattern="HA" * 6withlayer_types={"H": hyena_cfg, "A": attn_cfg}creates 12 blocks alternating between Hyena and Attention.
- Parameters:
in_channels (int) – Number of input channels (3 for RGB).
num_classes (int) – Number of output classes.
hidden_dim (int) – Transformer hidden dimension.
num_blocks (int) – Number of transformer blocks.
patch_size (int) – Patch size for patchification.
image_size (int) – Input image size (assumes square).
num_registers (int) – Number of learnable register tokens.
block_cfg (LazyConfig | None) – LazyConfig for ViT5ResidualBlock (homogeneous mode). Mutually exclusive with
layer_pattern/layer_types.norm_cfg (LazyConfig) – LazyConfig for the normalization layer (RMSNorm).
dropout_rate (float) – Dropout rate applied before the classification head.
readout (Literal['cls', 'gap', 'register_concat']) – Classification readout strategy.
"cls": append a learnable CLS token after patches and read it out."gap": global average pooling over patch tokens."register_concat": gather register tokens after all blocks, compress each via a shared neck linear, concatenate, and project.neck_compression_ratio (int | None) – Compression ratio for
register_concatreadout. Required whenreadout="register_concat".reg_init (Literal['trunc_normal', 'zeros']) – Initialization strategy for register tokens.
"trunc_normal"(default) or"zeros".layer_pattern (str | None) – Pattern string defining per-layer block types (hybrid mode). Each character maps to a key in
layer_types. Length must equalnum_blocks. Example:"HA" * 6.layer_types (dict[str, LazyConfig] | None) – Dict mapping pattern characters to block LazyConfigs. Required when
layer_patternis set.padding_types (set[str] | None) – Set of
layer_patterncharacters whose blocks need the full padded sequence (e.g. Hyena). Blocks whose character is NOT in this set receive the sequence with padding stripped. Only relevant whenlayer_patternis used andpad_size > 0. Default:{"H"}.max_drop_path_rate (float) – Maximum stochastic depth drop probability. Per-layer rates are computed according to
drop_path_scheduleand injected into each block config at construction time.drop_path_schedule (Literal['constant', 'linear']) – How drop path rates are distributed across depth.
"constant": every layer getsmax_drop_path_rate."linear": ramp from 0 tomax_drop_path_rateacross depth.
- __init__(
- in_channels,
- num_classes,
- hidden_dim,
- num_blocks,
- patch_size,
- image_size,
- num_registers,
- norm_cfg,
- readout,
- block_cfg=None,
- dropout_rate=0.0,
- neck_compression_ratio=None,
- reg_init='trunc_normal',
- layer_pattern=None,
- layer_types=None,
- padding_types=None,
- max_drop_path_rate=0.0,
- drop_path_schedule='constant',
Construct patch embedding, positional embeddings, token buffers, and transformer blocks.
Validates
readout/layer_patternconstraints, computes the zero-padding size soT % grid_w == 0, builds the per-layer_block_needs_paddingflag list, instantiates each block viainstantiate()with per-layerdrop_path_rateandregister_start_idxinjected, and initialises all parameters with truncated-normal (std 0.02).- Parameters:
in_channels (int) – Input image channels (3 for RGB).
num_classes (int) – Number of output logits / classes.
hidden_dim (int) – Transformer hidden width
D.num_blocks (int) – Number of transformer blocks
N.patch_size (int) – Non-overlapping patch stride
P; patches areP×P.image_size (int) – Square input resolution
H = W. Produces(H/P)²patch tokens.num_registers (int) – Number of learnable register tokens
Rappended after the CLS token.norm_cfg (LazyConfig) – LazyConfig for the output normalisation layer.
readout (Literal['cls', 'gap', 'register_concat']) – Token aggregation strategy —
"cls","gap", or"register_concat"(see class docstring).block_cfg (LazyConfig | None) – Single LazyConfig replicated
Ntimes (homogeneous mode). Mutually exclusive withlayer_pattern.dropout_rate (float) – Dropout probability applied between the norm and the classification head.
0.0disables dropout.neck_compression_ratio (int | None) – Compression factor for
register_concatreadout;neck_dim = hidden_dim // neck_compression_ratio. Required whenreadout="register_concat".reg_init (Literal['trunc_normal', 'zeros']) – Register-token initialisation —
"trunc_normal"(std 0.02, default) or"zeros".layer_pattern (str | None) – Per-layer type string of length
num_blocks(hybrid mode). Each character maps to a key inlayer_types.layer_types (dict[str, LazyConfig] | None) – Dict mapping pattern characters to block LazyConfigs. Required when
layer_patternis set.padding_types (set[str] | None) – Set of pattern characters whose blocks receive the full padded sequence. Default
{"H"}(Hyena blocks).max_drop_path_rate (float) – Peak stochastic depth probability distributed across blocks according to
drop_path_schedule.drop_path_schedule (Literal['constant', 'linear']) –
"constant"(uniform) or"linear"(ramp 0 →max_drop_path_ratewith depth).
- flop_count(inference=False)#
Count FLOPs for a full ViT-5 classification forward pass (one sample).
- Pipeline:
Patch embedding (Conv2d): 2 * in_ch * D * P^2 * num_patches
Positional embedding add: num_patches * D
Transformer blocks: sum of block.flop_count(T)
Output norm: self.out_norm.flop_count(1)
Classification head: 2 * head_dim * num_classes
Token count T = num_non_pad + pad_size. Attention blocks see T_attn = num_non_pad (padding stripped).
- forward(input_and_condition)#
Forward pass.
Token layout:
[patches (H*W), CLS (1?), registers (R), padding (P)]. Attention blocks see[patches, CLS, registers](padding stripped). Hyena blocks see the full padded sequence.