Sequence1DVisualizationCallback#

class Sequence1DVisualizationCallback(*args, **kwargs)#

Bases: Callback

1D sequence visualization callback for spatial recall tasks.

Visualizes: - Input canvas as 1D line plot - Prediction and label as 2D images (reshaped from segment)

Can be triggered every N epochs or every N training iterations.

Parameters:
  • num_samples (int) – Number of samples to visualize.

  • target_size (int) – Original 2D image size (segment_length = target_size²).

  • every_n_epochs (int | None) – How often to visualize (in epochs). Set to None to disable.

  • every_n_train_steps (int | None) – How often to visualize (in training steps). Set to None to disable.

  • key (str) – Key to use for the visualization in the logger.

  • show_input (bool) – Whether to show the input canvas alongside prediction and label.

  • show_mask_separately (bool) – If True and input has 2 channels, display the canvas and mask as separate line plots. Grid becomes: [canvas, mask, prediction, label] per row.

  • denormalize (bool) – Whether to denormalize the images.

  • mean (float) – Mean of the dataset (for denormalization).

  • std (float) – Standard deviation of the dataset (for denormalization).

  • readout_value (float) – Value used for readout region (for visualization reference line).

__init__(
num_samples=4,
target_size=16,
every_n_epochs=1,
every_n_train_steps=None,
key='val/sequence_1d_grid',
show_input=True,
show_mask_separately=False,
denormalize=True,
mean=0.1307,
std=0.3081,
readout_value=0.0,
)#

Initialize the callback.

Parameters:
  • num_samples (int)

  • target_size (int)

  • every_n_epochs (int | None)

  • every_n_train_steps (int | None)

  • key (str)

  • show_input (bool)

  • show_mask_separately (bool)

  • denormalize (bool)

  • mean (float)

  • std (float)

  • readout_value (float)

Return type:

None

on_train_batch_end(
trainer,
pl_module,
outputs,
batch,
batch_idx,
)#

Visualize sequences during training every N steps.

Parameters:
  • trainer (pytorch_lightning.Trainer)

  • pl_module (pytorch_lightning.LightningModule)

  • batch_idx (int)

Return type:

None

on_validation_epoch_end(
trainer,
pl_module,
)#

Visualize at the end of the validation epoch.

Parameters:
  • trainer (pytorch_lightning.Trainer)

  • pl_module (pytorch_lightning.LightningModule)

Return type:

None