Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/design-docs/logger.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ logging_config = {
"project": "grpo-dev",
"name": "grpo-dev-logging",
},
"swanlab": {
"project": "nemo-rl",
"name": "grpo-dev-logging",
},
"tensorboard": {
"log_dir": "logs",
},
Expand All @@ -64,6 +68,7 @@ The logger supports pretty-formatted logging of validation samples to help visua
```python
logger:
wandb_enabled: false
swanlab_enabled: false
tensorboard_enabled: false
num_val_samples_to_print: 10
```
Expand Down Expand Up @@ -91,6 +96,7 @@ This feature is enabled with the `monitor_gpus` configuration parameter. The fre
```python
logger:
wandb_enabled: false
swanlab_enabled: false
tensorboard_enabled: false
monitor_gpus: true
gpu_monitoring:
Expand All @@ -107,4 +113,4 @@ While it is feasible to monitor using remote workers, the implementation require
* Workers that spawn other workers accurately report the total resource usage of any grandchild workers.

Due to these complexities, we opted for a simpler approach: collecting metrics exposed by the Ray metrics server from the driver.
:::
:::
85 changes: 85 additions & 0 deletions nemo_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class WandbConfig(TypedDict):
project: NotRequired[str]
name: NotRequired[str]

class SwanlabConfig(TypedDict):
project: NotRequired[str]
name: NotRequired[str]

class TensorboardConfig(TypedDict):
log_dir: NotRequired[str]
Expand All @@ -62,8 +65,10 @@ class GPUMonitoringConfig(TypedDict):
class LoggerConfig(TypedDict):
log_dir: str
wandb_enabled: bool
swanlab_enabled: bool
tensorboard_enabled: bool
wandb: WandbConfig
swanlab: SwanlabConfig
tensorboard: TensorboardConfig
monitor_gpus: bool
gpu_monitoring: GPUMonitoringConfig
Expand Down Expand Up @@ -313,6 +318,74 @@ def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
"""
self.run.log({name: figure}, step=step)

class SwanlabLogger(LoggerInterface):
"""Weights & Biases logger backend."""

def __init__(self, cfg: SwanlabConfig, log_dir: Optional[str] = None):
import swanlab
self.run = swanlab.init(**cfg)
print(
f"Initialized SwanlabLogger for project {cfg.get('project')}, run {cfg.get('name')} at {log_dir}"
)

def define_metric(
self,
name: str,
step_metric: Optional[str] = None,
) -> None:
"""Define a metric with custom step metric.

Args:
name: Name of the metric or pattern (e.g. 'ray/*')
step_metric: Optional name of the step metric to use
"""
self.run.define_metric(name, step_metric=step_metric)

def log_metrics(
self,
metrics: dict[str, Any],
step: int,
prefix: Optional[str] = "",
step_metric: Optional[str] = None,
) -> None:
"""Log metrics to wandb.

Args:
metrics: Dict of metrics to log
step: Global step value
prefix: Optional prefix for metric names
step_metric: Optional name of a field in metrics to use as step instead
of the provided step value
"""
if prefix:
metrics = {
f"{prefix}/{k}" if k != step_metric else k: v
for k, v in metrics.items()
}

# If step_metric is provided, use the corresponding value from metrics as step
if step_metric and step_metric in metrics:
# commit=False so the step does not get incremented
self.run.log(metrics, commit=False)
else:
self.run.log(metrics, step=step)

def log_hyperparams(self, params: Mapping[str, Any]) -> None:
"""Log hyperparameters to swanlab.

Args:
params: Dict of hyperparameters to log
"""
self.run.config.update(params)

def log_plot(self, figure: plt.Figure, step: int, name: str) -> None:
"""Log a plot to swanlab.

Args:
figure: Matplotlib figure to log
step: Global step value
"""
self.run.log({name: figure}, step=step)

class GpuMetricSnapshot(TypedDict):
step: int
Expand Down Expand Up @@ -629,6 +702,7 @@ def __init__(self, cfg: LoggerConfig):
"""
self.loggers: list[LoggerInterface] = []
self.wandb_logger = None
self.swanlab_logger = None

self.base_log_dir = cfg["log_dir"]
os.makedirs(self.base_log_dir, exist_ok=True)
Expand All @@ -639,6 +713,12 @@ def __init__(self, cfg: LoggerConfig):
self.wandb_logger = WandbLogger(cfg["wandb"], log_dir=wandb_log_dir)
self.loggers.append(self.wandb_logger)

if cfg["swanlab_enabled"]:
swanlab_log_dir = os.path.join(self.base_log_dir, "swanlab")
os.makedirs(swanlab_log_dir, exist_ok=True)
self.swanlab_logger = SwanlabLogger(cfg["swanlab"], log_dir=swanlab_log_dir)
self.loggers.append(self.swanlab_logger)

if cfg["tensorboard_enabled"]:
tensorboard_log_dir = os.path.join(self.base_log_dir, "tensorboard")
os.makedirs(tensorboard_log_dir, exist_ok=True)
Expand All @@ -657,6 +737,11 @@ def __init__(self, cfg: LoggerConfig):
f"{metric_prefix}/*", step_metric=step_metric
)

if cfg["swanlab_enabled"] and self.swanlab_logger:
self.swanlab_logger.define_metric(
f"{metric_prefix}/*", step_metric=step_metric
)

self.gpu_monitor = RayGpuMonitorLogger(
collection_interval=cfg["gpu_monitoring"]["collection_interval"],
flush_interval=cfg["gpu_monitoring"]["flush_interval"],
Expand Down