diff --git a/src/art/local/backend.py b/src/art/local/backend.py index eefde0d8f..bfa9b143c 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -658,12 +658,7 @@ def _log_metrics( # If we have a W&B run, log the data there if run := self._get_wandb_run(model): - # Mark the step metric itself as hidden so W&B doesn't create an automatic chart for it - wandb.define_metric("training_step", hidden=True) - - # Enabling the following line will cause W&B to use the training_step metric as the x-axis for all metrics - # wandb.define_metric(f"{split}/*", step_metric="training_step") - run.log({"training_step": step, **metrics}, step=step) + run.log({"training_step": step, **metrics}) def _get_wandb_run(self, model: Model) -> Run | None: if "WANDB_API_KEY" not in os.environ: @@ -688,6 +683,12 @@ def _get_wandb_run(self, model: Model) -> Run | None: ), ) self._wandb_runs[model.name] = run + + # Define training_step as the x-axis for all metrics. + # This allows out-of-order logging (e.g., async validation for previous steps). + wandb.define_metric("training_step") + wandb.define_metric("train/*", step_metric="training_step") + wandb.define_metric("val/*", step_metric="training_step") os.environ["WEAVE_PRINT_CALL_LINK"] = os.getenv( "WEAVE_PRINT_CALL_LINK", "False" )