AutoregressiveWrapper#

class AutoregressiveWrapper(*args, **kwargs)#

Bases: LightningWrapperBase

Lightning wrapper for autoregressive (next-token prediction) tasks.

Todo

Resume support (see ClassificationWrapper for the reference pattern). Add on_save_checkpoint / on_load_checkpoint hooks that persist best_train_loss and best_val_loss across job resumes — without them, best-metric tracking silently resets to inf after every SLURM preemption or manual resume. Add corresponding tests in tests/test_checkpoint_resume.py (see TestBestMetricsPersistence for the classification pattern).

Parameters:
  • network (Module) – Network to wrap. Should output logits of shape [B, L, vocab_size] for discrete tokens or [B, L, C] for continuous values.

  • cfg (ExperimentConfig) – Experiment configuration.

  • mode (Literal['discrete', 'continuous']) – “discrete” for token prediction (cross-entropy), “continuous” for value prediction (MSE/MAE).

  • vocab_size (int | None) – Vocabulary size (required for discrete mode).

  • loss_type (Literal['mse', 'mae']) – Loss type for continuous mode (“mse” or “mae”). Ignored for discrete.

  • ignore_index (int) – Index to ignore in loss computation (e.g., padding token). Default -100 (PyTorch convention).

__init__(
network,
cfg,
mode='discrete',
vocab_size=None,
loss_type='mse',
ignore_index=-100,
)#

Initialize the AutoregressiveWrapper.

Parameters:
training_step(batch, batch_idx)#

Perform training step.

validation_step(batch, batch_idx)#

Perform validation step.

test_step(batch, batch_idx)#

Perform test step.

on_train_epoch_end()#

Log metrics at end of training epoch.

on_validation_epoch_end()#

Log metrics at end of validation epoch.

generate(
prompt,
max_new_tokens,
temperature=1.0,
top_k=None,
top_p=None,
condition=None,
)#

Generate tokens autoregressively.

Parameters:
  • prompt (Tensor) – Initial sequence, shape [B, L] for discrete or [B, L, C] for continuous.

  • max_new_tokens (int) – Maximum number of new tokens to generate.

  • temperature (float) – Sampling temperature (1.0 = no change, <1.0 = more deterministic).

  • top_k (int | None) – If set, only sample from top-k most likely tokens.

  • top_p (float | None) – If set, use nucleus sampling with this probability mass.

  • condition (Tensor | None) – Optional conditioning tensor.

Returns:

Generated sequence including prompt, shape [B, L + max_new_tokens, …].

Return type:

Tensor

generate_greedy(
prompt,
max_new_tokens,
condition=None,
)#

Generate tokens using greedy decoding (argmax).

Parameters:
  • prompt (Tensor) – Initial sequence.

  • max_new_tokens (int) – Maximum number of new tokens to generate.

  • condition (Tensor | None) – Optional conditioning tensor.

Returns:

Generated sequence including prompt.

Return type:

Tensor