LabeledEMAWeightAveraging#
- class LabeledEMAWeightAveraging(*args, **kwargs)#
Bases:
EMAWeightAveragingEMAWeightAveragingthat labels validation metrics with an_emasuffix.Sets
pl_module._val_metric_suffix = "_ema"while EMA weights are swapped in for validation, so that any wrapper that honours the suffix (e.g.ClassificationWrapper) logs toval/acc_emainstead ofval/acc.- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
- on_validation_epoch_start(
- trainer,
- pl_module,
Swap in EMA weights and set metric suffix before validation.
- Parameters:
trainer (pytorch_lightning.Trainer)
pl_module (pytorch_lightning.LightningModule)
- Return type:
None
- on_validation_epoch_end(trainer, pl_module)#
Restore training weights after validation; suffix is left for module hooks.
- Parameters:
trainer (pytorch_lightning.Trainer)
pl_module (pytorch_lightning.LightningModule)
- Return type:
None