ValidationImageGridCallback#

class ValidationImageGridCallback(*args, **kwargs)#

Bases: Callback

Validation image grid callback for PyTorch Lightning.

Visualizes input, prediction, and label images in a grid during validation. Can be triggered every N epochs or every N training iterations.

Parameters:
  • num_samples (int) – Number of samples to visualize.

  • every_n_epochs (int | None) – How often to visualize (in epochs). Set to None to disable.

  • every_n_train_steps (int | None) – How often to visualize (in training steps). Set to None to disable.

  • key (str) – Key to use for the visualization in the logger.

  • show_input (bool) – Whether to show the input image alongside prediction and label.

  • show_mask_separately (bool) – If True and input has 2 channels, display the grayscale canvas and mask as separate side-by-side images in the grid. Grid becomes: [canvas, mask, prediction, label] per row.

  • denormalize (bool) – Whether to denormalize the images.

  • mean (float) – Mean of the dataset (for denormalization).

  • std (float) – Standard deviation of the dataset (for denormalization).

  • flattened_image_shape (tuple | None) – Optional (H, W) to reshape flattened tensors of shape [B, H*W, C] into images. If not provided, will try to auto-infer a square shape.

__init__(
num_samples=4,
every_n_epochs=1,
every_n_train_steps=None,
key='val/image_grid',
show_input=True,
show_mask_separately=False,
denormalize=True,
mean=0.1307,
std=0.3081,
flattened_image_shape=None,
)#

Initialize the callback.

Parameters:
  • num_samples (int)

  • every_n_epochs (int | None)

  • every_n_train_steps (int | None)

  • key (str)

  • show_input (bool)

  • show_mask_separately (bool)

  • denormalize (bool)

  • mean (float)

  • std (float)

  • flattened_image_shape (tuple | None)

Return type:

None

on_validation_epoch_end(
trainer,
pl_module,
)#

Visualize the validation images at the end of the epoch.

Parameters:
  • trainer (pytorch_lightning.Trainer)

  • pl_module (pytorch_lightning.LightningModule)

Return type:

None

on_train_batch_end(
trainer,
pl_module,
outputs,
batch,
batch_idx,
)#

Visualize images during training every N steps.

Parameters:
  • trainer (pytorch_lightning.Trainer)

  • pl_module (pytorch_lightning.LightningModule)

  • batch_idx (int)

Return type:

None