OmegaScaleMonitorCallback#

class OmegaScaleMonitorCallback(*args, **kwargs)#

Bases: Callback

Logs a single chart per block tracking the effective per-row ω₀.

For each Hyena block whose kernel uses a LearnableOmegaSIRENPositionalEmbeddingND, we log omega_eff_min / omega_eff_mean / omega_eff_max, the running per-block stats of omega_0 · scale (post-clamp). The raw scale series is intentionally omitted because the per-block ω₀ values differ substantially (typically by ~24×), which would otherwise compress the scale axis on shared charts.

Parameters:

log_every_n_steps (int) – How often to log (in global steps).

__init__(log_every_n_steps=50)#
Parameters:

log_every_n_steps (int)

on_fit_start(trainer, pl_module)#
Parameters:
  • trainer (pytorch_lightning.Trainer)

  • pl_module (pytorch_lightning.LightningModule)

Return type:

None

on_train_batch_end(
trainer,
pl_module,
outputs,
batch,
batch_idx,
)#
property state_key: str#
state_dict()#
Return type:

dict[str, Any]

load_state_dict(state_dict)#
Parameters:

state_dict (dict[str, Any])

Return type:

None