diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 4041f31bc3..a367beec76 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -137,7 +137,7 @@ jobs: if: ${{ needs.pre-flight.outputs.run_ci == 'true' }} with: RUNNER: self-hosted-azure - TIMEOUT: 15 + TIMEOUT: 20 UNIT_TEST_SCRIPT: | cd /opt/reinforcer uv run --no-sync bash -x ./tests/run_unit.sh diff --git a/examples/configs/eval.yaml b/examples/configs/eval.yaml index a867e9617f..e319276094 100644 --- a/examples/configs/eval.yaml +++ b/examples/configs/eval.yaml @@ -7,6 +7,8 @@ generation: top_k: -1 # disable num_prompts_per_step: -1 # -1 means pass all prompts at once model_name: "Qwen/Qwen2.5-Math-1.5B-Instruct" + stop_token_ids: null + stop_strings: null vllm_cfg: tensor_parallel_size: 1 gpu_memory_utilization: 0.9 diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 3d8fdfce43..7a256621e4 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -59,6 +59,8 @@ policy: temperature: 1.0 top_p: 1.0 top_k: null + stop_token_ids: null + stop_strings: null vllm_cfg: tensor_parallel_size: 1 gpu_memory_utilization: 0.6 @@ -69,7 +71,7 @@ data: prompt_file: "examples/prompts/cot.txt" system_prompt_file: null dataset_name: "OpenMathInstruct-2" - + env: math: num_workers: 8 diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index f2e0576fbc..46c8855c6e 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -41,11 +41,13 @@ policy: temperature: 1.0 top_p: 1.0 top_k: null + stop_token_ids: null + stop_strings: null vllm_cfg: tensor_parallel_size: 1 gpu_memory_utilization: 0.6 max_model_len: ${policy.max_total_sequence_length} - + cluster: gpus_per_node: 8 num_nodes: 1 diff --git a/nemo_reinforcer/models/generation/interfaces.py b/nemo_reinforcer/models/generation/interfaces.py index f81d5d897d..468714899f 100644 --- a/nemo_reinforcer/models/generation/interfaces.py +++ b/nemo_reinforcer/models/generation/interfaces.py @@ -118,17 +118,18 @@ def configure_generation_config( """Apply specific configurations to generation config.""" # tokenizer setting config["pad_token_id"] = tokenizer.pad_token_id - # When https://github.com/NVIDIA/reinforcer/issues/57 is fixed, we should update stop_token_ids below. - config["stop_token_ids"] = [tokenizer.eos_token_id] + if config["stop_token_ids"] is None: + config["stop_token_ids"] = [tokenizer.eos_token_id] # vllm setting if config["backend"] == "vllm": - if is_eval: + # set load_format + config["vllm_cfg"]["load_format"] = "auto" if is_eval else "dummy" + # set skip_tokenizer_init + if is_eval or config["stop_strings"] is not None: config["vllm_cfg"]["skip_tokenizer_init"] = False - config["vllm_cfg"]["load_format"] = "auto" else: config["vllm_cfg"]["skip_tokenizer_init"] = True - config["vllm_cfg"]["load_format"] = "dummy" return config diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index 3f8528f549..f8c527dd06 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -250,13 +250,13 @@ def generate( sampling_params = self.SamplingParams( temperature=self.cfg["temperature"] if not greedy else 0, top_p=self.cfg["top_p"], - top_k=top_k - if not greedy - else 1, # we use a default of -1 if unset so that 'null'/None is a common disable value + # we use a default of -1 if unset so that 'null'/None is a common disable value + top_k=top_k if not greedy else 1, max_tokens=self.cfg["max_new_tokens"], logprobs=0, # Return logprobs for the generated tokens - stop=None, stop_token_ids=self.cfg["stop_token_ids"], + stop=self.cfg["stop_strings"], + include_stop_str_in_output=True, # returning stop strings like hf ) # Generate outputs @@ -352,7 +352,9 @@ def generate_text( top_p=self.cfg["top_p"], top_k=top_k if not greedy else 1, max_tokens=self.cfg["max_new_tokens"], - stop=self.cfg.get("stop_sequences", None), + stop_token_ids=self.cfg["stop_token_ids"], + stop=self.cfg["stop_strings"], + include_stop_str_in_output=True, # returning stop strings like hf ) # Generate outputs diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 051e56e23f..9ba65aca38 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -544,8 +544,10 @@ def generate( temperature=gen_cfg["temperature"], top_p=gen_cfg["top_p"], top_k=gen_cfg["top_k"], - pad_token_id=self.tokenizer.pad_token_id, - eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=gen_cfg["pad_token_id"], + eos_token_id=gen_cfg["stop_token_ids"], + stop_strings=gen_cfg["stop_strings"], + tokenizer=self.tokenizer, # needs for stop_strings return_dict_in_generate=True, output_scores=True, synced_gpus=True, diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index aadb1fec77..593b96852c 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -23,6 +23,7 @@ from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.models.generation.interfaces import configure_generation_config from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig +from nemo_reinforcer.models.policy import PolicyConfig # Define basic vLLM test config @@ -35,6 +36,8 @@ "temperature": 1.0, "top_p": 1.0, "top_k": None, + "stop_token_ids": None, + "stop_strings": None, "vllm_cfg": { "tensor_parallel_size": 1, "gpu_memory_utilization": 0.3, @@ -42,6 +45,29 @@ }, } +# Create HF-specific config with required parameters +basic_hf_test_config: PolicyConfig = { + "model_name": basic_vllm_test_config["model_name"], + "tokenizer_name": basic_vllm_test_config["tokenizer_name"], + # Required training parameters + "train_global_batch_size": 1, + "train_micro_batch_size": 1, + "learning_rate": 5e-6, + "logprob_batch_size": 1, + "max_new_tokens": 16, + "do_sample": False, + "precision": "float32", + "optimizer": { + "name": "torch.optim.AdamW", + "kwargs": { + "lr": 5e-6, + "weight_decay": 0.01, + "betas": [0.9, 0.999], + "eps": 1e-8, + }, + }, +} + @pytest.fixture(scope="module") def cluster(): @@ -193,28 +219,8 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): vllm_config = basic_vllm_test_config.copy() vllm_config = configure_generation_config(vllm_config, tokenizer) - # Create HF-specific config with required parameters - hf_config = { - "model_name": basic_vllm_test_config["model_name"], - "tokenizer_name": basic_vllm_test_config["tokenizer_name"], - # Required training parameters - "train_global_batch_size": 4, - "train_micro_batch_size": 1, - "learning_rate": 5e-6, - "logprob_batch_size": 1, - "max_new_tokens": 16, - "do_sample": False, - "precision": "float32", - "optimizer": { - "name": "torch.optim.AdamW", - "kwargs": { - "lr": 5e-6, - "weight_decay": 0.01, - "betas": [0.9, 0.999], - "eps": 1e-8, - }, - }, - } + hf_config = basic_hf_test_config.copy() + hf_config["train_global_batch_size"] = 4 vllm_policy = None hf_policy = None @@ -498,18 +504,7 @@ def test_vllm_weight_update_and_prefix_cache_reset( if tensor_parallel_size > 1: vllm_config["vllm_kwargs"] = {"distributed_executor_backend": "ray"} - hf_config = { - "model_name": basic_vllm_test_config["model_name"], - "tokenizer_name": "meta-llama/Llama-3.2-1B", - "train_global_batch_size": 1, - "train_micro_batch_size": 1, - "learning_rate": 1e-6, - "logprob_batch_size": 1, - "max_new_tokens": 16, - "do_sample": False, - "precision": "float32", - "optimizer": {"name": "torch.optim.AdamW", "kwargs": {"lr": 1e-6}}, - } + hf_config = basic_hf_test_config.copy() # Create policies vllm_policy = None @@ -592,3 +587,67 @@ def test_vllm_weight_update_and_prefix_cache_reset( gc.collect() torch.cuda.empty_cache() + + +@pytest.mark.parametrize("is_eval", [True, False]) +def test_vllm_generation_with_stop(cluster, test_input_data, tokenizer, is_eval): + """Test vLLM generation with stop.""" + from nemo_reinforcer.models.policy.hf_policy import HfPolicy + + # Create separate configs for each policy + vllm_config = basic_vllm_test_config.copy() + vllm_config["stop_token_ids"] = [3363] + vllm_config["stop_strings"] = ["I am a"] + vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=is_eval) + + # Ensure we can get same output + assert vllm_config["model_name"] == "meta-llama/Llama-3.2-1B", ( + "Model name should be meta-llama/Llama-3.2-1B to get expected output" + ) + assert vllm_config["vllm_cfg"]["tensor_parallel_size"] == 1, ( + "Tensor parallel size should be 1 to get expected output" + ) + + # Create policies + print("Creating vLLM policy...") + vllm_generation = VllmGeneration(cluster, vllm_config) + + # Get weights from HF policy if not in eval mode + if not is_eval: + # set to sleep first if not in eval mode + vllm_generation.finish_generation() + + print("Creating HF policy...") + hf_config = basic_hf_test_config.copy() + hf_policy = HfPolicy(cluster, hf_config) + + print(f"refitting vllm policy...") + ipc_handles = hf_policy.get_weights_ipc_handles() + vllm_generation.prepare_for_generation() + vllm_generation.update_weights(ipc_handles) + + # test generate + outputs = vllm_generation.generate(test_input_data, greedy=True) + output_ids = outputs["output_ids"] + generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + assert generated_texts == [ + "Hello, my name is Kelsey and I am a", + "The capital of France is Paris. The city", + ], "Output should be the same as the expected output" + + # test generate_text + test_prompts = [ + "Hello, my name is", + "The capital of France is", + ] + test_prompts = BatchedDataDict({"prompts": test_prompts}) + output = vllm_generation.generate_text(test_prompts, greedy=True) + assert output["texts"] == [ + " Kelsey and I am a", + " Paris. The city", + ], "Output should be the same as the expected output" + + # Clean up + vllm_generation.shutdown() + if not is_eval: + hf_policy.shutdown() diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 7cde591049..449ed016d5 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -15,11 +15,12 @@ import pytest import pprint import torch +from copy import deepcopy from nemo_reinforcer.algorithms.interfaces import LossFunction from nemo_reinforcer.algorithms.utils import get_tokenizer -from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.models.generation.interfaces import configure_generation_config from nemo_reinforcer.models.policy import PolicyConfig from nemo_reinforcer.models.policy.hf_policy import HfPolicy @@ -41,6 +42,8 @@ "max_new_tokens": 16, # Small number of tokens for testing "top_p": 1.0, "top_k": None, + "stop_token_ids": None, + "stop_strings": None, }, "optimizer": { "name": "torch.optim.AdamW", @@ -76,6 +79,54 @@ def tokenizer(): return tokenizer +@pytest.fixture(scope="function") +def test_input_data(tokenizer): + """Create test input data for inference.""" + prompts = [ + "Write a story about a magical forest", + "Explain how photosynthesis works", + "What are the benefits of exercise?", + "Describe the water cycle", + "What is the capital of France?", + "Who is the president of the USA?", + "What is the capital of the moon?", + "Where is the sun?", + ] + + expected_generations = [ + "Write a story about a magical forest. The forest is magical because it is full of magical creatures. The creatures are", + "Explain how photosynthesis works\nExplain how photosynthesis works\nPhotosynthesis is the process by which plants", + "What are the benefits of exercise? The benefits of exercise are many and varied. It is a great way to improve", + "Describe the water cycle in your own words.\nDescribe the water cycle in your own words.\nDescribe the", + "What is the capital of France? A. Paris B. New York C. Washington D. Baton Rouge\nA", + "Who is the president of the USA? Who is the president of the USA? Who is the president of the USA?", + "What is the capital of the moon? A. Houston B. New York C. Washington D. Denver\nA.", + "Where is the sun? Where is the moon? Where is the earth? Where is the sky? Where", + ] + + # Tokenize the prompts + tokenized = tokenizer( + prompts, + padding=True, + truncation=True, + max_length=64, + return_tensors="pt", + padding_side="right", + ) + + # Calculate input lengths from attention mask + input_lengths = tokenized["attention_mask"].sum(dim=1).to(torch.int32) + + data = BatchedDataDict( + { + "input_ids": tokenized["input_ids"], + "input_lengths": input_lengths, + } + ) + + return data, prompts, expected_generations + + @pytest.fixture def policy_setup(tokenizer): """Setup and teardown for policy tests - creates a virtual cluster and policy.""" @@ -289,7 +340,7 @@ def verify_loss_tensor(loss_tensor): @pytest.fixture -def generation_setup(request, tokenizer): +def generation_setup(request, test_input_data, tokenizer): """Setup and teardown specifically for generation tests.""" policy = None cluster = None @@ -322,47 +373,8 @@ def generation_setup(request, tokenizer): print("Creating test batch...") torch.manual_seed(42) # For reproducibility - prompts = [ - "Write a story about a magical forest", - "Explain how photosynthesis works", - "What are the benefits of exercise?", - "Describe the water cycle", - "What is the capital of France?", - "Who is the president of the USA?", - "What is the capital of the moon?", - "Where is the sun?", - ] - - expected_generations = [ - "Write a story about a magical forest. The forest is magical because it is full of magical creatures. The creatures are", - "Explain how photosynthesis works\nExplain how photosynthesis works\nPhotosynthesis is the process by which plants", - "What are the benefits of exercise? The benefits of exercise are many and varied. It is a great way to improve", - "Describe the water cycle in your own words.\nDescribe the water cycle in your own words.\nDescribe the", - "What is the capital of France? A. Paris B. New York C. Washington D. Baton Rouge\nA", - "Who is the president of the USA? Who is the president of the USA? Who is the president of the USA?", - "What is the capital of the moon? A. Houston B. New York C. Washington D. Denver\nA.", - "Where is the sun? Where is the moon? Where is the earth? Where is the sky? Where", - ] - - # Tokenize the prompts - tokenized = tokenizer( - prompts, - padding=True, - truncation=True, - max_length=64, - return_tensors="pt", - padding_side="right", - ) - - # Calculate input lengths from attention mask - input_lengths = tokenized["attention_mask"].sum(dim=1).to(torch.int32) - - data = BatchedDataDict( - { - "input_ids": tokenized["input_ids"], - "input_lengths": input_lengths, - } - ) + # Prepare test data + data, prompts, expected_generations = test_input_data # Provide the resources to the test yield policy, cluster, data, prompts, expected_generations @@ -544,3 +556,64 @@ def test_all_hf_policy_generation_lps_ref_training(generation_setup): # Verify loss decreased during training assert losses[0] > losses[-1], "Loss should decrease over training iterations" + + +def test_hf_policy_generation_with_stop(test_input_data, tokenizer): + # Create resources with unique name + cluster_name = "test-generate-with-stop" + print(f"Creating training virtual cluster '{cluster_name}'...") + + cluster = RayVirtualCluster( + name=cluster_name, + bundle_ct_per_node_list=[2], # Single node, 2 gpus + use_gpus=True, + num_gpus_per_node=2, # Using both GPUs + max_colocated_worker_groups=1, # Only one worker group + ) + + # Create separate configs for each policy + config = deepcopy(basic_llama_test_config) + config["generation"] = configure_generation_config(config["generation"], tokenizer) + # Add stop strings for testing + config["generation"]["stop_token_ids"] = [1690, 1920] # [" process", "many"] + config["generation"]["stop_strings"] = ["because it is", "A. Houston"] + + # Ensure we can get same output + assert config["model_name"] == "meta-llama/Llama-3.2-1B", ( + "Model name should be meta-llama/Llama-3.2-1B to get expected output" + ) + + # Create policy + policy = HfPolicy(cluster=cluster, config=config) + + # Call prepare_for_generation if available + print("Preparing for generation...") + policy.prepare_for_generation() + + # Generate text + print("Generating text...") + data, _, _ = test_input_data + results = policy.generate(data, greedy=True) + output_ids = results["output_ids"] + + # Call finish_generation if available + print("Finishing generation...") + policy.finish_generation() + + # Check result + generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + assert generated_texts == [ + "Write a story about a magical forest. The forest is magical because it is", + "Explain how photosynthesis works\nExplain how photosynthesis works\nPhotosynthesis is the process", + "What are the benefits of exercise? The benefits of exercise are many", + "Describe the water cycle in your own words.\nDescribe the water cycle in your own words.\nDescribe the", + "What is the capital of France? A. Paris B. New York C. Washington D. Baton Rouge\nA", + "Who is the president of the USA? Who is the president of the USA? Who is the president of the USA?", + "What is the capital of the moon? A. Houston", + "Where is the sun? Where is the moon? Where is the earth? Where is the sky? Where", + ], "Output should be the same as the expected output" + + # Clean up after the test + print("Cleaning up resources for test") + cluster.shutdown() + policy.worker_group.shutdown()