ClassificationWrapper#
- class ClassificationWrapper(*args, **kwargs)#
Bases:
LightningWrapperBaseLightning wrapper for classification tasks.
Loss modes (
lossparameter):"cross_entropy"— standardCrossEntropyLoss(hard labels)."soft_target_ce"—SoftTargetCrossEntropy:-sum(target * log_softmax(logits)). Classes compete via softmax. Use for finetuning with Mixup/CutMix (DeiT III recipe)."bce"—BCEWithLogitsLosswith binarized multi-hot targets. Each class is an independent sigmoid. Matching the ViT-5 / DeiT III pretraining recipe (--bce-loss).
- Parameters:
network (Module)
cfg (ExperimentConfig)
loss (str)
- __init__(network, cfg, loss='cross_entropy')#
Initialize classification wrapper with loss mode and metrics.
- Parameters:
network (Module)
cfg (ExperimentConfig)
loss (str)
- training_step(batch, batch_idx)#
Perform a training step and log the training loss & accuracy.
- validation_step(batch, batch_idx)#
Perform a validation step and log the validation loss & accuracy.
- test_step(batch, batch_idx)#
Perform a test step and log the test loss & accuracy.
- on_train_epoch_end()#
Log best training accuracy and loss and logits over the training set.
- on_validation_epoch_end()#
Log best validation accuracy and loss and logits over the validation set.
- on_save_checkpoint(checkpoint)#
Persist best-metric tracking values so they survive resume.
- Parameters:
checkpoint (dict)
- Return type:
None
- on_load_checkpoint(checkpoint)#
Restore best-metric tracking values and delegate key remapping to base.
- Parameters:
checkpoint (dict)
- Return type:
None
- static multiclass_prediction(logits)#
Predict the class with the highest logit for multi-class classification.
- static binary_prediction(logits)#
Predict the class with the highest logit for binary classification.