From db79dab2e03596095c4b103308a3f64b1c77c69b Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 28 May 2025 12:35:41 -0700 Subject: [PATCH 01/10] fix Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/loggers/log_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/nemo_automodel/loggers/log_utils.py b/nemo_automodel/loggers/log_utils.py index a26ab2a6d5..7a054dee96 100644 --- a/nemo_automodel/loggers/log_utils.py +++ b/nemo_automodel/loggers/log_utils.py @@ -30,10 +30,7 @@ def warning_filter(record: LogRecord) -> bool: Returns: False if the record level is WARNING, True otherwise. """ - if record.levelno == logging.WARNING: - return False - - return True + return record.levelno != logging.WARNING def module_filter(record: LogRecord, modules_to_filter: list[str]) -> bool: From 2530144065a888c7f4095061ad7ec3191ce32ac4 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 28 May 2025 12:37:48 -0700 Subject: [PATCH 02/10] add RankFilter Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/loggers/log_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/nemo_automodel/loggers/log_utils.py b/nemo_automodel/loggers/log_utils.py index 7a054dee96..91d60dfdf1 100644 --- a/nemo_automodel/loggers/log_utils.py +++ b/nemo_automodel/loggers/log_utils.py @@ -20,6 +20,17 @@ logger = logging.getLogger(__name__) +class RankFilter(logging.Filter): + def __init__(self, rank): + super().__init__() + self.rank = rank + + def filter(self, record): + # If the log record explicitly requests to bypass the rank check, allow it. + if getattr(record, 'all_ranks', False): + return True + # Otherwise, only allow logs from rank 0. + return self.rank == 0 def warning_filter(record: LogRecord) -> bool: """Logging filter to exclude WARNING level messages. From 5a1cc9954e1e54ce7a9388e84eda265bcc717ada Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 28 May 2025 12:43:24 -0700 Subject: [PATCH 03/10] add docstring Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/loggers/log_utils.py | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/nemo_automodel/loggers/log_utils.py b/nemo_automodel/loggers/log_utils.py index 91d60dfdf1..e6648831d7 100644 --- a/nemo_automodel/loggers/log_utils.py +++ b/nemo_automodel/loggers/log_utils.py @@ -21,11 +21,42 @@ logger = logging.getLogger(__name__) class RankFilter(logging.Filter): + """ + A logging filter that controls log output based on the process rank. + + This filter allows log messages only for rank 0 by default. + If a log record has an attribute 'all_ranks' set to True, + the log message will always be output regardless of the process rank. + """ def __init__(self, rank): + """ + Decide whether to log the provided record. + + If the log record has an attribute 'bypass_rank_filter' set to True, + the record is allowed. Otherwise, only messages from rank 0 are allowed. + + Args: + record (logging.LogRecord): The log record to be evaluated. + + Returns: + bool: True if the log record should be logged, False otherwise. + """ super().__init__() self.rank = rank def filter(self, record): + """ + Decide whether to log the provided record. + + If the log record has an attribute 'bypass_rank_filter' set to True, + the record is allowed. Otherwise, only messages from rank 0 are allowed. + + Args: + record (logging.LogRecord): The log record to be evaluated. + + Returns: + bool: True if the log record should be logged, False otherwise. + """ # If the log record explicitly requests to bypass the rank check, allow it. if getattr(record, 'all_ranks', False): return True From 346170269f9738115b60e1d44368a96a0101aa99 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 28 May 2025 13:01:35 -0700 Subject: [PATCH 04/10] rename env var Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/loggers/log_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo_automodel/loggers/log_utils.py b/nemo_automodel/loggers/log_utils.py index e6648831d7..9c1f24baa1 100644 --- a/nemo_automodel/loggers/log_utils.py +++ b/nemo_automodel/loggers/log_utils.py @@ -121,7 +121,7 @@ def setup_logging( from specific modules. Logging Level Precedence: - 1. Env var `NEMOLM_LOGGING_LEVEL` + 1. Env var `LOGGING_LEVEL` 2. `logging_level` argument 3. Default: `logging.INFO` @@ -133,7 +133,7 @@ def setup_logging( loggers. If False (default), only sets the level for the root logger and loggers starting with 'nemo'. """ - env_logging_level = os.getenv("NEMOLM_LOGGING_LEVEL", None) + env_logging_level = os.getenv("LOGGING_LEVEL", None) if env_logging_level is not None: logging_level = int(env_logging_level) From 8556080f300da2fe8938b71c3cbc7c4f4a4773b6 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Jun 2025 14:54:59 -0700 Subject: [PATCH 05/10] fix Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/loggers/log_utils.py | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/nemo_automodel/loggers/log_utils.py b/nemo_automodel/loggers/log_utils.py index 9c1f24baa1..6a451ded40 100644 --- a/nemo_automodel/loggers/log_utils.py +++ b/nemo_automodel/loggers/log_utils.py @@ -17,6 +17,7 @@ from functools import partial from logging import Filter, LogRecord from typing import Callable, Optional, Union +import os logger = logging.getLogger(__name__) @@ -28,27 +29,11 @@ class RankFilter(logging.Filter): If a log record has an attribute 'all_ranks' set to True, the log message will always be output regardless of the process rank. """ - def __init__(self, rank): - """ - Decide whether to log the provided record. - - If the log record has an attribute 'bypass_rank_filter' set to True, - the record is allowed. Otherwise, only messages from rank 0 are allowed. - - Args: - record (logging.LogRecord): The log record to be evaluated. - - Returns: - bool: True if the log record should be logged, False otherwise. - """ - super().__init__() - self.rank = rank - def filter(self, record): """ Decide whether to log the provided record. - If the log record has an attribute 'bypass_rank_filter' set to True, + If the log record has an attribute 'all_ranks' set to True, the record is allowed. Otherwise, only messages from rank 0 are allowed. Args: @@ -60,8 +45,8 @@ def filter(self, record): # If the log record explicitly requests to bypass the rank check, allow it. if getattr(record, 'all_ranks', False): return True - # Otherwise, only allow logs from rank 0. - return self.rank == 0 + # TODO(@akoumparouli): make this PP aware. + return int(os.environ.get('LOCAL_RANK', '0')) != 0 def warning_filter(record: LogRecord) -> bool: """Logging filter to exclude WARNING level messages. @@ -147,5 +132,6 @@ def setup_logging( if filter_warning: add_filter_to_all_loggers(warning_filter) + logging.getLogger().addFilter(RankFilter()) if modules_to_filter: add_filter_to_all_loggers(partial(module_filter, modules_to_filter=modules_to_filter)) From 0bb4209d4be0c29fa0a58a778cfd7784ab17ef5d Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Jun 2025 14:55:14 -0700 Subject: [PATCH 06/10] switch from print to logging Signed-off-by: Alexandros Koumparoulis --- recipes/llm/finetune.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/recipes/llm/finetune.py b/recipes/llm/finetune.py index 2229c302f2..c2a954cac4 100644 --- a/recipes/llm/finetune.py +++ b/recipes/llm/finetune.py @@ -21,6 +21,11 @@ from nemo_automodel.training.step_scheduler import StepScheduler from nemo_automodel.utils.dist_utils import reduce_loss, get_sync_ctx, rescale_gradients, clip_gradients from transformers import AutoTokenizer +from nemo_automodel.loggers.log_utils import setup_logging + +import logging +logger = logging.getLogger(__name__) + # --------------------------- # Stateless helper functions # --------------------------- @@ -183,6 +188,8 @@ def setup(self): NotImplemented: Raises if it tries to restore a checkpoint; will be removed. """ self.dist_env = build_distributed(self.cfg.get("dist_env", {})) + # setups logging and adds the rankfilter to logging + setup_logging() self.device_mesh = None self.model_wrapper = None @@ -213,7 +220,7 @@ def setup(self): config=self.cfg, settings=Settings(silent=True), ) - print("🚀 View run at {}".format(run.url)) + logging.info("🚀 View run at {}".format(run.url)) # Build components self.model = build_model(self.dist_env.device, self.cfg.model, self.cfg.get('peft', None), self.model_wrapper) @@ -395,13 +402,10 @@ def _run_train_step(self, batch, is_optim_step, clip_norm=1.0): # log reporting_loss = self.log_train_metrics(grad_norm) - if self.dist_env.is_main: - print( - f"step {self.step_scheduler.step} | " - f"epoch {self.step_scheduler.epoch} | " - f"loss {reporting_loss:.6f} | " - f"grad_norm {grad_norm:.6f}" + logging.info("step {} | epoch {} | loss {:.6f} | grad_norm {:.6f}".format( + self.step_scheduler.step, self.step_scheduler.epoch, reporting_loss, grad_norm ) + ) @torch.no_grad() @@ -454,11 +458,10 @@ def _run_validation_epoch(self) -> float: "epoch": self.step_scheduler.epoch } ) - print( - f"[val] step {self.step_scheduler.step} | " - f"epoch {self.step_scheduler.epoch} | " - f"loss {val_loss:.4f}", + logging.info("[val] step {} | epoch {} | loss {:.6f}".format( + self.step_scheduler.step, self.step_scheduler.epoch, val_loss ) + ) def log_train_metrics(self, grad_norm): """ From b8f39cd66660750f7550793ab365f9d1f8d05500 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Jun 2025 14:57:24 -0700 Subject: [PATCH 07/10] fix Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/utils/dist_utils.py | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/nemo_automodel/utils/dist_utils.py b/nemo_automodel/utils/dist_utils.py index 35d692acf5..b85dc13e38 100644 --- a/nemo_automodel/utils/dist_utils.py +++ b/nemo_automodel/utils/dist_utils.py @@ -24,9 +24,9 @@ from contextlib import ContextDecorator, nullcontext import torch.distributed as dist import yaml - from nemo_automodel.utils.yaml_utils import safe_yaml_representers +logger = logging.getLogger(__name__) class FirstRankPerNode(ContextDecorator): @@ -136,16 +136,6 @@ def get_local_rank_preinit() -> int: return int(os.getenv("LOCAL_RANK", "0")) -def print_rank_0(message: str) -> None: - """Print a message only on global rank 0. - - Args: - message: The message string to print. - """ - rank = get_rank_safe() - if rank == 0: - print(message, flush=True) - def is_last_rank() -> bool: """Check if the current rank is the last rank in the default process group. @@ -156,18 +146,6 @@ def is_last_rank() -> bool: return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1) -def print_rank_last(message: str) -> None: - """Print a message only on the last rank of the default process group. - - Args: - message: The message string to print. - """ - if torch.distributed.is_initialized(): - if is_last_rank(): - print(message, flush=True) - else: - print(message, flush=True) - def append_to_progress_log(save_dir: str, string: str, barrier: bool = True) -> None: """Append a formatted string to the progress log file (rank 0 only). @@ -203,7 +181,7 @@ def barrier_and_log(string: str) -> None: if torch.distributed.is_initialized(): torch.distributed.barrier() time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - print_rank_0(f"[{string}] datetime: {time_str} ") + logger.info(f"[{}] datetime: {} ".format(string, time_str)) def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any): @@ -333,4 +311,4 @@ def clip_gradients(model, clip_norm, foreach=True): if isinstance(grad_norm, torch.distributed.tensor.DTensor): grad_norm = grad_norm.full_tensor() torch.nn.utils.clip_grads_with_norm_([p for p in model.parameters()], clip_norm, grad_norm, foreach=foreach) - return grad_norm \ No newline at end of file + return grad_norm From 5a59195609c2e65eeec71afd571ec046ddded3e9 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Jun 2025 15:01:09 -0700 Subject: [PATCH 08/10] fix Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/utils/dist_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_automodel/utils/dist_utils.py b/nemo_automodel/utils/dist_utils.py index b85dc13e38..ea031a30b6 100644 --- a/nemo_automodel/utils/dist_utils.py +++ b/nemo_automodel/utils/dist_utils.py @@ -181,7 +181,7 @@ def barrier_and_log(string: str) -> None: if torch.distributed.is_initialized(): torch.distributed.barrier() time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - logger.info(f"[{}] datetime: {} ".format(string, time_str)) + logger.info("[{}] datetime: {} ".format(string, time_str)) def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any): From 6295aecbe47d635d9aa450a5abe2bd9bef4cb551 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 6 Jun 2025 15:14:23 -0700 Subject: [PATCH 09/10] add peak memory Signed-off-by: Alexandros Koumparoulis --- recipes/llm/finetune.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/recipes/llm/finetune.py b/recipes/llm/finetune.py index c2a954cac4..fb45733455 100644 --- a/recipes/llm/finetune.py +++ b/recipes/llm/finetune.py @@ -187,6 +187,7 @@ def setup(self): Raises: NotImplemented: Raises if it tries to restore a checkpoint; will be removed. """ + torch.cuda.reset_peak_memory_stats() self.dist_env = build_distributed(self.cfg.get("dist_env", {})) # setups logging and adds the rankfilter to logging setup_logging() @@ -402,10 +403,12 @@ def _run_train_step(self, batch, is_optim_step, clip_norm=1.0): # log reporting_loss = self.log_train_metrics(grad_norm) - logging.info("step {} | epoch {} | loss {:.6f} | grad_norm {:.6f}".format( - self.step_scheduler.step, self.step_scheduler.epoch, reporting_loss, grad_norm + logging.info("step {} | epoch {} | loss {:.6f} | grad_norm {:.6f} | mem: {:.2f} GiB".format( + self.step_scheduler.step, self.step_scheduler.epoch, reporting_loss, grad_norm, + torch.cuda.max_memory_allocated() / 1024 ** 3 ) ) + torch.cuda.reset_peak_memory_stats() @torch.no_grad() From 118a9def7fb36fa22f4dc6935d8a68c6ddff9a5e Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Sat, 7 Jun 2025 11:26:01 -0700 Subject: [PATCH 10/10] disable logging in ranks > 0 Signed-off-by: Alexandros Koumparoulis --- nemo_automodel/loggers/log_utils.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/nemo_automodel/loggers/log_utils.py b/nemo_automodel/loggers/log_utils.py index 6a451ded40..4bf7af4c76 100644 --- a/nemo_automodel/loggers/log_utils.py +++ b/nemo_automodel/loggers/log_utils.py @@ -26,27 +26,24 @@ class RankFilter(logging.Filter): A logging filter that controls log output based on the process rank. This filter allows log messages only for rank 0 by default. - If a log record has an attribute 'all_ranks' set to True, - the log message will always be output regardless of the process rank. """ def filter(self, record): """ Decide whether to log the provided record. - If the log record has an attribute 'all_ranks' set to True, - the record is allowed. Otherwise, only messages from rank 0 are allowed. - Args: record (logging.LogRecord): The log record to be evaluated. Returns: bool: True if the log record should be logged, False otherwise. """ - # If the log record explicitly requests to bypass the rank check, allow it. - if getattr(record, 'all_ranks', False): - return True # TODO(@akoumparouli): make this PP aware. - return int(os.environ.get('LOCAL_RANK', '0')) != 0 + if 'LOCAL_RANK' in os.environ: + rank = int(os.environ.get('LOCAL_RANK')) + # permantly disable logging for rank != 0 + if rank > 0: + logging.disable(logging.CRITICAL) + return True def warning_filter(record: LogRecord) -> bool: """Logging filter to exclude WARNING level messages.