diff --git a/docs/design-docs/logger.md b/docs/design-docs/logger.md index b13436423b..c98d199613 100644 --- a/docs/design-docs/logger.md +++ b/docs/design-docs/logger.md @@ -43,6 +43,10 @@ logging_config = { "project": "grpo-dev", "name": "grpo-dev-logging", }, + "swanlab": { + "project": "nemo-rl", + "name": "grpo-dev-logging", + }, "tensorboard": { "log_dir": "logs", }, @@ -64,6 +68,7 @@ The logger supports pretty-formatted logging of validation samples to help visua ```python logger: wandb_enabled: false + swanlab_enabled: false tensorboard_enabled: false num_val_samples_to_print: 10 ``` @@ -91,6 +96,7 @@ This feature is enabled with the `monitor_gpus` configuration parameter. The fre ```python logger: wandb_enabled: false + swanlab_enabled: false tensorboard_enabled: false monitor_gpus: true gpu_monitoring: @@ -107,4 +113,4 @@ While it is feasible to monitor using remote workers, the implementation require * Workers that spawn other workers accurately report the total resource usage of any grandchild workers. Due to these complexities, we opted for a simpler approach: collecting metrics exposed by the Ray metrics server from the driver. -::: \ No newline at end of file +::: diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index c039ecd939..db7a8570a1 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -49,6 +49,9 @@ class WandbConfig(TypedDict): project: NotRequired[str] name: NotRequired[str] +class SwanlabConfig(TypedDict): + project: NotRequired[str] + name: NotRequired[str] class TensorboardConfig(TypedDict): log_dir: NotRequired[str] @@ -62,8 +65,10 @@ class GPUMonitoringConfig(TypedDict): class LoggerConfig(TypedDict): log_dir: str wandb_enabled: bool + swanlab_enabled: bool tensorboard_enabled: bool wandb: WandbConfig + swanlab: SwanlabConfig tensorboard: TensorboardConfig monitor_gpus: bool gpu_monitoring: GPUMonitoringConfig @@ -313,6 +318,74 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None: """ self.run.log({name: figure}, step=step) +class SwanlabLogger(LoggerInterface): + """Weights & Biases logger backend.""" + + def __init__(self, cfg: SwanlabConfig, log_dir: Optional[str] = None): + import swanlab + self.run = swanlab.init(**cfg) + print( + f"Initialized SwanlabLogger for project {cfg.get('project')}, run {cfg.get('name')} at {log_dir}" + ) + + def define_metric( + self, + name: str, + step_metric: Optional[str] = None, + ) -> None: + """Define a metric with custom step metric. + + Args: + name: Name of the metric or pattern (e.g. 'ray/*') + step_metric: Optional name of the step metric to use + """ + self.run.define_metric(name, step_metric=step_metric) + + def log_metrics( + self, + metrics: dict[str, Any], + step: int, + prefix: Optional[str] = "", + step_metric: Optional[str] = None, + ) -> None: + """Log metrics to wandb. + + Args: + metrics: Dict of metrics to log + step: Global step value + prefix: Optional prefix for metric names + step_metric: Optional name of a field in metrics to use as step instead + of the provided step value + """ + if prefix: + metrics = { + f"{prefix}/{k}" if k != step_metric else k: v + for k, v in metrics.items() + } + + # If step_metric is provided, use the corresponding value from metrics as step + if step_metric and step_metric in metrics: + # commit=False so the step does not get incremented + self.run.log(metrics, commit=False) + else: + self.run.log(metrics, step=step) + + def log_hyperparams(self, params: Mapping[str, Any]) -> None: + """Log hyperparameters to swanlab. + + Args: + params: Dict of hyperparameters to log + """ + self.run.config.update(params) + + def log_plot(self, figure: plt.Figure, step: int, name: str) -> None: + """Log a plot to swanlab. + + Args: + figure: Matplotlib figure to log + step: Global step value + """ + self.run.log({name: figure}, step=step) class GpuMetricSnapshot(TypedDict): step: int @@ -629,6 +702,7 @@ def __init__(self, cfg: LoggerConfig): """ self.loggers: list[LoggerInterface] = [] self.wandb_logger = None + self.swanlab_logger = None self.base_log_dir = cfg["log_dir"] os.makedirs(self.base_log_dir, exist_ok=True) @@ -639,6 +713,12 @@ def __init__(self, cfg: LoggerConfig): self.wandb_logger = WandbLogger(cfg["wandb"], log_dir=wandb_log_dir) self.loggers.append(self.wandb_logger) + if cfg["swanlab_enabled"]: + swanlab_log_dir = os.path.join(self.base_log_dir, "swanlab") + os.makedirs(swanlab_log_dir, exist_ok=True) + self.swanlab_logger = SwanlabLogger(cfg["swanlab"], log_dir=swanlab_log_dir) + self.loggers.append(self.swanlab_logger) + if cfg["tensorboard_enabled"]: tensorboard_log_dir = os.path.join(self.base_log_dir, "tensorboard") os.makedirs(tensorboard_log_dir, exist_ok=True) @@ -657,6 +737,11 @@ def __init__(self, cfg: LoggerConfig): f"{metric_prefix}/*", step_metric=step_metric ) + if cfg["swanlab_enabled"] and self.swanlab_logger: + self.swanlab_logger.define_metric( + f"{metric_prefix}/*", step_metric=step_metric + ) + self.gpu_monitor = RayGpuMonitorLogger( collection_interval=cfg["gpu_monitoring"]["collection_interval"], flush_interval=cfg["gpu_monitoring"]["flush_interval"],