Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import datasets
from datasets import load_dataset
from nemo_lm.automodel.datasets.utils import SFTSingleTurnPreprocessor
from automodel.datasets.utils import SFTSingleTurnPreprocessor

class HellaSwag:
def __init__(self, path_or_dataset, tokenizer, split):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

from nemo_lm.automodel.utils.common_utils import log_single_rank
from automodel.utils.common_utils import log_single_rank

logger = logging.getLogger(__name__)

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pathlib import Path
from typing import Any, Optional

from nemo_lm.automodel.utils.common_utils import print_rank_last
from automodel.utils.common_utils import print_rank_last


def on_save_checkpoint_success(
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@

import torch

from nemo_lm.automodel.utils.import_utils import safe_import_from
from automodel.utils.import_utils import safe_import_from

linear_cross_entropy, HAVE_LINEAR_LOSS_CE = safe_import_from(
"cut_cross_entropy",
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import torch.distributed as dist
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

from nemo_lm.automodel.utils.dist_utils import FirstRankPerNode
from nemo_lm.automodel.loss import masked_cross_entropy
from nemo_lm.automodel.loss.linear_ce import HAVE_LINEAR_LOSS_CE, fused_linear_cross_entropy
from automodel.utils.dist_utils import FirstRankPerNode
from automodel.loss import masked_cross_entropy
from automodel.loss.linear_ce import HAVE_LINEAR_LOSS_CE, fused_linear_cross_entropy
# from nemo.utils import logging
from nemo_lm.automodel.utils.import_utils import safe_import
from automodel.utils.import_utils import safe_import


@torch.no_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from torch.optim.optimizer import Optimizer

from nemo_lm.automodel.utils.common_utils import log_single_rank
from automodel.utils.common_utils import log_single_rank

logger = logging.getLogger(__name__)

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@
import torch
from torch.nn import Module

# from nemo_lm.automodel.components.state import GlobalState, TrainState
# from nemo_lm.automodel.config import ConfigContainer
# from nemo_lm.automodel.utils.model_utils import unwrap_model
# from nemo_lm.automodel.utils import wandb_utils
from nemo_lm.automodel.training.checkpoint_utils import (
# from automodel.components.state import GlobalState, TrainState
# from automodel.config import ConfigContainer
# from automodel.utils.model_utils import unwrap_model
# from automodel.utils import wandb_utils
from automodel.training.checkpoint_utils import (
TRACKER_PREFIX,
checkpoint_exists,
get_checkpoint_run_config_filename,
get_checkpoint_train_state_filename,
read_run_config,
read_train_state,
)
# from nemo_lm.automodel.utils.checkpoint_utils import TRAIN_STATE_FILE
from nemo_lm.automodel.utils.dist_utils import (
# from automodel.utils.checkpoint_utils import TRAIN_STATE_FILE
from automodel.utils.dist_utils import (
get_local_rank_preinit,
get_rank_safe,
get_world_size_safe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
import yaml

from nemo_lm.automodel.utils.dist_utils import (
from automodel.utils.dist_utils import (
get_local_rank_preinit,
get_rank_safe,
get_world_size_safe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@
from torch.nn.parallel import DistributedDataParallel
from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig

from nemo_lm.automodel.components.data.hf_dataset import HFDatasetBuilder
from nemo_lm.automodel.components.loss.linear_ce import HAVE_LINEAR_LOSS_CE
from nemo_lm.automodel.components.loss.masked_ce import masked_cross_entropy
from nemo_lm.automodel.components.scheduler import OptimizerParamScheduler
from nemo_lm.automodel.utils.model_utils import JitConfig, TEConfig, jit_compile_model, te_accelerate
from automodel.components.data.hf_dataset import HFDatasetBuilder
from automodel.components.loss.linear_ce import HAVE_LINEAR_LOSS_CE
from automodel.components.loss.masked_ce import masked_cross_entropy
from automodel.components.scheduler import OptimizerParamScheduler
from automodel.utils.model_utils import JitConfig, TEConfig, jit_compile_model, te_accelerate
from nemo_lm.config.common import (
DistributedInitConfig,
LoggerConfig,
ProfilingConfig,
RNGConfig,
TrainingConfig,
)
from nemo_lm.automodel.utils.common_utils import get_rank_safe, get_world_size_safe
from nemo_lm.automodel.utils.config_utils import ConfigContainer as Container
from nemo_lm.automodel.utils.import_utils import safe_import
from automodel.utils.common_utils import get_rank_safe, get_world_size_safe
from automodel.utils.config_utils import ConfigContainer as Container
from automodel.utils.import_utils import safe_import

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,29 @@
from torch.nn.parallel import DistributedDataParallel
from transformers import AutoTokenizer

from nemo_lm.automodel.checkpointing import (
from automodel.checkpointing import (
checkpoint_and_decide_exit,
checkpoint_exists,
load_checkpoint,
save_checkpoint_and_time,
)
from nemo_lm.automodel.components.state import GlobalState
from nemo_lm.automodel.config import ConfigContainer
from nemo_lm.automodel.utils.distributed_utils import initialize_automodel
from nemo_lm.automodel.utils.train_utils import (
from automodel.components.state import GlobalState
from automodel.config import ConfigContainer
from automodel.utils.distributed_utils import initialize_automodel
from automodel.utils.train_utils import (
eval_log,
reduce_loss,
training_log,
)
from nemo_lm.config.common import ProfilingConfig
from nemo_lm.automodel.utils.common_utils import (
from automodel.utils.common_utils import (
append_to_progress_log,
barrier_and_log,
get_rank_safe,
get_world_size_safe,
print_rank_0,
)
from nemo_lm.automodel.utils.log_utils import setup_logging
from automodel.utils.log_utils import setup_logging

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

import torch
import torch.distributed
from dataclasses import dataclass

from nemo_lm.automodel.utils.dist_utils import (
from automodel.utils.dist_utils import (
get_local_rank_preinit,
get_rank_safe,
get_world_size_safe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from torch.nn.parallel import DistributedDataParallel

from nemo_lm.automodel.utils.import_utils import safe_import_from
from automodel.utils.import_utils import safe_import_from

te, HAVE_TE = safe_import_from("transformer_engine", "pytorch")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import torch

from nemo_lm.automodel.utils.dist_utils import get_rank_safe
from automodel.utils.dist_utils import get_rank_safe

class StatefulRNG:
def __init__(self, seed: int, ranked: bool = False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.tensorboard.writer import SummaryWriter

from nemo_lm.automodel.components.timers import Timers
from nemo_lm.automodel.config import ConfigContainer
from nemo_lm.automodel.utils.common_utils import dump_dataclass_to_yaml, get_rank_safe, get_world_size_safe
from nemo_lm.automodel.utils.sig_utils import DistributedSignalHandler
from automodel.components.timers import Timers
from automodel.config import ConfigContainer
from automodel.utils.common_utils import dump_dataclass_to_yaml, get_rank_safe, get_world_size_safe
from automodel.utils.sig_utils import DistributedSignalHandler


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch

from nemo_lm.automodel.utils.import_utils import is_torch_min_version
from automodel.utils.import_utils import is_torch_min_version

if is_torch_min_version("1.13.0"):
dist_all_gather_func = torch.distributed.all_gather_into_tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

import torch

from nemo_lm.automodel.components.state import GlobalState
from nemo_lm.automodel.config import ConfigContainer
from nemo_lm.automodel.utils.common_utils import (
from automodel.components.state import GlobalState
from automodel.config import ConfigContainer
from automodel.utils.common_utils import (
get_world_size_safe,
is_last_rank,
print_rank_last,
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import yaml
from omegaconf import OmegaConf

from nemo_lm.automodel.utils.instantiate_utils import InstantiationMode, instantiate
from nemo_lm.automodel.utils.yaml_utils import safe_yaml_representers
from automodel.utils.instantiate_utils import InstantiationMode, instantiate
from automodel.utils.yaml_utils import safe_yaml_representers

T = TypeVar("T", bound="ConfigContainer")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _try_bootstrap_pg(self) -> bool:
import torch.distributed
import yaml

from nemo_lm.automodel.utils.yaml_utils import safe_yaml_representers
from automodel.utils.yaml_utils import safe_yaml_representers


def get_rank_safe() -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.distributed

from nemo_lm.automodel.utils.common_utils import get_world_size_safe, print_rank_0
from automodel.utils.common_utils import get_world_size_safe, print_rank_0


def get_device(local_rank: Optional[int] = None) -> torch.device:
Expand Down
File renamed without changes.
Empty file.
6 changes: 3 additions & 3 deletions recipes/automodel_finetune.py → recipes/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch.nn as nn
from torch.utils.data import DataLoader

from nemo_lm.automodel.config.loader import load_yaml_config
from nemo_lm.automodel.training.init_utils import initialize_distributed
from nemo_lm.automodel.base_recipe import BaseRecipe
from automodel.config.loader import load_yaml_config
from automodel.training.init_utils import initialize_distributed
from automodel.base_recipe import BaseRecipe


# ---------------------------
Expand Down
18 changes: 9 additions & 9 deletions recipes/llama_3_2_1b_hellaswag.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
training:
_target_: nemo_lm.config.common.TrainingConfig
_target_: config.common.TrainingConfig
train_iters: 250
eval_interval: 1000
eval_iters: 4
Expand All @@ -10,18 +10,18 @@ distributed:
timeout_minutes: 1

rng:
_target_: nemo_lm.automodel.training.rng.StatefulRNG
_target_: automodel.training.rng.StatefulRNG
seed: 1111
ranked: true

model:
_target_: transformers.AutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.2-1B

loss_fn: nemo_lm.automodel.loss.masked_ce.masked_cross_entropy
loss_fn: automodel.loss.masked_ce.masked_cross_entropy

dataset:
_target_: nemo_lm.automodel.datasets.hellaswag.HellaSwag
_target_: automodel.datasets.hellaswag.HellaSwag
path_or_dataset: rowan/hellaswag
split: train
tokenizer:
Expand All @@ -30,11 +30,11 @@ dataset:

dataloader:
_target_: torchdata.stateful_dataloader.StatefulDataLoader
collate_fn: nemo_lm.automodel.datasets.utils.default_collater
collate_fn: automodel.datasets.utils.default_collater
batch_size: 1

validation_dataset:
_target_: nemo_lm.automodel.datasets.hellaswag.hellaswag
_target_: automodel.datasets.hellaswag.hellaswag
path_or_dataset: rowan/hellaswag
split: train
tokenizer:
Expand All @@ -53,7 +53,7 @@ optimizer:
# min_lr: 1.0e-5

scheduler:
_target_: nemo_lm.automodel.config.SchedulerConfig
_target_: automodel.config.SchedulerConfig
start_weight_decay: 0
end_weight_decay: 0
weight_decay_incr_style: constant
Expand All @@ -64,7 +64,7 @@ scheduler:
override_opt_param_scheduler: true

logger:
_target_: nemo_lm.config.common.LoggerConfig
_target_: config.common.LoggerConfig
wandb_project: nemo_automodel_sft_loop
wandb_entity: nvidia
wandb_exp_name: nemolm_automodel_Rowan_hellaswag_meta-llama_Llama-3.2-1B_gbs_256_seq_len_1024_lr_1.0e-5
Expand All @@ -81,7 +81,7 @@ logger:
- nemo.collections.llm.gpt.data.utils

checkpointer:
_target_: nemo_lm.automodel.training.checkpoint.TorchCheckpointer
_target_: automodel.training.checkpoint.TorchCheckpointer
# save_interval: 10000
# save: /tmp/nemo_run/checkpoints/automodel/Rowan_hellaswag_meta-llama_Llama-3.2-1B_gbs_256_seq_len_1024
# load: /tmp/nemo_run/checkpoints/automodel/Rowan_hellaswag_meta-llama_Llama-3.2-1B_gbs_256_seq_len_1024
Expand Down