Experiments#
The experiments package wires nvSubquadratic modules into reproducible
training pipelines built on PyTorch Lightning. Each experiment is a
LazyConfig of dataclasses (ExperimentConfig)
plus a wrapper subclass that defines the training step.
Entry points#
Parse command line arguments for the experiment. |
|
|
Main function to run the experiment. |
|
Construct a trainer and the checkpoint callback from a configuration. |
Configuration dataclasses#
The dataclasses below describe the experiment surface. They are
instantiated via LazyConfig from
the per-experiment config files in examples/.
|
Top-level configuration for a single nvSubquadratic training run. |
|
Train configuration. |
|
Lightning Trainer configuration overrides. |
|
Scheduler configuration. |
|
Wandb configuration. |
|
Auto-resume configuration via Weights & Biases run name. |
|
Configuration to start training from weights of a previously saved checkpoint (weights only, no optimizer/scheduler state). |
Lightning wrappers#
Task-specific wrappers around a common base. Each wrapper defines
training_step / validation_step / metrics for one task family.
|
Base PyTorch Lightning module shared by all nvSubquadratic task wrappers. |
|
Lightning wrapper for classification tasks. |
|
Cross-entropy loss with soft targets (from DeiT III / timm). |
|
Lightning wrapper for regression tasks. |
|
Lightning wrapper for WELL benchmark regression tasks. |
|
Lightning wrapper for autoregressive (next-token prediction) tasks. |
|
Constructs an optimizer for a given model given a configuration. |
|
Creates a learning rate scheduler for a given optimizer given a configuration. |
Callbacks#
|
Logs a compact FiLM diagnostic text report to wandb. |
|
Validation image grid callback for PyTorch Lightning. |
|
Validation volume grid callback for 3D spatial recall tasks. |
|
Logs iteration throughput, fwd/bwd breakdown, and GPU memory to wandb. |
|
Logs a single min/max chart per block for every mask module. |
|
|
|
Logs a single chart per block tracking the effective per-row ω₀. |
|
1D sequence visualization callback for spatial recall tasks. |
|
Checkpoints and stops the training when a walltime limit is reached. |
|
Periodically run wandb artifact cache cleanup to cap local cache size. |
Data modules#
PyTorch Lightning LightningDataModule subclasses for each dataset
that experiments target.
MNIST / EMNIST datamodule for PyTorch Lightning. |
|
EMNIST datamodule for PyTorch Lightning. |
|
TinyImageNet / ImageNet datamodule backed by Hugging Face Datasets. |
|
Spatial Recall Dataset and DataModule for PyTorch Lightning. |
|
Fully-fused DALI ImageNet DataModule — all augmentations inside DALI. |
|
Lightning DataModule for WELL benchmark datasets. |
Utilities#
|
Generate a deterministic run name based on the config file name, current timestamp, and any overrides. |
|
Load a configuration from a Python file. |
|
Download the checkpoint files from the Weights & Biases artifact marked with a given alias (default: "best"). |
|
Load a .ckpt file and return a flat state_dict-like mapping. |