ClassificationWrapper#

class ClassificationWrapper(*args, **kwargs)#

Bases: LightningWrapperBase

Lightning wrapper for classification tasks.

Loss modes (loss parameter):

  • "cross_entropy" — standard CrossEntropyLoss (hard labels).

  • "soft_target_ce"SoftTargetCrossEntropy: -sum(target * log_softmax(logits)). Classes compete via softmax. Use for finetuning with Mixup/CutMix (DeiT III recipe).

  • "bce"BCEWithLogitsLoss with binarized multi-hot targets. Each class is an independent sigmoid. Matching the ViT-5 / DeiT III pretraining recipe (--bce-loss).

Parameters:
__init__(network, cfg, loss='cross_entropy')#

Initialize classification wrapper with loss mode and metrics.

Parameters:
training_step(batch, batch_idx)#

Perform a training step and log the training loss & accuracy.

validation_step(batch, batch_idx)#

Perform a validation step and log the validation loss & accuracy.

test_step(batch, batch_idx)#

Perform a test step and log the test loss & accuracy.

on_train_epoch_end()#

Log best training accuracy and loss and logits over the training set.

on_validation_epoch_end()#

Log best validation accuracy and loss and logits over the validation set.

on_save_checkpoint(checkpoint)#

Persist best-metric tracking values so they survive resume.

Parameters:

checkpoint (dict)

Return type:

None

on_load_checkpoint(checkpoint)#

Restore best-metric tracking values and delegate key remapping to base.

Parameters:

checkpoint (dict)

Return type:

None

static multiclass_prediction(logits)#

Predict the class with the highest logit for multi-class classification.

static binary_prediction(logits)#

Predict the class with the highest logit for binary classification.