LabeledEMAWeightAveraging#

class LabeledEMAWeightAveraging(*args, **kwargs)#

Bases: EMAWeightAveraging

EMAWeightAveraging that labels validation metrics with an _ema suffix.

Sets pl_module._val_metric_suffix = "_ema" while EMA weights are swapped in for validation, so that any wrapper that honours the suffix (e.g. ClassificationWrapper) logs to val/acc_ema instead of val/acc.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

on_validation_epoch_start(
trainer,
pl_module,
)#

Swap in EMA weights and set metric suffix before validation.

Parameters:
  • trainer (pytorch_lightning.Trainer)

  • pl_module (pytorch_lightning.LightningModule)

Return type:

None

on_validation_epoch_end(trainer, pl_module)#

Restore training weights after validation; suffix is left for module hooks.

Parameters:
  • trainer (pytorch_lightning.Trainer)

  • pl_module (pytorch_lightning.LightningModule)

Return type:

None