Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions docs/design_docs/generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,30 @@ The {py:class}`UpdatableVllmInternalWorker <nemo_reinforcer.models.generation.vl
To use a generation backend:

```python
from transformers import AutoTokenizer

from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig
from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict

# Set up the configuration
tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

config = VllmConfig(
backend="vllm",
model_name="Qwen/Qwen2.5-1.5B",
max_new_tokens=100,
temperature=0.7,
top_p=1,
top_k=None,
stop_token_ids=[tokenizer.eos_token_id]
pad_token=tokenizer.pad_token_id,
skip_tokenizer_init=True,
vllm_cfg={
"tensor_parallel_size": 1,
"gpu_memory_utilization": 0.8
"gpu_memory_utilization": 0.8,
"max_model_len": 2048,
}
)

Expand Down
4 changes: 3 additions & 1 deletion examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def setup_data(data_config: DataConfig, policy_config: PolicyConfig, env_configs
raise ValueError(f"No processor for dataset {data_config['dataset_name']}.")

tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

task_data_processors = defaultdict(
lambda: (math_task_spec, openinstructmath2_data_processor)
Expand Down Expand Up @@ -270,7 +272,7 @@ def main():
checkpointer,
grpo_state,
master_config,
) = setup(config, dataset, val_dataset)
) = setup(config, tokenizer, dataset, val_dataset)
grpo_train(
policy,
policy_generation,
Expand Down
7 changes: 7 additions & 0 deletions nemo_reinforcer/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class MasterConfig(TypedDict):

def setup(
master_config: MasterConfig,
tokenizer: AutoTokenizer,
dataset: AllTaskProcessedDataset,
val_dataset: Optional[AllTaskProcessedDataset],
) -> Tuple[
Expand Down Expand Up @@ -219,6 +220,12 @@ def setup(
# vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this)
backend = generation_config["backend"]
generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM
generation_config["vllm_cfg"]["skip_tokenizer_init"] = True
# When https://github.com/NVIDIA/reinforcer/issues/57 is fixed, we should update stop_token_ids below.
generation_config["stop_token_ids"] = [tokenizer.eos_token_id]
generation_config["pad_token"] = tokenizer.pad_token_id
generation_config["vllm_cfg"]["load_format"] = "dummy"

if backend == "hf":
policy_generation = None
print(f" ✓ Using HF backend for generation with {policy_config['model_name']}")
Expand Down
4 changes: 3 additions & 1 deletion nemo_reinforcer/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, TypedDict, Union, Tuple
from typing import Any, TypedDict, Union, Tuple, List

import torch
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
Expand Down Expand Up @@ -101,6 +101,8 @@ class GenerationConfig(TypedDict):
top_p: float
top_k: int
model_name: str
stop_token_ids: List[int]
Comment thread
terrykong marked this conversation as resolved.
pad_token: int


class GenerationDatumSpec(TypedDict):
Expand Down
33 changes: 24 additions & 9 deletions nemo_reinforcer/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class VllmSpecificArgs(TypedDict):
tensor_parallel_size: int
gpu_memory_utilization: float
max_model_len: int
# Additional arguments for vLLM inserted by reinforcer based on the context of when vllm is used
skip_tokenizer_init: bool
load_format: str


class VllmConfig(GenerationConfig):
Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(
Only needed for the first worker in each tied worker group.
"""
self.cfg = config

self.model_name = self.cfg["model_name"]
self.tensor_parallel_size = self.cfg["vllm_cfg"]["tensor_parallel_size"]
self.gpu_memory_utilization = self.cfg["vllm_cfg"]["gpu_memory_utilization"]
Expand Down Expand Up @@ -166,23 +170,22 @@ def __init__(

self.llm = LLM(
model=self.model_name,
load_format="dummy",
tensor_parallel_size=self.tensor_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
# Training pipeline will set this to "dummy" and eval will load real weights using 'auto'
load_format=self.cfg["vllm_cfg"]["load_format"],
skip_tokenizer_init=self.cfg["vllm_cfg"]["skip_tokenizer_init"],
tensor_parallel_size=self.cfg["vllm_cfg"]["tensor_parallel_size"],
gpu_memory_utilization=self.cfg["vllm_cfg"]["gpu_memory_utilization"],
enable_prefix_caching=True,
dtype="auto",
enforce_eager=True,
max_model_len=self.cfg["vllm_cfg"]["max_model_len"],
trust_remote_code=True,
worker_cls=UpdatableVllmInternalWorker,
enable_sleep_mode=True,
disable_log_stats=True,
**vllm_kwargs,
)

self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

def llm(self):
return self.llm

Expand Down Expand Up @@ -213,7 +216,7 @@ def generate(
f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}"
)
is_right_padded, error_msg = verify_right_padding(
data, pad_value=self.tokenizer.pad_token_id
data, pad_value=self.cfg["pad_token"]
)
if not is_right_padded:
warnings.warn(
Expand Down Expand Up @@ -251,6 +254,7 @@ def generate(
max_tokens=self.cfg["max_new_tokens"],
logprobs=0, # Return logprobs for the generated tokens
stop=None,
stop_token_ids=self.cfg["stop_token_ids"],
)

# Generate outputs
Expand All @@ -276,7 +280,7 @@ def generate(

# Create a new tensor with the right size and fill with padding token
full_output = torch.full(
(total_length,), self.tokenizer.pad_token_id, dtype=input_ids.dtype
(total_length,), self.cfg["pad_token"], dtype=input_ids.dtype
)

# Copy original input (with padding) into the beginning
Expand Down Expand Up @@ -402,6 +406,17 @@ def __init__(
"""Initialize a vLLM policy with distributed workers."""
# Store config
self.cfg = config
# Ensure all required VllmConfig fields are present
missing_keys = [
key for key in VllmConfig.__annotations__ if key not in self.cfg
]
assert not missing_keys, (
f"VLLM Configuration Error: Missing required keys in VllmConfig.\n"
f"Missing keys: {', '.join(missing_keys)}\n"
f"Provided keys: {', '.join(self.cfg.keys())}\n"
f"Please update your configuration to include all required VLLM parameters."
)

self.tensor_parallel_size = self.cfg["vllm_cfg"]["tensor_parallel_size"]

# Create worker builder for VllmGenerationWorker
Expand Down
44 changes: 42 additions & 2 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

# Define basic vLLM test config
basic_vllm_test_config: VllmConfig = {
"backend": "vllm",
"model_name": "meta-llama/Llama-3.2-1B", # Small model for testing
"dtype": "bfloat16",
"max_new_tokens": 10,
Expand All @@ -39,6 +40,15 @@
}


def configure_vllm_with_tokenizer(vllm_config, tokenizer):
"""Apply tokenizer-specific configurations to vLLM config."""
vllm_config["vllm_cfg"]["skip_tokenizer_init"] = True
vllm_config["vllm_cfg"]["load_format"] = "dummy"
vllm_config["pad_token"] = tokenizer.pad_token_id
vllm_config["stop_token_ids"] = [tokenizer.eos_token_id]
return vllm_config


@pytest.fixture(scope="module")
def check_vllm_available():
"""Skip tests if vLLM is not installed."""
Expand Down Expand Up @@ -74,9 +84,12 @@ def tokenizer():


@pytest.fixture(scope="function")
def policy(cluster, check_vllm_available):
def policy(cluster, tokenizer, check_vllm_available):
"""Initialize the vLLM policy."""
policy = VllmGeneration(cluster, basic_vllm_test_config)
# Create separate configs for each policy
vllm_config = basic_vllm_test_config.copy()
vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer)
policy = VllmGeneration(cluster, vllm_config)
yield policy

# Ensure policy is properly shutdown
Expand Down Expand Up @@ -121,6 +134,30 @@ def test_input_data(tokenizer):
)


def test_vllm_missing_required_config_key(cluster, check_vllm_available):
"""Test that an assertion error is raised when a required config key is missing."""
# Create a config missing a required key by removing 'model_name'
incomplete_config = basic_vllm_test_config.copy()
del incomplete_config["model_name"] # Remove a required key

# Also need to ensure skip_tokenizer_init and load_format are there
# since these are checked in VllmConfig.__annotations__
incomplete_config["skip_tokenizer_init"] = True
incomplete_config["load_format"] = "auto"

# Attempt to initialize VllmGeneration with incomplete config - should raise AssertionError
with pytest.raises(AssertionError) as excinfo:
VllmGeneration(cluster, incomplete_config)

# Verify the error message contains information about the missing key
error_message = str(excinfo.value)
assert "Missing required keys in VllmConfig" in error_message
assert "model_name" in error_message, (
"Error should mention the missing 'model_name' key"
)
print(f"Successfully caught missing config key with error: {error_message}")


def test_vllm_policy_generation(policy, test_input_data, tokenizer):
"""Test vLLM policy generation capabilities."""
# Test generation
Expand Down Expand Up @@ -171,6 +208,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer):

# Create separate configs for each policy
vllm_config = basic_vllm_test_config.copy()
vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer)

# Create HF-specific config with required parameters
hf_config = {
Expand Down Expand Up @@ -359,6 +397,7 @@ def test_vllm_policy_tensor_parallel(cluster, tokenizer):
"""Test vLLM policy with tensor parallelism > 1."""
# Configure with tensor_parallel_size=2
tp_config = basic_vllm_test_config.copy()
tp_config = configure_vllm_with_tokenizer(tp_config, tokenizer)
tp_config["tensor_parallel_size"] = 2

# Ensure we specify the distributed executor backend
Expand Down Expand Up @@ -420,6 +459,7 @@ def test_vllm_policy_weight_update(cluster, tokenizer, tensor_parallel_size):

# Create separate configs for each policy
vllm_config = basic_vllm_test_config.copy()
vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer)
vllm_config["tensor_parallel_size"] = tensor_parallel_size

# Add vllm_kwargs only if using tensor parallelism
Expand Down