WELLRegressionWrapper#
- class WELLRegressionWrapper(*args, **kwargs)#
Bases:
RegressionWrapperLightning wrapper for WELL benchmark regression tasks.
This wrapper adapts the WELL benchmark’s data format and metrics to work with the nvSubquadratic training infrastructure.
Training is done with 1-step prediction. Validation uses both short and long autoregressive rollouts with WELL benchmark metrics.
WELL data format: - Batch format: dict with ‘input_fields’, ‘constant_fields’ (optional), etc. - Fields are in channels-last format: [B, T, H, W, C]
Model expects: - Input: [B, H, W, C_in] where C_in = n_steps_input * n_fields + n_constant_fields - Output: [B, H, W, C_out] where C_out = n_fields
- Parameters:
network (Module) – Network to wrap
cfg (ExperimentConfig) – Experiment configuration
metadata – WELL dataset metadata
n_steps_input (int) – Number of input timesteps
n_steps_output (int) – Number of output timesteps (for training, usually 1)
max_rollout_steps (int) – Maximum rollout steps for validation
metric (Literal['MAE', 'MSE']) – Training metric (‘MSE’ or ‘MAE’)
- __init__(
- network,
- cfg,
- metadata,
- n_steps_input=4,
- n_steps_output=1,
- max_rollout_steps=32,
- metric='MSE',
- normalization=None,
Initialize the WELL regression wrapper with dataset metadata and rollout settings.
- training_step(batch, batch_idx)#
Training uses 1-step prediction.
- Parameters:
batch – Dict with ‘input_fields’ and target in subsequent timestep
batch_idx – Index of the current batch
- validation_step(batch, batch_idx)#
Validation uses autoregressive rollout and WELL benchmark metrics.
- Parameters:
batch – Batch dict from validation dataloader
batch_idx – Index of the current batch
- test_step(batch, batch_idx)#
Test uses autoregressive rollout and WELL benchmark metrics.
- Parameters:
batch – Batch dict from test dataloader
batch_idx – Index of the current batch
- on_validation_epoch_end()#
Log best validation loss (rank 0 only to avoid duplicate wandb logs).