diff --git a/nemo_automodel/loggers/log_utils.py b/nemo_automodel/loggers/log_utils.py index a26ab2a6d5..4bf7af4c76 100644 --- a/nemo_automodel/loggers/log_utils.py +++ b/nemo_automodel/loggers/log_utils.py @@ -17,9 +17,33 @@ from functools import partial from logging import Filter, LogRecord from typing import Callable, Optional, Union +import os logger = logging.getLogger(__name__) +class RankFilter(logging.Filter): + """ + A logging filter that controls log output based on the process rank. + + This filter allows log messages only for rank 0 by default. + """ + def filter(self, record): + """ + Decide whether to log the provided record. + + Args: + record (logging.LogRecord): The log record to be evaluated. + + Returns: + bool: True if the log record should be logged, False otherwise. + """ + # TODO(@akoumparouli): make this PP aware. + if 'LOCAL_RANK' in os.environ: + rank = int(os.environ.get('LOCAL_RANK')) + # permantly disable logging for rank != 0 + if rank > 0: + logging.disable(logging.CRITICAL) + return True def warning_filter(record: LogRecord) -> bool: """Logging filter to exclude WARNING level messages. @@ -30,10 +54,7 @@ def warning_filter(record: LogRecord) -> bool: Returns: False if the record level is WARNING, True otherwise. """ - if record.levelno == logging.WARNING: - return False - - return True + return record.levelno != logging.WARNING def module_filter(record: LogRecord, modules_to_filter: list[str]) -> bool: @@ -82,7 +103,7 @@ def setup_logging( from specific modules. Logging Level Precedence: - 1. Env var `NEMOLM_LOGGING_LEVEL` + 1. Env var `LOGGING_LEVEL` 2. `logging_level` argument 3. Default: `logging.INFO` @@ -94,7 +115,7 @@ def setup_logging( loggers. If False (default), only sets the level for the root logger and loggers starting with 'nemo'. """ - env_logging_level = os.getenv("NEMOLM_LOGGING_LEVEL", None) + env_logging_level = os.getenv("LOGGING_LEVEL", None) if env_logging_level is not None: logging_level = int(env_logging_level) @@ -108,5 +129,6 @@ def setup_logging( if filter_warning: add_filter_to_all_loggers(warning_filter) + logging.getLogger().addFilter(RankFilter()) if modules_to_filter: add_filter_to_all_loggers(partial(module_filter, modules_to_filter=modules_to_filter)) diff --git a/nemo_automodel/utils/dist_utils.py b/nemo_automodel/utils/dist_utils.py index 35d692acf5..ea031a30b6 100644 --- a/nemo_automodel/utils/dist_utils.py +++ b/nemo_automodel/utils/dist_utils.py @@ -24,9 +24,9 @@ from contextlib import ContextDecorator, nullcontext import torch.distributed as dist import yaml - from nemo_automodel.utils.yaml_utils import safe_yaml_representers +logger = logging.getLogger(__name__) class FirstRankPerNode(ContextDecorator): @@ -136,16 +136,6 @@ def get_local_rank_preinit() -> int: return int(os.getenv("LOCAL_RANK", "0")) -def print_rank_0(message: str) -> None: - """Print a message only on global rank 0. - - Args: - message: The message string to print. - """ - rank = get_rank_safe() - if rank == 0: - print(message, flush=True) - def is_last_rank() -> bool: """Check if the current rank is the last rank in the default process group. @@ -156,18 +146,6 @@ def is_last_rank() -> bool: return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1) -def print_rank_last(message: str) -> None: - """Print a message only on the last rank of the default process group. - - Args: - message: The message string to print. - """ - if torch.distributed.is_initialized(): - if is_last_rank(): - print(message, flush=True) - else: - print(message, flush=True) - def append_to_progress_log(save_dir: str, string: str, barrier: bool = True) -> None: """Append a formatted string to the progress log file (rank 0 only). @@ -203,7 +181,7 @@ def barrier_and_log(string: str) -> None: if torch.distributed.is_initialized(): torch.distributed.barrier() time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - print_rank_0(f"[{string}] datetime: {time_str} ") + logger.info("[{}] datetime: {} ".format(string, time_str)) def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any): @@ -333,4 +311,4 @@ def clip_gradients(model, clip_norm, foreach=True): if isinstance(grad_norm, torch.distributed.tensor.DTensor): grad_norm = grad_norm.full_tensor() torch.nn.utils.clip_grads_with_norm_([p for p in model.parameters()], clip_norm, grad_norm, foreach=foreach) - return grad_norm \ No newline at end of file + return grad_norm diff --git a/recipes/llm/finetune.py b/recipes/llm/finetune.py index 2229c302f2..fb45733455 100644 --- a/recipes/llm/finetune.py +++ b/recipes/llm/finetune.py @@ -21,6 +21,11 @@ from nemo_automodel.training.step_scheduler import StepScheduler from nemo_automodel.utils.dist_utils import reduce_loss, get_sync_ctx, rescale_gradients, clip_gradients from transformers import AutoTokenizer +from nemo_automodel.loggers.log_utils import setup_logging + +import logging +logger = logging.getLogger(__name__) + # --------------------------- # Stateless helper functions # --------------------------- @@ -182,7 +187,10 @@ def setup(self): Raises: NotImplemented: Raises if it tries to restore a checkpoint; will be removed. """ + torch.cuda.reset_peak_memory_stats() self.dist_env = build_distributed(self.cfg.get("dist_env", {})) + # setups logging and adds the rankfilter to logging + setup_logging() self.device_mesh = None self.model_wrapper = None @@ -213,7 +221,7 @@ def setup(self): config=self.cfg, settings=Settings(silent=True), ) - print("🚀 View run at {}".format(run.url)) + logging.info("🚀 View run at {}".format(run.url)) # Build components self.model = build_model(self.dist_env.device, self.cfg.model, self.cfg.get('peft', None), self.model_wrapper) @@ -395,13 +403,12 @@ def _run_train_step(self, batch, is_optim_step, clip_norm=1.0): # log reporting_loss = self.log_train_metrics(grad_norm) - if self.dist_env.is_main: - print( - f"step {self.step_scheduler.step} | " - f"epoch {self.step_scheduler.epoch} | " - f"loss {reporting_loss:.6f} | " - f"grad_norm {grad_norm:.6f}" + logging.info("step {} | epoch {} | loss {:.6f} | grad_norm {:.6f} | mem: {:.2f} GiB".format( + self.step_scheduler.step, self.step_scheduler.epoch, reporting_loss, grad_norm, + torch.cuda.max_memory_allocated() / 1024 ** 3 ) + ) + torch.cuda.reset_peak_memory_stats() @torch.no_grad() @@ -454,11 +461,10 @@ def _run_validation_epoch(self) -> float: "epoch": self.step_scheduler.epoch } ) - print( - f"[val] step {self.step_scheduler.step} | " - f"epoch {self.step_scheduler.epoch} | " - f"loss {val_loss:.4f}", + logging.info("[val] step {} | epoch {} | loss {:.6f}".format( + self.step_scheduler.step, self.step_scheduler.epoch, val_loss ) + ) def log_train_metrics(self, grad_norm): """