RegressionWrapper#

class RegressionWrapper(*args, **kwargs)#

Bases: LightningWrapperBase

Lightning wrapper for regression tasks.

Parameters:
__init__(network, cfg, metric)#

Initialize the RegressionWrapper.

Parameters:
  • network (Module) – Network to wrap.

  • cfg (ExperimentConfig) – Configuration.

  • metric (Literal['MAE', 'MSE']) – Metric to use. Must be either ‘MAE’ or ‘MSE’.

training_step(batch, batch_idx)#

Perform training step and log the training loss.

validation_step(batch, batch_idx)#

Perform a validation step and log the validation loss.

test_step(batch, batch_idx)#

Perform a test step and log the test loss.

on_train_epoch_end()#

Log best train loss and logits over the training set.

on_validation_epoch_end()#

Log best validation loss and logits over the validation set.