LightningWrapperBase#

class LightningWrapperBase(*args, **kwargs)#

Bases: LightningModule

Base PyTorch Lightning module shared by all nvSubquadratic task wrappers.

Handles everything that is common across tasks: optimizer/scheduler construction, checkpoint resume key alignment, optional CUDA profiling, gradient norm tracking, and FLOP logging. Subclasses implement:

  • training_step / validation_step / test_step

  • Task-specific loss computation and metric logging

Parameter grouping (see _build_param_groups()):

Parameters tagged _no_weight_decay = True are placed in a zero-decay group. Parameters tagged _lr_scale = <float> receive a per-parameter LR multiplier applied by scaling the base LR before passing the group to the optimizer.

Scheduler

_build_lr_scheduler() chains a linear-warmup LinearLR with the main schedule (cosine, WSD, constant) via ResumableSequentialLR, which fixes the PyTorch ≤ 2.10 checkpoint-resume LR bug.

network#

The wrapped model.

Type:

torch.nn.Module

optimizer_cfg#

Optimizer config from ExperimentConfig.

scheduler_cfg#

Scheduler config from ExperimentConfig.

distributed#

True when more than one GPU is visible.

Type:

bool

Parameters:
__init__(network, cfg)#

Initialise the wrapper and log parameter count and FLOPs.

Parameters:
property logger#

Return the first logger (Lightning 2.x uses a loggers list).

on_load_checkpoint(checkpoint)#

Patch checkpoint for cross-optimizer and compiled/non-compiled resume.

Handles three mismatch scenarios:

  1. state_dict key prefixestorch.compile wraps modules under _orig_mod, so checkpoint keys may differ from the live model.

  2. current_model_state key prefixes — when EMA is active, Lightning saves the raw training weights under current_model_state. The EMA callback later calls pl_module.load_state_dict(...) with these keys, so they must also be remapped.

  3. optimizer param-group keys — resuming with a different optimizer (e.g. Apex FusedLAMB vs torch_optimizer.Lamb) may require injecting default values for keys the new optimizer expects but the old checkpoint lacks (like bias_correction, adam_w_mode, etc.).

Parameters:

checkpoint (dict)

Return type:

None

forward(input_and_condition)#

Forward pass of the network.

Parameters:

input_and_condition (dict[str, Tensor]) – A dictionary containing the input and condition. Keys: “input” and “condition”.

Returns:

The output of the network.

on_before_backward(loss)#

Called before backward pass - record forward end time.

Parameters:

loss (Tensor)

Return type:

None

on_after_backward()#

Called after backward pass - record timing and log.

Return type:

None

configure_optimizers()#

Configure the optimizer and scheduler for training.

on_before_optimizer_step(optimizer)#

Log the gradient norm before the optimizer step every grad_norm_interval steps.

on_fit_start()#

Log the model architecture and parameter count to Weights & Biases once training starts.