Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions nemo_automodel/loggers/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,33 @@
from functools import partial
from logging import Filter, LogRecord
from typing import Callable, Optional, Union
import os

logger = logging.getLogger(__name__)

class RankFilter(logging.Filter):
"""
A logging filter that controls log output based on the process rank.

This filter allows log messages only for rank 0 by default.
"""
def filter(self, record):
"""
Decide whether to log the provided record.

Args:
record (logging.LogRecord): The log record to be evaluated.

Returns:
bool: True if the log record should be logged, False otherwise.
"""
# TODO(@akoumparouli): make this PP aware.
if 'LOCAL_RANK' in os.environ:
rank = int(os.environ.get('LOCAL_RANK'))
# permantly disable logging for rank != 0
if rank > 0:
logging.disable(logging.CRITICAL)
return True

def warning_filter(record: LogRecord) -> bool:
"""Logging filter to exclude WARNING level messages.
Expand All @@ -30,10 +54,7 @@ def warning_filter(record: LogRecord) -> bool:
Returns:
False if the record level is WARNING, True otherwise.
"""
if record.levelno == logging.WARNING:
return False

return True
return record.levelno != logging.WARNING


def module_filter(record: LogRecord, modules_to_filter: list[str]) -> bool:
Expand Down Expand Up @@ -82,7 +103,7 @@ def setup_logging(
from specific modules.

Logging Level Precedence:
1. Env var `NEMOLM_LOGGING_LEVEL`
1. Env var `LOGGING_LEVEL`
2. `logging_level` argument
3. Default: `logging.INFO`

Expand All @@ -94,7 +115,7 @@ def setup_logging(
loggers. If False (default), only sets the level
for the root logger and loggers starting with 'nemo'.
"""
env_logging_level = os.getenv("NEMOLM_LOGGING_LEVEL", None)
env_logging_level = os.getenv("LOGGING_LEVEL", None)
if env_logging_level is not None:
logging_level = int(env_logging_level)

Expand All @@ -108,5 +129,6 @@ def setup_logging(

if filter_warning:
add_filter_to_all_loggers(warning_filter)
logging.getLogger().addFilter(RankFilter())
if modules_to_filter:
add_filter_to_all_loggers(partial(module_filter, modules_to_filter=modules_to_filter))
28 changes: 3 additions & 25 deletions nemo_automodel/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from contextlib import ContextDecorator, nullcontext
import torch.distributed as dist
import yaml

from nemo_automodel.utils.yaml_utils import safe_yaml_representers

logger = logging.getLogger(__name__)


class FirstRankPerNode(ContextDecorator):
Expand Down Expand Up @@ -136,16 +136,6 @@ def get_local_rank_preinit() -> int:
return int(os.getenv("LOCAL_RANK", "0"))


def print_rank_0(message: str) -> None:
"""Print a message only on global rank 0.

Args:
message: The message string to print.
"""
rank = get_rank_safe()
if rank == 0:
print(message, flush=True)


def is_last_rank() -> bool:
"""Check if the current rank is the last rank in the default process group.
Expand All @@ -156,18 +146,6 @@ def is_last_rank() -> bool:
return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1)


def print_rank_last(message: str) -> None:
"""Print a message only on the last rank of the default process group.

Args:
message: The message string to print.
"""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)


def append_to_progress_log(save_dir: str, string: str, barrier: bool = True) -> None:
"""Append a formatted string to the progress log file (rank 0 only).
Expand Down Expand Up @@ -203,7 +181,7 @@ def barrier_and_log(string: str) -> None:
if torch.distributed.is_initialized():
torch.distributed.barrier()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print_rank_0(f"[{string}] datetime: {time_str} ")
logger.info("[{}] datetime: {} ".format(string, time_str))


def log_single_rank(logger: logging.Logger, *args: Any, rank: int = 0, **kwargs: Any):
Expand Down Expand Up @@ -333,4 +311,4 @@ def clip_gradients(model, clip_norm, foreach=True):
if isinstance(grad_norm, torch.distributed.tensor.DTensor):
grad_norm = grad_norm.full_tensor()
torch.nn.utils.clip_grads_with_norm_([p for p in model.parameters()], clip_norm, grad_norm, foreach=foreach)
return grad_norm
return grad_norm
28 changes: 17 additions & 11 deletions recipes/llm/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from nemo_automodel.training.step_scheduler import StepScheduler
from nemo_automodel.utils.dist_utils import reduce_loss, get_sync_ctx, rescale_gradients, clip_gradients
from transformers import AutoTokenizer
from nemo_automodel.loggers.log_utils import setup_logging

import logging
logger = logging.getLogger(__name__)

# ---------------------------
# Stateless helper functions
# ---------------------------
Expand Down Expand Up @@ -182,7 +187,10 @@ def setup(self):
Raises:
NotImplemented: Raises if it tries to restore a checkpoint; will be removed.
"""
torch.cuda.reset_peak_memory_stats()
self.dist_env = build_distributed(self.cfg.get("dist_env", {}))
# setups logging and adds the rankfilter to logging
setup_logging()

self.device_mesh = None
self.model_wrapper = None
Expand Down Expand Up @@ -213,7 +221,7 @@ def setup(self):
config=self.cfg,
settings=Settings(silent=True),
)
print("🚀 View run at {}".format(run.url))
logging.info("🚀 View run at {}".format(run.url))

# Build components
self.model = build_model(self.dist_env.device, self.cfg.model, self.cfg.get('peft', None), self.model_wrapper)
Expand Down Expand Up @@ -395,13 +403,12 @@ def _run_train_step(self, batch, is_optim_step, clip_norm=1.0):

# log
reporting_loss = self.log_train_metrics(grad_norm)
if self.dist_env.is_main:
print(
f"step {self.step_scheduler.step} | "
f"epoch {self.step_scheduler.epoch} | "
f"loss {reporting_loss:.6f} | "
f"grad_norm {grad_norm:.6f}"
logging.info("step {} | epoch {} | loss {:.6f} | grad_norm {:.6f} | mem: {:.2f} GiB".format(
self.step_scheduler.step, self.step_scheduler.epoch, reporting_loss, grad_norm,
torch.cuda.max_memory_allocated() / 1024 ** 3
)
)
torch.cuda.reset_peak_memory_stats()


@torch.no_grad()
Expand Down Expand Up @@ -454,11 +461,10 @@ def _run_validation_epoch(self) -> float:
"epoch": self.step_scheduler.epoch
}
)
print(
f"[val] step {self.step_scheduler.step} | "
f"epoch {self.step_scheduler.epoch} | "
f"loss {val_loss:.4f}",
logging.info("[val] step {} | epoch {} | loss {:.6f}".format(
self.step_scheduler.step, self.step_scheduler.epoch, val_loss
)
)

def log_train_metrics(self, grad_norm):
"""
Expand Down