RegressionWrapper#
- class RegressionWrapper(*args, **kwargs)#
Bases:
LightningWrapperBaseLightning wrapper for regression tasks.
- Parameters:
network (Module)
cfg (ExperimentConfig)
metric (Literal['MAE', 'MSE'])
- __init__(network, cfg, metric)#
Initialize the RegressionWrapper.
- Parameters:
network (Module) – Network to wrap.
cfg (ExperimentConfig) – Configuration.
metric (Literal['MAE', 'MSE']) – Metric to use. Must be either ‘MAE’ or ‘MSE’.
- training_step(batch, batch_idx)#
Perform training step and log the training loss.
- validation_step(batch, batch_idx)#
Perform a validation step and log the validation loss.
- test_step(batch, batch_idx)#
Perform a test step and log the test loss.
- on_train_epoch_end()#
Log best train loss and logits over the training set.
- on_validation_epoch_end()#
Log best validation loss and logits over the validation set.