LightningWrapperBase#
- class LightningWrapperBase(*args, **kwargs)#
Bases:
LightningModuleBase PyTorch Lightning module shared by all nvSubquadratic task wrappers.
Handles everything that is common across tasks: optimizer/scheduler construction, checkpoint resume key alignment, optional CUDA profiling, gradient norm tracking, and FLOP logging. Subclasses implement:
training_step/validation_step/test_stepTask-specific loss computation and metric logging
Parameter grouping (see
_build_param_groups()):Parameters tagged
_no_weight_decay = Trueare placed in a zero-decay group. Parameters tagged_lr_scale = <float>receive a per-parameter LR multiplier applied by scaling the base LR before passing the group to the optimizer.Scheduler
_build_lr_scheduler()chains a linear-warmupLinearLRwith the main schedule (cosine, WSD, constant) viaResumableSequentialLR, which fixes the PyTorch ≤ 2.10 checkpoint-resume LR bug.- network#
The wrapped model.
- Type:
- optimizer_cfg#
Optimizer config from
ExperimentConfig.
- scheduler_cfg#
Scheduler config from
ExperimentConfig.
- Parameters:
network (Module) – The neural network to train.
cfg (ExperimentConfig) – Full experiment configuration.
- __init__(network, cfg)#
Initialise the wrapper and log parameter count and FLOPs.
- Parameters:
network (Module) – The neural network to train.
cfg (ExperimentConfig) – Full experiment configuration;
cfg.optimizerandcfg.schedulerare stored for later use inconfigure_optimizers().
- property logger#
Return the first logger (Lightning 2.x uses a loggers list).
- on_load_checkpoint(checkpoint)#
Patch checkpoint for cross-optimizer and compiled/non-compiled resume.
Handles three mismatch scenarios:
state_dict key prefixes —
torch.compilewraps modules under_orig_mod, so checkpoint keys may differ from the live model.current_model_state key prefixes — when EMA is active, Lightning saves the raw training weights under
current_model_state. The EMA callback later callspl_module.load_state_dict(...)with these keys, so they must also be remapped.optimizer param-group keys — resuming with a different optimizer (e.g. Apex FusedLAMB vs torch_optimizer.Lamb) may require injecting default values for keys the new optimizer expects but the old checkpoint lacks (like
bias_correction,adam_w_mode, etc.).
- Parameters:
checkpoint (dict)
- Return type:
None
- forward(input_and_condition)#
Forward pass of the network.
- on_before_backward(loss)#
Called before backward pass - record forward end time.
- Parameters:
loss (Tensor)
- Return type:
None
- on_after_backward()#
Called after backward pass - record timing and log.
- Return type:
None
- configure_optimizers()#
Configure the optimizer and scheduler for training.
- on_before_optimizer_step(optimizer)#
Log the gradient norm before the optimizer step every grad_norm_interval steps.
- on_fit_start()#
Log the model architecture and parameter count to Weights & Biases once training starts.