WELLRegressionWrapper#

class WELLRegressionWrapper(*args, **kwargs)#

Bases: RegressionWrapper

Lightning wrapper for WELL benchmark regression tasks.

This wrapper adapts the WELL benchmark’s data format and metrics to work with the nvSubquadratic training infrastructure.

Training is done with 1-step prediction. Validation uses both short and long autoregressive rollouts with WELL benchmark metrics.

WELL data format: - Batch format: dict with ‘input_fields’, ‘constant_fields’ (optional), etc. - Fields are in channels-last format: [B, T, H, W, C]

Model expects: - Input: [B, H, W, C_in] where C_in = n_steps_input * n_fields + n_constant_fields - Output: [B, H, W, C_out] where C_out = n_fields

Parameters:
  • network (Module) – Network to wrap

  • cfg (ExperimentConfig) – Experiment configuration

  • metadata – WELL dataset metadata

  • n_steps_input (int) – Number of input timesteps

  • n_steps_output (int) – Number of output timesteps (for training, usually 1)

  • max_rollout_steps (int) – Maximum rollout steps for validation

  • metric (Literal['MAE', 'MSE']) – Training metric (‘MSE’ or ‘MAE’)

__init__(
network,
cfg,
metadata,
n_steps_input=4,
n_steps_output=1,
max_rollout_steps=32,
metric='MSE',
normalization=None,
)#

Initialize the WELL regression wrapper with dataset metadata and rollout settings.

Parameters:
training_step(batch, batch_idx)#

Training uses 1-step prediction.

Parameters:
  • batch – Dict with ‘input_fields’ and target in subsequent timestep

  • batch_idx – Index of the current batch

validation_step(batch, batch_idx)#

Validation uses autoregressive rollout and WELL benchmark metrics.

Parameters:
  • batch – Batch dict from validation dataloader

  • batch_idx – Index of the current batch

test_step(batch, batch_idx)#

Test uses autoregressive rollout and WELL benchmark metrics.

Parameters:
  • batch – Batch dict from test dataloader

  • batch_idx – Index of the current batch

on_validation_epoch_end()#

Log best validation loss (rank 0 only to avoid duplicate wandb logs).