diff --git a/docs/design_docs/generation.md b/docs/design_docs/generation.md index e9fa3ee2ee..8dda8a028a 100644 --- a/docs/design_docs/generation.md +++ b/docs/design_docs/generation.md @@ -95,20 +95,30 @@ The {py:class}`UpdatableVllmInternalWorker Tuple[ @@ -219,6 +220,12 @@ def setup( # vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this) backend = generation_config["backend"] generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM + generation_config["vllm_cfg"]["skip_tokenizer_init"] = True + # When https://github.com/NVIDIA/reinforcer/issues/57 is fixed, we should update stop_token_ids below. + generation_config["stop_token_ids"] = [tokenizer.eos_token_id] + generation_config["pad_token"] = tokenizer.pad_token_id + generation_config["vllm_cfg"]["load_format"] = "dummy" + if backend == "hf": policy_generation = None print(f" ✓ Using HF backend for generation with {policy_config['model_name']}") diff --git a/nemo_reinforcer/models/generation/interfaces.py b/nemo_reinforcer/models/generation/interfaces.py index 8ffa1d2945..138b70fbc1 100644 --- a/nemo_reinforcer/models/generation/interfaces.py +++ b/nemo_reinforcer/models/generation/interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, TypedDict, Union, Tuple +from typing import Any, TypedDict, Union, Tuple, List import torch from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict @@ -101,6 +101,8 @@ class GenerationConfig(TypedDict): top_p: float top_k: int model_name: str + stop_token_ids: List[int] + pad_token: int class GenerationDatumSpec(TypedDict): diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index ebbe53cff5..cb60b7fe8c 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -39,6 +39,9 @@ class VllmSpecificArgs(TypedDict): tensor_parallel_size: int gpu_memory_utilization: float max_model_len: int + # Additional arguments for vLLM inserted by reinforcer based on the context of when vllm is used + skip_tokenizer_init: bool + load_format: str class VllmConfig(GenerationConfig): @@ -110,6 +113,7 @@ def __init__( Only needed for the first worker in each tied worker group. """ self.cfg = config + self.model_name = self.cfg["model_name"] self.tensor_parallel_size = self.cfg["vllm_cfg"]["tensor_parallel_size"] self.gpu_memory_utilization = self.cfg["vllm_cfg"]["gpu_memory_utilization"] @@ -166,9 +170,11 @@ def __init__( self.llm = LLM( model=self.model_name, - load_format="dummy", - tensor_parallel_size=self.tensor_parallel_size, - gpu_memory_utilization=self.gpu_memory_utilization, + # Training pipeline will set this to "dummy" and eval will load real weights using 'auto' + load_format=self.cfg["vllm_cfg"]["load_format"], + skip_tokenizer_init=self.cfg["vllm_cfg"]["skip_tokenizer_init"], + tensor_parallel_size=self.cfg["vllm_cfg"]["tensor_parallel_size"], + gpu_memory_utilization=self.cfg["vllm_cfg"]["gpu_memory_utilization"], enable_prefix_caching=True, dtype="auto", enforce_eager=True, @@ -176,13 +182,10 @@ def __init__( trust_remote_code=True, worker_cls=UpdatableVllmInternalWorker, enable_sleep_mode=True, + disable_log_stats=True, **vllm_kwargs, ) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - def llm(self): return self.llm @@ -213,7 +216,7 @@ def generate( f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" ) is_right_padded, error_msg = verify_right_padding( - data, pad_value=self.tokenizer.pad_token_id + data, pad_value=self.cfg["pad_token"] ) if not is_right_padded: warnings.warn( @@ -251,6 +254,7 @@ def generate( max_tokens=self.cfg["max_new_tokens"], logprobs=0, # Return logprobs for the generated tokens stop=None, + stop_token_ids=self.cfg["stop_token_ids"], ) # Generate outputs @@ -276,7 +280,7 @@ def generate( # Create a new tensor with the right size and fill with padding token full_output = torch.full( - (total_length,), self.tokenizer.pad_token_id, dtype=input_ids.dtype + (total_length,), self.cfg["pad_token"], dtype=input_ids.dtype ) # Copy original input (with padding) into the beginning @@ -402,6 +406,17 @@ def __init__( """Initialize a vLLM policy with distributed workers.""" # Store config self.cfg = config + # Ensure all required VllmConfig fields are present + missing_keys = [ + key for key in VllmConfig.__annotations__ if key not in self.cfg + ] + assert not missing_keys, ( + f"VLLM Configuration Error: Missing required keys in VllmConfig.\n" + f"Missing keys: {', '.join(missing_keys)}\n" + f"Provided keys: {', '.join(self.cfg.keys())}\n" + f"Please update your configuration to include all required VLLM parameters." + ) + self.tensor_parallel_size = self.cfg["vllm_cfg"]["tensor_parallel_size"] # Create worker builder for VllmGenerationWorker diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 9c89046af2..af8b5698e3 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -25,6 +25,7 @@ # Define basic vLLM test config basic_vllm_test_config: VllmConfig = { + "backend": "vllm", "model_name": "meta-llama/Llama-3.2-1B", # Small model for testing "dtype": "bfloat16", "max_new_tokens": 10, @@ -39,6 +40,15 @@ } +def configure_vllm_with_tokenizer(vllm_config, tokenizer): + """Apply tokenizer-specific configurations to vLLM config.""" + vllm_config["vllm_cfg"]["skip_tokenizer_init"] = True + vllm_config["vllm_cfg"]["load_format"] = "dummy" + vllm_config["pad_token"] = tokenizer.pad_token_id + vllm_config["stop_token_ids"] = [tokenizer.eos_token_id] + return vllm_config + + @pytest.fixture(scope="module") def check_vllm_available(): """Skip tests if vLLM is not installed.""" @@ -74,9 +84,12 @@ def tokenizer(): @pytest.fixture(scope="function") -def policy(cluster, check_vllm_available): +def policy(cluster, tokenizer, check_vllm_available): """Initialize the vLLM policy.""" - policy = VllmGeneration(cluster, basic_vllm_test_config) + # Create separate configs for each policy + vllm_config = basic_vllm_test_config.copy() + vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) + policy = VllmGeneration(cluster, vllm_config) yield policy # Ensure policy is properly shutdown @@ -121,6 +134,30 @@ def test_input_data(tokenizer): ) +def test_vllm_missing_required_config_key(cluster, check_vllm_available): + """Test that an assertion error is raised when a required config key is missing.""" + # Create a config missing a required key by removing 'model_name' + incomplete_config = basic_vllm_test_config.copy() + del incomplete_config["model_name"] # Remove a required key + + # Also need to ensure skip_tokenizer_init and load_format are there + # since these are checked in VllmConfig.__annotations__ + incomplete_config["skip_tokenizer_init"] = True + incomplete_config["load_format"] = "auto" + + # Attempt to initialize VllmGeneration with incomplete config - should raise AssertionError + with pytest.raises(AssertionError) as excinfo: + VllmGeneration(cluster, incomplete_config) + + # Verify the error message contains information about the missing key + error_message = str(excinfo.value) + assert "Missing required keys in VllmConfig" in error_message + assert "model_name" in error_message, ( + "Error should mention the missing 'model_name' key" + ) + print(f"Successfully caught missing config key with error: {error_message}") + + def test_vllm_policy_generation(policy, test_input_data, tokenizer): """Test vLLM policy generation capabilities.""" # Test generation @@ -171,6 +208,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Create separate configs for each policy vllm_config = basic_vllm_test_config.copy() + vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) # Create HF-specific config with required parameters hf_config = { @@ -359,6 +397,7 @@ def test_vllm_policy_tensor_parallel(cluster, tokenizer): """Test vLLM policy with tensor parallelism > 1.""" # Configure with tensor_parallel_size=2 tp_config = basic_vllm_test_config.copy() + tp_config = configure_vllm_with_tokenizer(tp_config, tokenizer) tp_config["tensor_parallel_size"] = 2 # Ensure we specify the distributed executor backend @@ -420,6 +459,7 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size): # Create separate configs for each policy vllm_config = basic_vllm_test_config.copy() + vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) vllm_config["tensor_parallel_size"] = tensor_parallel_size # Add vllm_kwargs only if using tensor parallelism