MaskMonitorCallback#

class MaskMonitorCallback(*args, **kwargs)#

Bases: Callback

Logs a single min/max chart per block for every mask module.

Parameters:

log_every_n_steps (int) – How often to log (in global steps).

__init__(log_every_n_steps=50)#
Parameters:

log_every_n_steps (int)

on_fit_start(trainer, pl_module)#
Parameters:
  • trainer (pytorch_lightning.Trainer)

  • pl_module (pytorch_lightning.LightningModule)

Return type:

None

on_train_batch_end(
trainer,
pl_module,
outputs,
batch,
batch_idx,
)#
property state_key: str#
state_dict()#
Return type:

dict[str, Any]

load_state_dict(state_dict)#
Parameters:

state_dict (dict[str, Any])

Return type:

None