Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 78 additions & 11 deletions nemo_reinforcer/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ class LoggerInterface(ABC):

@abstractmethod
def log_metrics(
self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = ""
self,
metrics: Dict[str, Any],
step: int,
prefix: Optional[str] = "",
step_metric: Optional[str] = None,
) -> None:
"""Log a dictionary of metrics."""
pass
Expand All @@ -87,14 +91,19 @@ def __init__(self, cfg: TensorboardConfig, log_dir: Optional[str] = None):
print(f"Initialized TensorboardLogger at {log_dir}")

def log_metrics(
self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = ""
self,
metrics: Dict[str, Any],
step: int,
prefix: Optional[str] = "",
step_metric: Optional[str] = None, # ignored in TensorBoard
) -> None:
"""Log metrics to Tensorboard.

Args:
metrics: Dict of metrics to log
step: Global step value
prefix: Optional prefix for metric names
step_metric: Optional step metric name (ignored in TensorBoard)
"""
for name, value in metrics.items():
if prefix:
Expand All @@ -120,20 +129,47 @@ def __init__(self, cfg: WandbConfig, log_dir: Optional[str] = None):
f"Initialized WandbLogger for project {cfg.get('project')}, run {cfg.get('name')} at {log_dir}"
)

def define_metric(
self,
name: str,
step_metric: Optional[str] = None,
) -> None:
"""Define a metric with custom step metric.

Args:
name: Name of the metric or pattern (e.g. 'ray/*')
step_metric: Optional name of the step metric to use
"""
self.run.define_metric(name, step_metric=step_metric)

def log_metrics(
self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = ""
self,
metrics: Dict[str, Any],
step: int,
prefix: Optional[str] = "",
step_metric: Optional[str] = None,
) -> None:
"""Log metrics to wandb.

Args:
metrics: Dict of metrics to log
step: Global step value
prefix: Optional prefix for metric names
step_metric: Optional name of a field in metrics to use as step instead
of the provided step value
"""
if prefix:
metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}

self.run.log(metrics, step=step)
metrics = {
f"{prefix}/{k}" if k != step_metric else k: v
for k, v in metrics.items()
}

# If step_metric is provided, use the corresponding value from metrics as step
if step_metric and step_metric in metrics:
# commit=False so the step does not get incremented
self.run.log(metrics, commit=False)
else:
self.run.log(metrics, step=step)

def log_hyperparams(self, params: Dict[str, Any]) -> None:
"""Log hyperparameters to wandb.
Expand All @@ -156,17 +192,22 @@ def __init__(
self,
collection_interval: int | float,
flush_interval: int | float,
metric_prefix: str,
step_metric: str,
parent_logger: Optional["Logger"] = None,
):
"""Initialize the GPU monitor.

Args:
collection_interval: Interval in seconds to collect GPU metrics
flush_interval: Interval in seconds to flush metrics to parent logger
step_metric: Name of the field to use as the step metric
parent_logger: Logger to receive the collected metrics
"""
self.collection_interval = collection_interval
self.flush_interval = flush_interval
self.metric_prefix = metric_prefix
self.step_metric = step_metric
self.parent_logger = parent_logger
self.metrics_buffer: list[
GpuMetricSnapshot
Expand Down Expand Up @@ -425,7 +466,17 @@ def flush(self):
for entry in self.metrics_buffer:
step = entry["step"]
metrics = entry["metrics"]
self.parent_logger.log_metrics(metrics, step, prefix="ray")

# Add the step metric directly to metrics for use as step_metric
metrics[self.step_metric] = step

# Pass step_metric as the step_metric to use it as the step value in wandb
self.parent_logger.log_metrics(
metrics,
step=step,
prefix=self.metric_prefix,
step_metric=self.step_metric,
)

# Clear buffer after logging
self.metrics_buffer = []
Expand All @@ -448,15 +499,16 @@ def __init__(self, cfg: LoggerConfig):
- gpu_flush_interval
"""
self.loggers = []
self.wandb_logger = None

self.base_log_dir = cfg["log_dir"]
os.makedirs(self.base_log_dir, exist_ok=True)

if cfg["wandb_enabled"]:
wandb_log_dir = os.path.join(self.base_log_dir, "wandb")
os.makedirs(wandb_log_dir, exist_ok=True)
wandb_logger = WandbLogger(cfg["wandb"], log_dir=wandb_log_dir)
self.loggers.append(wandb_logger)
self.wandb_logger = WandbLogger(cfg["wandb"], log_dir=wandb_log_dir)
self.loggers.append(self.wandb_logger)

if cfg["tensorboard_enabled"]:
tensorboard_log_dir = os.path.join(self.base_log_dir, "tensorboard")
Expand All @@ -469,9 +521,18 @@ def __init__(self, cfg: LoggerConfig):
# Initialize GPU monitoring if requested
self.gpu_monitor = None
if cfg["monitor_gpus"]:
metric_prefix = "ray"
step_metric = f"{metric_prefix}/ray_step"
if cfg["wandb_enabled"] and self.wandb_logger:
self.wandb_logger.define_metric(
f"{metric_prefix}/*", step_metric=step_metric
)

self.gpu_monitor = RayGpuMonitorLogger(
collection_interval=cfg["gpu_monitoring"]["collection_interval"],
flush_interval=cfg["gpu_monitoring"]["flush_interval"],
metric_prefix=metric_prefix,
step_metric=step_metric,
parent_logger=self,
)
self.gpu_monitor.start()
Expand All @@ -480,17 +541,23 @@ def __init__(self, cfg: LoggerConfig):
print("No loggers initialized")

def log_metrics(
self, metrics: Dict[str, Any], step: int, prefix: Optional[str] = ""
self,
metrics: Dict[str, Any],
step: int,
prefix: Optional[str] = "",
step_metric: Optional[str] = None,
) -> None:
"""Log metrics to all enabled backends.

Args:
metrics: Dict of metrics to log
step: Global step value
prefix: Optional prefix for metric names
step_metric: Optional name of a field in metrics to use as step instead
of the provided step value (currently only needed for wandb)
"""
for logger in self.loggers:
logger.log_metrics(metrics, step, prefix)
logger.log_metrics(metrics, step, prefix, step_metric)

def log_hyperparams(self, params: Dict[str, Any]) -> None:
"""Log hyperparameters to all enabled backends.
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def session_data(request, init_ray_cluster):
logger = RayGpuMonitorLogger(
collection_interval=float("inf"),
flush_interval=float("inf"),
metric_prefix="test",
step_metric="test/step",
parent_logger=None,
)
unit_test_data["gpu_types"] = list(set(logger._collect_gpu_sku().values()))
Expand Down Expand Up @@ -209,6 +211,8 @@ def ray_gpu_monitor(init_ray_cluster):
gpu_monitor = RayGpuMonitorLogger(
collection_interval=1,
flush_interval=float("inf"), # Disabling flushing since we will do it manually
metric_prefix="test",
step_metric="test/step",
parent_logger=None,
)
gpu_monitor.start()
Expand Down
Loading
Loading