Experiments#

The experiments package wires nvSubquadratic modules into reproducible training pipelines built on PyTorch Lightning. Each experiment is a LazyConfig of dataclasses (ExperimentConfig) plus a wrapper subclass that defines the training step.

Entry points#

parse_args()

Parse command line arguments for the experiment.

main()

Main function to run the experiment.

construct_trainer(cfg, wandb_logger, run_name)

Construct a trainer and the checkpoint callback from a configuration.

Configuration dataclasses#

The dataclasses below describe the experiment surface. They are instantiated via LazyConfig from the per-experiment config files in examples/.

ExperimentConfig([device, debug, ...])

Top-level configuration for a single nvSubquadratic training run.

TrainConfig([do, precision, iterations, ...])

Train configuration.

TrainerConfig([samples_per_epoch, ...])

Lightning Trainer configuration overrides.

SchedulerConfig([name, ...])

Scheduler configuration.

WandbConfig([project, entity, job_group, ...])

Wandb configuration.

AutoResumeConfig([enabled, alias, run_name])

Auto-resume configuration via Weights & Biases run name.

StartFromCheckpointConfig([load, alias, ...])

Configuration to start training from weights of a previously saved checkpoint (weights only, no optimizer/scheduler state).

Lightning wrappers#

Task-specific wrappers around a common base. Each wrapper defines training_step / validation_step / metrics for one task family.

LightningWrapperBase(*args, **kwargs)

Base PyTorch Lightning module shared by all nvSubquadratic task wrappers.

ClassificationWrapper(*args, **kwargs)

Lightning wrapper for classification tasks.

SoftTargetCrossEntropy(*args, **kwargs)

Cross-entropy loss with soft targets (from DeiT III / timm).

RegressionWrapper(*args, **kwargs)

Lightning wrapper for regression tasks.

WELLRegressionWrapper(*args, **kwargs)

Lightning wrapper for WELL benchmark regression tasks.

AutoregressiveWrapper(*args, **kwargs)

Lightning wrapper for autoregressive (next-token prediction) tasks.

construct_optimizer(model, optimizer_cfg)

Constructs an optimizer for a given model given a configuration.

construct_scheduler(optimizer, scheduler_cfg)

Creates a learning rate scheduler for a given optimizer given a configuration.

Callbacks#

FiLMMonitorCallback(*args, **kwargs)

Logs a compact FiLM diagnostic text report to wandb.

ValidationImageGridCallback(*args, **kwargs)

Validation image grid callback for PyTorch Lightning.

ValidationVolumeGridCallback(*args, **kwargs)

Validation volume grid callback for 3D spatial recall tasks.

IterationSpeedCallback(*args, **kwargs)

Logs iteration throughput, fwd/bwd breakdown, and GPU memory to wandb.

MaskMonitorCallback(*args, **kwargs)

Logs a single min/max chart per block for every mask module.

LabeledEMAWeightAveraging(*args, **kwargs)

EMAWeightAveraging that labels validation metrics with an _ema suffix.

OmegaScaleMonitorCallback(*args, **kwargs)

Logs a single chart per block tracking the effective per-row ω₀.

Sequence1DVisualizationCallback(*args, **kwargs)

1D sequence visualization callback for spatial recall tasks.

WalltimeCheckpointer(*args, **kwargs)

Checkpoints and stops the training when a walltime limit is reached.

WandbCacheCleanupCallback(*args, **kwargs)

Periodically run wandb artifact cache cleanup to cap local cache size.

Data modules#

PyTorch Lightning LightningDataModule subclasses for each dataset that experiments target.

mnist

MNIST / EMNIST datamodule for PyTorch Lightning.

emnist

EMNIST datamodule for PyTorch Lightning.

tinyimagenet

TinyImageNet / ImageNet datamodule backed by Hugging Face Datasets.

spatial_recall_dataset

Spatial Recall Dataset and DataModule for PyTorch Lightning.

dali_imagenet_fused

Fully-fused DALI ImageNet DataModule — all augmentations inside DALI.

well

Lightning DataModule for WELL benchmark datasets.

Utilities#

get_deterministic_run_name(config_path[, ...])

Generate a deterministic run name based on the config file name, current timestamp, and any overrides.

load_config_from_file(config_path)

Load a configuration from a Python file.

download_checkpoint(run_path[, alias])

Download the checkpoint files from the Weights & Biases artifact marked with a given alias (default: "best").

load_checkpoint_state_dict(ckpt_path)

Load a .ckpt file and return a flat state_dict-like mapping.