AutoregressiveWrapper#
- class AutoregressiveWrapper(*args, **kwargs)#
Bases:
LightningWrapperBaseLightning wrapper for autoregressive (next-token prediction) tasks.
Todo
Resume support (see
ClassificationWrapperfor the reference pattern). Addon_save_checkpoint/on_load_checkpointhooks that persistbest_train_lossandbest_val_lossacross job resumes — without them, best-metric tracking silently resets toinfafter every SLURM preemption or manual resume. Add corresponding tests intests/test_checkpoint_resume.py(seeTestBestMetricsPersistencefor the classification pattern).- Parameters:
network (Module) – Network to wrap. Should output logits of shape [B, L, vocab_size] for discrete tokens or [B, L, C] for continuous values.
cfg (ExperimentConfig) – Experiment configuration.
mode (Literal['discrete', 'continuous']) – “discrete” for token prediction (cross-entropy), “continuous” for value prediction (MSE/MAE).
vocab_size (int | None) – Vocabulary size (required for discrete mode).
loss_type (Literal['mse', 'mae']) – Loss type for continuous mode (“mse” or “mae”). Ignored for discrete.
ignore_index (int) – Index to ignore in loss computation (e.g., padding token). Default -100 (PyTorch convention).
- __init__(
- network,
- cfg,
- mode='discrete',
- vocab_size=None,
- loss_type='mse',
- ignore_index=-100,
Initialize the AutoregressiveWrapper.
- training_step(batch, batch_idx)#
Perform training step.
- validation_step(batch, batch_idx)#
Perform validation step.
- test_step(batch, batch_idx)#
Perform test step.
- on_train_epoch_end()#
Log metrics at end of training epoch.
- on_validation_epoch_end()#
Log metrics at end of validation epoch.
- generate(
- prompt,
- max_new_tokens,
- temperature=1.0,
- top_k=None,
- top_p=None,
- condition=None,
Generate tokens autoregressively.
- Parameters:
prompt (Tensor) – Initial sequence, shape [B, L] for discrete or [B, L, C] for continuous.
max_new_tokens (int) – Maximum number of new tokens to generate.
temperature (float) – Sampling temperature (1.0 = no change, <1.0 = more deterministic).
top_k (int | None) – If set, only sample from top-k most likely tokens.
top_p (float | None) – If set, use nucleus sampling with this probability mass.
condition (Tensor | None) – Optional conditioning tensor.
- Returns:
Generated sequence including prompt, shape [B, L + max_new_tokens, …].
- Return type:
- generate_greedy(
- prompt,
- max_new_tokens,
- condition=None,
Generate tokens using greedy decoding (argmax).