diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index a15e9bebd7..b99ebcc858 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -18,6 +18,7 @@ import logging import os import re +import subprocess import threading import time from abc import ABC, abstractmethod @@ -138,10 +139,120 @@ class WandbLogger(LoggerInterface): def __init__(self, cfg: WandbConfig, log_dir: Optional[str] = None): self.run = wandb.init(**cfg, dir=log_dir) + self._log_code() + self._log_diffs() print( f"Initialized WandbLogger for project {cfg.get('project')}, run {cfg.get('name')} at {log_dir}" ) + def _log_diffs(self): + """Log git diffs to wandb. + + This function captures and logs two types of diffs: + 1. Uncommitted changes (working tree diff against HEAD) + 2. All changes (including uncommitted) against the main branch + + Each diff is saved as a text file in a wandb artifact. + """ + try: + branch_result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + current_branch = branch_result.stdout.strip() + + diff_artifact = wandb.Artifact( + name=f"git-diffs-{self.run.project}-{self.run.id}", type="git-diffs" + ) + + # 1. Log uncommitted changes (working tree diff) + uncommitted_result = subprocess.run( + ["git", "diff", "HEAD"], capture_output=True, text=True, check=True + ) + uncommitted_diff = uncommitted_result.stdout + + if uncommitted_diff: + diff_path = os.path.join( + wandb.run.dir if wandb.run else ".", "uncommitted_changes_diff.txt" + ) + with open(diff_path, "w") as f: + f.write(uncommitted_diff) + + # Add file to artifact + diff_artifact.add_file(diff_path, name="uncommitted_changes_diff.txt") + print("Logged uncommitted changes diff to wandb") + else: + print("No uncommitted changes found") + + # 2. Log diff against main branch (if current branch is not main) + if current_branch != "main": + # Log diff between main and working tree (includes uncommitted changes) + working_diff_result = subprocess.run( + ["git", "diff", "main"], capture_output=True, text=True, check=True + ) + working_diff = working_diff_result.stdout + + if working_diff: + # Save diff to a temporary file + diff_path = os.path.join( + wandb.run.dir if wandb.run else ".", "main_diff.txt" + ) + with open(diff_path, "w") as f: + f.write(working_diff) + + # Add file to artifact + diff_artifact.add_file(diff_path, name="main_diff.txt") + print("Logged diff against main branch") + else: + print("No differences found between main and working tree") + + self.run.log_artifact(diff_artifact) + + except subprocess.CalledProcessError as e: + print(f"Error during git operations: {e}") + except Exception as e: + print(f"Unexpected error during git diff logging: {e}") + + def _log_code(self): + """Log code that is tracked by git to wandb. + + This function gets a list of all files tracked by git in the project root + and manually uploads them to the current wandb run as an artifact. + """ + try: + result = subprocess.run( + ["git", "ls-files"], capture_output=True, text=True, check=True + ) + + tracked_files = result.stdout.strip().split("\n") + + if not tracked_files: + print( + "Warning: No git repository found. Wandb logs will not track code changes for reproducibility." + ) + return + + code_artifact = wandb.Artifact( + name=f"source-code-{self.run.project}", type="code" + ) + + for file_path in tracked_files: + if os.path.isfile(file_path): + try: + code_artifact.add_file(file_path, name=file_path) + except Exception as e: + print(f"Error adding file {file_path}: {e}") + + self.run.log_artifact(code_artifact) + print(f"Logged {len(tracked_files)} git-tracked files to wandb") + + except subprocess.CalledProcessError as e: + print(f"Error getting git-tracked files: {e}") + except Exception as e: + print(f"Unexpected error during git code logging: {e}") + def define_metric( self, name: str,