Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions docs/design_docs/generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,33 +95,31 @@ The {py:class}`UpdatableVllmInternalWorker <nemo_reinforcer.models.generation.vl
To use a generation backend:

```python
from transformers import AutoTokenizer

from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig
from nemo_reinforcer.algorithms.utils import get_tokenizer
from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
from nemo_reinforcer.models.generation.interfaces import configure_generation_config
from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig

# Set up the configuration
tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

config = VllmConfig(
model_name="Qwen/Qwen2.5-1.5B",
max_new_tokens=100,
temperature=0.7,
top_p=1,
top_k=None,
stop_token_ids=[tokenizer.eos_token_id]
pad_token=tokenizer.pad_token_id,
skip_tokenizer_init=True,
backend="vllm",
vllm_cfg={
"tensor_parallel_size": 1,
"gpu_memory_utilization": 0.8,
"max_model_len": 2048,
}
)

# Configure config with tokenizer
tokenizer = get_tokenizer(config["model_name"])
config = configure_generation_config(config, tokenizer)

# Initialize the cluster and generation backend
cluster = RayVirtualCluster(...)
generator = VllmGeneration(cluster, config)
Expand Down
19 changes: 10 additions & 9 deletions examples/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
from transformers import AutoTokenizer

from examples.run_grpo_math import math_data_processor
from nemo_reinforcer.algorithms.utils import get_tokenizer
from nemo_reinforcer.data import MathDataConfig
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset
from nemo_reinforcer.data.interfaces import TaskDataSpec
from nemo_reinforcer.data.llm_message_utils import remap_dataset_keys
from nemo_reinforcer.distributed.virtual_cluster import init_ray
from nemo_reinforcer.environments.math_environment import MathEnvironment
from nemo_reinforcer.evals.eval import MasterConfig, run_env_eval, setup
from nemo_reinforcer.models.generation.interfaces import GenerationConfig
from nemo_reinforcer.models.generation.interfaces import configure_generation_config


def parse_args():
Expand All @@ -50,9 +51,7 @@ def parse_args():
return args, overrides


def setup_data(
data_config: MathDataConfig, generation_config: GenerationConfig, env_configs
):
def setup_data(tokenizer: AutoTokenizer, data_config: MathDataConfig, env_configs):
print("\n▶ Setting up data...")
math_task_spec = TaskDataSpec(
task_name="math",
Expand All @@ -73,10 +72,6 @@ def setup_data(
},
)

tokenizer = AutoTokenizer.from_pretrained(generation_config["model_name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

math_env = MathEnvironment.options(
runtime_env={"py_executable": MathEnvironment.DEFAULT_PY_EXECUTABLE}
).remote(env_configs["math"])
Expand Down Expand Up @@ -118,12 +113,18 @@ def main():
# Init ray
init_ray()

# Setup tokenizer
tokenizer = get_tokenizer(config["generation"]["model_name"])
config["generation"] = configure_generation_config(
config["generation"], tokenizer, is_eval=True
)

# Setup data
(
dataset,
math_env,
tokenizer,
) = setup_data(config["data"], config["generation"], config["env"])
) = setup_data(tokenizer, config["data"], config["env"])

# Setup
(
Expand Down
31 changes: 19 additions & 12 deletions examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
from collections import defaultdict
from typing import Any, Dict

from datasets import load_dataset
from omegaconf import OmegaConf
from transformers import AutoTokenizer

from nemo_reinforcer.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_reinforcer.algorithms.utils import get_tokenizer
from nemo_reinforcer.data import DataConfig
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset, rl_collate_fn
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset
from nemo_reinforcer.data.hf_datasets.openmathinstruct2 import OpenMathInstruct2Dataset
from nemo_reinforcer.data.interfaces import DatumSpec, LLMMessageLogType, TaskDataSpec
from nemo_reinforcer.distributed.virtual_cluster import init_ray
from nemo_reinforcer.environments.math_environment import MathEnvironment
from nemo_reinforcer.models.policy import PolicyConfig
from nemo_reinforcer.models.generation.interfaces import configure_generation_config
from nemo_reinforcer.utils.config import load_config, parse_hydra_overrides
from nemo_reinforcer.utils.logger import get_next_experiment_dir

Expand Down Expand Up @@ -172,7 +172,7 @@ def math_data_processor(
return output


def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs):
def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, env_configs):
print("\n▶ Setting up data...")
math_task_spec = TaskDataSpec(
task_name="math",
Expand All @@ -187,10 +187,6 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs
else:
raise ValueError(f"No processor for dataset {data_config['dataset_name']}.")

tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

task_data_processors = defaultdict(
lambda: (math_task_spec, openinstructmath2_data_processor)
)
Expand Down Expand Up @@ -220,7 +216,7 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs

task_to_env = defaultdict(lambda: math_env)
task_to_env["math"] = math_env
return dataset, val_dataset, task_to_env, task_to_env, tokenizer
return dataset, val_dataset, task_to_env, task_to_env


def main():
Expand Down Expand Up @@ -257,10 +253,20 @@ def main():

init_ray()

# setup data
dataset, val_dataset, task_to_env, val_task_to_env, tokenizer = setup_data(
config["data"], config["policy"], config["env"]
# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["model_name"])
config["policy"]["generation"] = configure_generation_config(
config["policy"]["generation"], tokenizer
)

# setup data
(
dataset,
val_dataset,
task_to_env,
val_task_to_env,
) = setup_data(tokenizer, config["data"], config["env"])

(
policy,
policy_generation,
Expand All @@ -273,6 +279,7 @@ def main():
grpo_state,
master_config,
) = setup(config, tokenizer, dataset, val_dataset)

grpo_train(
policy,
policy_generation,
Expand Down
29 changes: 17 additions & 12 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
from typing import Dict, Any

from omegaconf import OmegaConf
from transformers import AutoTokenizer

from nemo_reinforcer.algorithms.sft import MasterConfig, sft_train, setup
from nemo_reinforcer.distributed.virtual_cluster import init_ray
from nemo_reinforcer.utils.config import load_config
from nemo_reinforcer.utils.logger import get_next_experiment_dir
from nemo_reinforcer.algorithms.utils import get_tokenizer
from nemo_reinforcer.data import DataConfig, hf_datasets
from nemo_reinforcer.data.datasets import AllTaskProcessedDataset
from nemo_reinforcer.data.interfaces import TaskDataSpec, DatumSpec
from nemo_reinforcer.data.llm_message_utils import get_formatted_message_log
from transformers import AutoTokenizer
from nemo_reinforcer.models.policy import PolicyConfig
from nemo_reinforcer.distributed.virtual_cluster import init_ray
from nemo_reinforcer.utils.config import load_config
from nemo_reinforcer.utils.logger import get_next_experiment_dir


def parse_args():
Expand Down Expand Up @@ -83,7 +83,7 @@ def sft_preprocessor(
return output


def setup_data(data_config: DataConfig, policy_config: PolicyConfig):
def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
print("\n▶ Setting up data...")
data_cls = data_config["dataset_name"]
if data_cls == "open_assistant":
Expand All @@ -100,8 +100,6 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig):
val_dataset = data.formatted_ds["validation"]
sft_task_spec = data.task_spec

tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])

train_dataset = AllTaskProcessedDataset(
train_dataset,
tokenizer,
Expand All @@ -118,7 +116,7 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig):
max_seq_length=data_config["max_input_seq_length"],
)

return train_dataset, val_dataset, tokenizer, sft_task_spec
return train_dataset, val_dataset, sft_task_spec


def main():
Expand Down Expand Up @@ -152,10 +150,16 @@ def main():

init_ray()

# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["model_name"])

# setup data
dataset, val_dataset, tokenizer, sft_task_spec = setup_data(
config["data"], config["policy"]
)
(
dataset,
val_dataset,
sft_task_spec,
) = setup_data(tokenizer, config["data"])

(
policy,
cluster,
Expand All @@ -167,6 +171,7 @@ def main():
sft_save_state,
master_config,
) = setup(config, dataset, val_dataset)

sft_train(
policy,
train_dataloader,
Expand Down
5 changes: 0 additions & 5 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,6 @@ def setup(
# vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this)
backend = generation_config["backend"]
generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM
generation_config["vllm_cfg"]["skip_tokenizer_init"] = True
# When https://github.com/NVIDIA/reinforcer/issues/57 is fixed, we should update stop_token_ids below.
generation_config["stop_token_ids"] = [tokenizer.eos_token_id]
generation_config["pad_token"] = tokenizer.pad_token_id
generation_config["vllm_cfg"]["load_format"] = "dummy"

if backend == "hf":
policy_generation = None
Expand Down
9 changes: 9 additions & 0 deletions nemo_reinforcer/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import torch
from torch.masked import as_masked_tensor
from transformers import AutoTokenizer


def calculate_kl_penalty_joschu2020(
Expand Down Expand Up @@ -130,3 +131,11 @@ def set_seed(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def get_tokenizer(model_name: str) -> AutoTokenizer:
Comment thread
yuki-97 marked this conversation as resolved.
"""Get the tokenizer and set pad token to eos token if it is not already set."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
Comment thread
yuki-97 marked this conversation as resolved.
return tokenizer
6 changes: 0 additions & 6 deletions nemo_reinforcer/evals/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,6 @@ def setup(
backend = generation_config["backend"]
assert backend == "vllm", "Only vLLM backend is supported for evaluation"

# set vllm config
generation_config["vllm_cfg"]["load_format"] = "auto"
generation_config["vllm_cfg"]["skip_tokenizer_init"] = False
generation_config["stop_token_ids"] = [tokenizer.eos_token_id]
generation_config["pad_token"] = tokenizer.pad_token_id

# initialize vllm generation
vllm_generation = VllmGeneration(cluster=cluster, config=generation_config)
print(
Expand Down
29 changes: 26 additions & 3 deletions nemo_reinforcer/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import Any, TypedDict, Union, Tuple, List

import torch
from transformers import AutoTokenizer

from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict


Expand Down Expand Up @@ -45,8 +47,8 @@ def verify_right_padding(
)

assert pad_value is not None, (
"Tokenizer does not have a pad token assigned. \n"
"If the default tokenizer does not have a pad token, you can assign it the value of eos token by tokenizer.pad_token = tokenizer.eos_token"
"Tokenizer does not have a pad_token_id. \n"
"Please use the nemo_reinforcer.algorithms.utils.get_tokenizer(...) API which sets pad_token_id if absent."
)

# Determine which type of data we're dealing with
Expand Down Expand Up @@ -107,7 +109,28 @@ class GenerationConfig(TypedDict):
top_k: int
model_name: str
stop_token_ids: List[int]
pad_token: int
pad_token_id: int
Comment thread
yuki-97 marked this conversation as resolved.


def configure_generation_config(
config: GenerationConfig, tokenizer: AutoTokenizer, is_eval=False
):
"""Apply specific configurations to generation config."""
# tokenizer setting
config["pad_token_id"] = tokenizer.pad_token_id
# When https://github.com/NVIDIA/reinforcer/issues/57 is fixed, we should update stop_token_ids below.
config["stop_token_ids"] = [tokenizer.eos_token_id]

# vllm setting
if config["backend"] == "vllm":
if is_eval:
config["vllm_cfg"]["skip_tokenizer_init"] = False
config["vllm_cfg"]["load_format"] = "auto"
else:
config["vllm_cfg"]["skip_tokenizer_init"] = True
config["vllm_cfg"]["load_format"] = "dummy"

return config


class GenerationDatumSpec(TypedDict):
Expand Down
12 changes: 8 additions & 4 deletions nemo_reinforcer/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def generate(
f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}"
)
is_right_padded, error_msg = verify_right_padding(
data, pad_value=self.cfg["pad_token"]
data, pad_value=self.cfg["pad_token_id"]
)
if not is_right_padded:
warnings.warn(
Expand Down Expand Up @@ -282,7 +282,7 @@ def generate(

# Create a new tensor with the right size and fill with padding token
full_output = torch.full(
(total_length,), self.cfg["pad_token"], dtype=input_ids.dtype
(total_length,), self.cfg["pad_token_id"], dtype=input_ids.dtype
)

# Copy original input (with padding) into the beginning
Expand Down Expand Up @@ -516,7 +516,9 @@ def generate(
results = self.worker_group.get_all_worker_results(future_bundle)

# Combine results from all tied worker groups
combined = BatchedDataDict.from_batches(results)
combined = BatchedDataDict.from_batches(
results, pad_value_dict={"output_ids": self.cfg["pad_token_id"]}
)

# Verify the output has all required fields
required_keys = [
Expand Down Expand Up @@ -557,7 +559,9 @@ def generate_text(
results = self.worker_group.get_all_worker_results(future_bundle)

# Combine results from all tied worker groups
combined = BatchedDataDict.from_batches(results)
combined = BatchedDataDict.from_batches(
results, pad_value_dict={"output_ids": self.cfg["pad_token_id"]}
)

# Verify the output has all required fields
required_keys = ["texts"]
Expand Down
Loading