diff --git a/docs/design_docs/generation.md b/docs/design_docs/generation.md index 8dda8a028a..84f450c7cc 100644 --- a/docs/design_docs/generation.md +++ b/docs/design_docs/generation.md @@ -95,26 +95,20 @@ The {py:class}`UpdatableVllmInternalWorker AutoTokenizer: + """Get the tokenizer and set pad token to eos token if it is not already set.""" + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer diff --git a/nemo_reinforcer/evals/eval.py b/nemo_reinforcer/evals/eval.py index 33d486a4d5..a1a4cad74b 100644 --- a/nemo_reinforcer/evals/eval.py +++ b/nemo_reinforcer/evals/eval.py @@ -105,12 +105,6 @@ def setup( backend = generation_config["backend"] assert backend == "vllm", "Only vLLM backend is supported for evaluation" - # set vllm config - generation_config["vllm_cfg"]["load_format"] = "auto" - generation_config["vllm_cfg"]["skip_tokenizer_init"] = False - generation_config["stop_token_ids"] = [tokenizer.eos_token_id] - generation_config["pad_token"] = tokenizer.pad_token_id - # initialize vllm generation vllm_generation = VllmGeneration(cluster=cluster, config=generation_config) print( diff --git a/nemo_reinforcer/models/generation/interfaces.py b/nemo_reinforcer/models/generation/interfaces.py index da7e737784..f81d5d897d 100644 --- a/nemo_reinforcer/models/generation/interfaces.py +++ b/nemo_reinforcer/models/generation/interfaces.py @@ -15,6 +15,8 @@ from typing import Any, TypedDict, Union, Tuple, List import torch +from transformers import AutoTokenizer + from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict @@ -45,8 +47,8 @@ def verify_right_padding( ) assert pad_value is not None, ( - "Tokenizer does not have a pad token assigned. \n" - "If the default tokenizer does not have a pad token, you can assign it the value of eos token by tokenizer.pad_token = tokenizer.eos_token" + "Tokenizer does not have a pad_token_id. \n" + "Please use the nemo_reinforcer.algorithms.utils.get_tokenizer(...) API which sets pad_token_id if absent." ) # Determine which type of data we're dealing with @@ -107,7 +109,28 @@ class GenerationConfig(TypedDict): top_k: int model_name: str stop_token_ids: List[int] - pad_token: int + pad_token_id: int + + +def configure_generation_config( + config: GenerationConfig, tokenizer: AutoTokenizer, is_eval=False +): + """Apply specific configurations to generation config.""" + # tokenizer setting + config["pad_token_id"] = tokenizer.pad_token_id + # When https://github.com/NVIDIA/reinforcer/issues/57 is fixed, we should update stop_token_ids below. + config["stop_token_ids"] = [tokenizer.eos_token_id] + + # vllm setting + if config["backend"] == "vllm": + if is_eval: + config["vllm_cfg"]["skip_tokenizer_init"] = False + config["vllm_cfg"]["load_format"] = "auto" + else: + config["vllm_cfg"]["skip_tokenizer_init"] = True + config["vllm_cfg"]["load_format"] = "dummy" + + return config class GenerationDatumSpec(TypedDict): diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index 4ffbb3e2ff..3f8528f549 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -218,7 +218,7 @@ def generate( f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" ) is_right_padded, error_msg = verify_right_padding( - data, pad_value=self.cfg["pad_token"] + data, pad_value=self.cfg["pad_token_id"] ) if not is_right_padded: warnings.warn( @@ -282,7 +282,7 @@ def generate( # Create a new tensor with the right size and fill with padding token full_output = torch.full( - (total_length,), self.cfg["pad_token"], dtype=input_ids.dtype + (total_length,), self.cfg["pad_token_id"], dtype=input_ids.dtype ) # Copy original input (with padding) into the beginning @@ -516,7 +516,9 @@ def generate( results = self.worker_group.get_all_worker_results(future_bundle) # Combine results from all tied worker groups - combined = BatchedDataDict.from_batches(results) + combined = BatchedDataDict.from_batches( + results, pad_value_dict={"output_ids": self.cfg["pad_token_id"]} + ) # Verify the output has all required fields required_keys = [ @@ -557,7 +559,9 @@ def generate_text( results = self.worker_group.get_all_worker_results(future_bundle) # Combine results from all tied worker groups - combined = BatchedDataDict.from_batches(results) + combined = BatchedDataDict.from_batches( + results, pad_value_dict={"output_ids": self.cfg["pad_token_id"]} + ) # Verify the output has all required fields required_keys = ["texts"] diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 3a316ba3ae..c36bc0fec7 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -28,9 +28,10 @@ StateDictType, ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.algorithms.utils import get_tokenizer from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup @@ -97,10 +98,7 @@ def __init__( ) else: self.reference_model = None - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - # If no pad token is defined, you might need: - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer = get_tokenizer(model_name) # ------------------------------------------------ # 3) Move to GPU + Composable FSDP @@ -519,7 +517,9 @@ def generate( batch_size, seq_len = input_ids.shape # Convert right padding to left padding - left_padded_input_ids = torch.zeros_like(input_ids) + left_padded_input_ids = torch.full_like( + input_ids, gen_cfg["pad_token_id"] + ) left_padded_attention_mask = torch.zeros( (batch_size, seq_len), dtype=torch.long, device=input_ids.device ) @@ -569,7 +569,12 @@ def generate( micro_batches.append(mb) # Get lengths, pad, and concatenate all batches - return_data = BatchedDataDict.from_batches(micro_batches) + return_data = BatchedDataDict.from_batches( + micro_batches, + pad_value_dict={ + "left_padded_output_ids": self.cfg["generation"]["pad_token_id"] + }, + ) # Calculate the lengths of generations for each sequence by finding stop tokens generation_lengths = [] @@ -581,8 +586,9 @@ def generate( max_seq_len = max( [seq.size(0) for seq in return_data["left_padded_output_ids"]] ) - right_padded_output_ids = torch.zeros( + right_padded_output_ids = torch.full( (batch_size, max_seq_len), + self.cfg["generation"]["pad_token_id"], dtype=return_data["left_padded_output_ids"][0].dtype, device=return_data["left_padded_output_ids"][0].device, ) @@ -1017,7 +1023,8 @@ def generate( "generate", sharded_data, common_kwargs={"greedy": greedy} ) result = BatchedDataDict.from_batches( - self.worker_group.get_all_worker_results(futures) + self.worker_group.get_all_worker_results(futures), + pad_value_dict={"output_ids": self.cfg["generation"]["pad_token_id"]}, ) # Verify the output has all required fields diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index a5bcda1ff6..ed90267d10 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy + import pytest import torch import ray -import numpy as np - -from transformers import AutoTokenizer +from nemo_reinforcer.algorithms.utils import get_tokenizer from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.models.generation.interfaces import configure_generation_config from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig @@ -41,19 +42,6 @@ } -def configure_vllm_with_tokenizer(vllm_config, tokenizer, is_eval=False): - """Apply tokenizer-specific configurations to vLLM config.""" - if is_eval: - vllm_config["vllm_cfg"]["skip_tokenizer_init"] = False - vllm_config["vllm_cfg"]["load_format"] = "auto" - else: - vllm_config["vllm_cfg"]["skip_tokenizer_init"] = True - vllm_config["vllm_cfg"]["load_format"] = "dummy" - vllm_config["pad_token"] = tokenizer.pad_token_id - vllm_config["stop_token_ids"] = [tokenizer.eos_token_id] - return vllm_config - - @pytest.fixture(scope="module") def check_vllm_available(): """Skip tests if vLLM is not installed.""" @@ -82,9 +70,7 @@ def cluster(): def tokenizer(): """Initialize tokenizer for the test model.""" model_name = basic_vllm_test_config["model_name"] - tokenizer = AutoTokenizer.from_pretrained(model_name) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + tokenizer = get_tokenizer(model_name) return tokenizer @@ -93,7 +79,7 @@ def policy(cluster, tokenizer, check_vllm_available): """Initialize the vLLM policy.""" # Create separate configs for each policy vllm_config = basic_vllm_test_config.copy() - vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) + vllm_config = configure_generation_config(vllm_config, tokenizer) policy = VllmGeneration(cluster, vllm_config) yield policy @@ -213,7 +199,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Create separate configs for each policy vllm_config = basic_vllm_test_config.copy() - vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) + vllm_config = configure_generation_config(vllm_config, tokenizer) # Create HF-specific config with required parameters hf_config = { @@ -252,6 +238,17 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): "Where is the sun?", ] + expected_generations = [ + "Write a story about a magical forest. The forest is magical because it is full of", + "Explain how photosynthesis works\nExplain how photosynthesis works\nPhotosynthesis", + "What are the benefits of exercise? The benefits of exercise are many and varied. It", + "Describe the water cycle in your own words.\nDescribe the water cycle in", + "What is the capital of France? A. Paris B. New York C. Washington", + "Who is the president of the USA? Who is the president of the USA? Who is", + "What is the capital of the moon? A. Houston, Texas B. New York City", + "Where is the sun? Where is the moon? Where is the earth?", + ] + # Tokenize the prompts the same way as in test_hf_ray_policy tokenized = tokenizer( prompts, @@ -286,7 +283,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Step 1: Use vLLM for generation print("Using vLLM policy for fast generation...") - generation_results = vllm_policy.generate(test_input_data) + generation_results = vllm_policy.generate(test_input_data, greedy=True) vllm_policy.finish_generation() # Validate generation outputs assert "output_ids" in generation_results, ( @@ -301,6 +298,9 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): generation_results["output_ids"], skip_special_tokens=True ) print(f"vLLM generated texts: {generated_texts}") + assert generated_texts == expected_generations, ( + "Output should be the same as the expected output" + ) # Run logprob calculation with HF policy to verify @@ -401,9 +401,9 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): def test_vllm_policy_tensor_parallel(cluster, tokenizer): """Test vLLM policy with tensor parallelism > 1.""" # Configure with tensor_parallel_size=2 - tp_config = basic_vllm_test_config.copy() - tp_config = configure_vllm_with_tokenizer(tp_config, tokenizer) - tp_config["tensor_parallel_size"] = 2 + tp_config = deepcopy(basic_vllm_test_config) + tp_config = configure_generation_config(tp_config, tokenizer) + tp_config["vllm_cfg"]["tensor_parallel_size"] = 2 # Ensure we specify the distributed executor backend tp_config["vllm_kwargs"] = {"distributed_executor_backend": "ray"} @@ -466,7 +466,7 @@ def test_vllm_generate_text(cluster, tokenizer): # Create separate configs for each policy vllm_config = basic_vllm_test_config.copy() - vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer, is_eval=True) + vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) # Ensure we can get same output assert vllm_config["model_name"] == "meta-llama/Llama-3.2-1B", ( @@ -499,8 +499,8 @@ def test_vllm_weight_update_and_prefix_cache_reset( from nemo_reinforcer.models.policy.hf_policy import HfPolicy # Create configs - vllm_config = basic_vllm_test_config.copy() - vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer, is_eval=True) + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) vllm_config["vllm_cfg"]["tensor_parallel_size"] = tensor_parallel_size if tensor_parallel_size > 1: vllm_config["vllm_kwargs"] = {"distributed_executor_backend": "ray"} diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index ded244feac..76926960cf 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -16,13 +16,14 @@ import pprint import torch -from nemo_reinforcer.models.policy import PolicyConfig -from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.algorithms.utils import get_tokenizer from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict -from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.models.generation.interfaces import configure_generation_config +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.models.policy.hf_policy import HfPolicy from tests.unit.test_utils import simple_loss, nll_loss -from transformers import AutoTokenizer basic_llama_test_config: PolicyConfig = { @@ -66,8 +67,16 @@ def gc_collect(): gc.collect() +@pytest.fixture(scope="function") +def tokenizer(): + """Initialize tokenizer for the test model.""" + model_name = basic_llama_test_config["model_name"] + tokenizer = get_tokenizer(model_name) + return tokenizer + + @pytest.fixture -def policy_setup(): +def policy_setup(tokenizer): """Setup and teardown for policy tests - creates a virtual cluster and policy.""" policy = None cluster = None @@ -84,6 +93,7 @@ def policy_setup(): ) config = basic_llama_test_config + config["generation"] = configure_generation_config(config["generation"], tokenizer) print("Creating HfPolicy...") policy = HfPolicy(cluster=cluster, config=config) @@ -278,7 +288,7 @@ def verify_loss_tensor(loss_tensor): @pytest.fixture -def generation_setup(request): +def generation_setup(request, tokenizer): """Setup and teardown specifically for generation tests.""" policy = None cluster = None @@ -298,6 +308,9 @@ def generation_setup(request): ) config = basic_llama_test_config + config["generation"] = configure_generation_config( + config["generation"], tokenizer + ) print("Creating generation HfPolicy...") policy = HfPolicy( @@ -331,8 +344,6 @@ def generation_setup(request): ] # Tokenize the prompts - tokenizer = AutoTokenizer.from_pretrained(config["model_name"]) - tokenizer.pad_token = tokenizer.eos_token tokenized = tokenizer( prompts, padding=True, @@ -353,7 +364,7 @@ def generation_setup(request): ) # Provide the resources to the test - yield policy, cluster, data, tokenizer, prompts, expected_generations + yield policy, cluster, data, prompts, expected_generations except Exception as e: print(f"Error during generation setup: {e}") @@ -367,8 +378,8 @@ def generation_setup(request): @pytest.mark.timeout(180) @pytest.mark.parametrize("generation_setup", [False], indirect=True) -def test_hf_policy_generation(generation_setup, tracker): - policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup +def test_hf_policy_generation(generation_setup, tokenizer, tracker): + policy, cluster, data, prompts, expected_generations = generation_setup # Verify resources were created properly assert policy is not None, "Generation policy was not created properly" @@ -386,6 +397,10 @@ def test_hf_policy_generation(generation_setup, tracker): # Verify results assert "output_ids" in results, "Generation results should contain 'output_ids'" output_ids = results["output_ids"] + generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + assert generated_texts == expected_generations, ( + "Output should be the same as the expected output" + ) # run logprob calculation manually to verify fprop_logprob_data = BatchedDataDict( @@ -455,7 +470,7 @@ def test_hf_policy_generation(generation_setup, tracker): @pytest.mark.timeout(180) @pytest.mark.parametrize("generation_setup", [True], indirect=True) def test_all_hf_policy_generation_lps_ref_training(generation_setup): - policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup + policy, cluster, data, prompts, expected_generations = generation_setup # Verify resources were created properly assert policy is not None, "Generation policy was not created properly"