From 8053a40a3a5cc94dbf640cf94ac2674590b2aedc Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Fri, 4 Apr 2025 06:46:35 -0700 Subject: [PATCH 01/11] WIP migration to v1 runtime Signed-off-by: Parth Chadha --- nemo_reinforcer/models/generation/vllm.py | 19 +++++++++------- .../models/generation/vllm_backend.py | 20 ++++++++--------- pyproject.toml | 4 ++-- uv.lock | 22 +++++++++++++++---- 4 files changed, 41 insertions(+), 24 deletions(-) diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index 4ffbb3e2ff..f69f258327 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -95,7 +95,8 @@ def configure_worker( init_kwargs["fraction_of_gpus"] = num_gpus # Force vllm to use v0 runtime (will be enabled by default in #51) - env_vars["VLLM_USE_V1"] = "0" + # env_vars["VLLM_USE_V1"] = "0" + env_vars["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" return resources, env_vars, init_kwargs @@ -134,12 +135,13 @@ def __init__( self.world_size = 1 try: - from vllm import LLM, SamplingParams - from nemo_reinforcer.models.generation.vllm_backend import ( - UpdatableVllmInternalWorker, - ) + import vllm + # from vllm import LLM, SamplingParams + # from nemo_reinforcer.models.generation.vllm_backend import ( + # UpdatableVllmInternalWorker, + # ) - self.SamplingParams = SamplingParams + self.SamplingParams = vllm.SamplingParams except ImportError: raise ImportError( "vLLM is not installed. Please install it with `pip install nemo-reinforcer[vllm]` " @@ -168,7 +170,7 @@ def __init__( # For non-TP mode, explicitly set executor to None to avoid Ray issues vllm_kwargs["distributed_executor_backend"] = None - self.llm = LLM( + self.llm = vllm.LLM( model=self.model_name, # Training pipeline will set this to "dummy" and eval will load real weights using 'auto' load_format=self.cfg["vllm_cfg"]["load_format"], @@ -181,7 +183,8 @@ def __init__( enforce_eager=False, max_model_len=self.cfg["vllm_cfg"]["max_model_len"], trust_remote_code=True, - worker_cls=UpdatableVllmInternalWorker, + worker_extension_cls="nemo_reinforcer.models.generation.vllm_backend.VllmInternalWorkerExtension", + # worker_cls=UpdatableVllmInternalWorker, enable_sleep_mode=True, disable_log_stats=True, **vllm_kwargs, diff --git a/nemo_reinforcer/models/generation/vllm_backend.py b/nemo_reinforcer/models/generation/vllm_backend.py index 09e94f2815..6114194bc8 100644 --- a/nemo_reinforcer/models/generation/vllm_backend.py +++ b/nemo_reinforcer/models/generation/vllm_backend.py @@ -13,17 +13,17 @@ # limitations under the License. import torch -try: - from vllm.worker.worker import Worker -except ImportError: - raise ImportError( - "vLLM is not installed. Please install it with `pip install nemo-reinforcer[vllm]` " - "or `pip install vllm` separately. This issue may also occur if worker is using incorrect " - "py_executable." - ) +# try: +# from vllm.worker.worker import Worker +# except ImportError: +# raise ImportError( +# "vLLM is not installed. Please install it with `pip install nemo-reinforcer[vllm]` " +# "or `pip install vllm` separately. This issue may also occur if worker is using incorrect " +# "py_executable." +# ) -class UpdatableVllmInternalWorker(Worker): +class VllmInternalWorkerExtension: def report_device_id(self) -> str: from nemo_reinforcer.utils.nvml import get_device_uuid @@ -60,6 +60,6 @@ def update_weights_from_ipc_handles(self, ipc_handles): return True except Exception as e: print( - f"Error in UpdatableVllmInternalWorker.update_weights_from_ipc_handles: {e}" + f"Error in VllmInternalWorkerExtension.update_weights_from_ipc_handles: {e}" ) return False diff --git a/pyproject.toml b/pyproject.toml index b257b7aed5..07e23c9f5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ readme = {file = "README.md", content-type = "text/markdown"} [project.optional-dependencies] vllm = [ - "vllm==0.8.0", + "vllm==0.8.2", ] [dependency-groups] @@ -108,4 +108,4 @@ convention = "google" # --link-mode=copy (slower but more reliable; supresses warning) # --link-mode=symlink (fastest option when uv cache and venv on different file-system; caveat: venv is brittle since it depends on the environment/container) # -#link-mode = "symlink" \ No newline at end of file +#link-mode = "symlink" diff --git a/uv.lock b/uv.lock index b5d6bbb4f0..d546f25e64 100644 --- a/uv.lock +++ b/uv.lock @@ -1361,6 +1361,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6d/eb/a5e8b06b924b4149cf498e1598116bad1e91ab23046c2dfc2c498154d393/latex2sympy2_extended-1.10.1-py3-none-any.whl", hash = "sha256:917a23e8f3b6edea88a56978fbbe87ed9fca4197f8277646be57b4660710347c", size = 207460 }, ] +[[package]] +name = "llguidance" +version = "0.7.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7c/4b/92f81aa9d98e2c0721e2760e0fa1ae1691380bd27f2bf530310671a777d9/llguidance-0.7.11.tar.gz", hash = "sha256:226409610f1d1e0ecd62f15d1dd47851879513eb1eb56129c56de8188b80fa8d", size = 384121 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/38/bb5e0e185f84e4702ca079b0874de88b0d1b7245c48fc6449b766bce6103/llguidance-0.7.11-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c1639466113196cf6d274461deaafbe6011b60d459f773ca97045df1ee87e195", size = 3065620 }, + { url = "https://files.pythonhosted.org/packages/c3/c3/14f1173407a0ba18e1f57d26eae4da49d6336d5e0405336b9cbcb749848b/llguidance-0.7.11-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:1e6899df33f3372ec86d7c1939e33891fda9e9a533dcd7f7f8c556897446765b", size = 2957459 }, + { url = "https://files.pythonhosted.org/packages/5f/07/6064f1253708c879c96ce0b74bacd7ab2845c0e8199ff13d84681a5041ad/llguidance-0.7.11-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32edcdc60922bdc97dcbae4d18e2d6dca451571959303ced7b7821dbbd344c0f", size = 13561497 }, + { url = "https://files.pythonhosted.org/packages/e1/9e/96d96fab0c27adb9f51dabc42682d12dfe4602e7637a71614b916879ae7a/llguidance-0.7.11-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b167f7d4da85747378c0c58393cd078b459a90d6e8a60e676692784a78a6f61", size = 13687114 }, + { url = "https://files.pythonhosted.org/packages/c3/72/f5ed95fd29faf6b197d6af543671306ef154741f804b197c3e3f7ad15a8b/llguidance-0.7.11-cp39-abi3-win_amd64.whl", hash = "sha256:585cb3b52a702303240ae91cc0633735dab3a1db2c062af8ffb4ef3ca4737236", size = 2611515 }, +] + [[package]] name = "llvmlite" version = "0.43.0" @@ -1823,7 +1836,7 @@ requires-dist = [ { name = "torch", specifier = "==2.6.0" }, { name = "torchdata" }, { name = "transformers" }, - { name = "vllm", marker = "extra == 'vllm'", specifier = "==0.8.0" }, + { name = "vllm", marker = "extra == 'vllm'", specifier = "==0.8.2" }, { name = "wandb" }, ] provides-extras = ["vllm"] @@ -4153,7 +4166,7 @@ wheels = [ [[package]] name = "vllm" -version = "0.8.0" +version = "0.8.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -4168,6 +4181,7 @@ dependencies = [ { name = "gguf" }, { name = "importlib-metadata" }, { name = "lark" }, + { name = "llguidance", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, { name = "lm-format-enforcer" }, { name = "mistral-common", extra = ["opencv"] }, { name = "msgspec" }, @@ -4205,9 +4219,9 @@ dependencies = [ { name = "xformers", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'x86_64'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d2/27/73a54707964c5160067e253398cc328943e3ddbaa3099265ab593e6ec766/vllm-0.8.0.tar.gz", hash = "sha256:449e6651d30d6d5025d0d42499cf1a02d983915ef3b3670547db14a0431aa9bd", size = 6407594 } +sdist = { url = "https://files.pythonhosted.org/packages/df/4d/6b27cc14d0c35e578a743a767953500a801ba296694b7e44cca709738b41/vllm-0.8.2.tar.gz", hash = "sha256:9b337b1c4072ccb94b1bf2b716593fadbe2dcb8d091f9bcbd6b5c6d37f9842ac", size = 6450146 } wheels = [ - { url = "https://files.pythonhosted.org/packages/15/77/7beca2061aadfdfd2d81411102e6445b459bcfedfc46671d4712de6a00fb/vllm-0.8.0-cp38-abi3-manylinux1_x86_64.whl", hash = "sha256:d3660eda448560b0ce6a1524466d7d36ec0024e772c9dbf562dbead980e7d480", size = 265290109 }, + { url = "https://files.pythonhosted.org/packages/57/49/207364110b96d76139a4e80617e5831d46884abe824941b15c8a748ca5e0/vllm-0.8.2-cp38-abi3-manylinux1_x86_64.whl", hash = "sha256:32442b686c5dad8e6ddcf5a8b0cf3f741359fed6a9e9e940009f1daf80ae15de", size = 293643693 }, ] [[package]] From f80899392f52e7035bceb0b709b27cf877de7b53 Mon Sep 17 00:00:00 2001 From: Charlie Truong Date: Fri, 4 Apr 2025 12:44:11 -0500 Subject: [PATCH 02/11] ci: Fix unit test summary (#128) Signed-off-by: Charlie Truong Co-authored-by: Terry Kong Signed-off-by: Parth Chadha --- .github/workflows/_run_test.yml | 3 +-- .github/workflows/cicd-main.yml | 4 ++-- docker/Dockerfile | 3 +++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/_run_test.yml b/.github/workflows/_run_test.yml index 4e2131629a..89827efdad 100644 --- a/.github/workflows/_run_test.yml +++ b/.github/workflows/_run_test.yml @@ -81,7 +81,6 @@ jobs: --env HF_DATASETS_CACHE=/home/TestData/reinforcer/hf_datasets_cache \ --env REINFORCER_REPO_DIR=/opt/reinforcer \ --env HF_TOKEN \ - --env GITHUB_STEP_SUMMARY \ --volume $GITHUB_ACTION_DIR:$GITHUB_ACTION_DIR \ --volume /mnt/datadrive/TestData/reinforcer/datasets:/opt/reinforcer/datasets:ro \ --volume /mnt/datadrive/TestData/reinforcer/checkpoints:/home/TestData/reinforcer/checkpoints:ro \ @@ -112,7 +111,7 @@ jobs: ${{ inputs.AFTER_SCRIPT }} RUN_TEST_EOF ) - docker exec nemo_container_${{ github.run_id }} bash -eux -o pipefail -c "$cmd" + docker exec --env GITHUB_STEP_SUMMARY nemo_container_${{ github.run_id }} bash -eux -o pipefail -c "$cmd" - name: final_script_external if: always() && inputs.FINAL_SCRIPT_EXTERNAL != ':' diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 92e55f62cd..0bfc2aae9d 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -152,12 +152,12 @@ jobs: # uv run --no-sync bash ./tests/functional/grpo.sh AFTER_SCRIPT: | cd /opt/reinforcer - cat <>> conda initialize >>>/,/# <<< conda initialize << Date: Tue, 8 Apr 2025 01:58:09 +0800 Subject: [PATCH 03/11] fix: fix error padding (#87) Signed-off-by: Yuki Huang Co-authored-by: Terry Kong Signed-off-by: Parth Chadha --- docs/design_docs/generation.md | 18 +++--- examples/run_eval.py | 19 ++++--- examples/run_grpo_math.py | 31 ++++++---- examples/run_sft.py | 29 ++++++---- nemo_reinforcer/algorithms/grpo.py | 5 -- nemo_reinforcer/algorithms/utils.py | 9 +++ nemo_reinforcer/evals/eval.py | 6 -- .../models/generation/interfaces.py | 29 +++++++++- nemo_reinforcer/models/generation/vllm.py | 12 ++-- nemo_reinforcer/models/policy/hf_policy.py | 25 ++++++--- .../models/generation/test_vllm_generation.py | 56 +++++++++---------- .../unit/models/policy/test_hf_ray_policy.py | 39 +++++++++---- 12 files changed, 168 insertions(+), 110 deletions(-) diff --git a/docs/design_docs/generation.md b/docs/design_docs/generation.md index 8dda8a028a..84f450c7cc 100644 --- a/docs/design_docs/generation.md +++ b/docs/design_docs/generation.md @@ -95,26 +95,20 @@ The {py:class}`UpdatableVllmInternalWorker AutoTokenizer: + """Get the tokenizer and set pad token to eos token if it is not already set.""" + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer diff --git a/nemo_reinforcer/evals/eval.py b/nemo_reinforcer/evals/eval.py index 33d486a4d5..a1a4cad74b 100644 --- a/nemo_reinforcer/evals/eval.py +++ b/nemo_reinforcer/evals/eval.py @@ -105,12 +105,6 @@ def setup( backend = generation_config["backend"] assert backend == "vllm", "Only vLLM backend is supported for evaluation" - # set vllm config - generation_config["vllm_cfg"]["load_format"] = "auto" - generation_config["vllm_cfg"]["skip_tokenizer_init"] = False - generation_config["stop_token_ids"] = [tokenizer.eos_token_id] - generation_config["pad_token"] = tokenizer.pad_token_id - # initialize vllm generation vllm_generation = VllmGeneration(cluster=cluster, config=generation_config) print( diff --git a/nemo_reinforcer/models/generation/interfaces.py b/nemo_reinforcer/models/generation/interfaces.py index da7e737784..f81d5d897d 100644 --- a/nemo_reinforcer/models/generation/interfaces.py +++ b/nemo_reinforcer/models/generation/interfaces.py @@ -15,6 +15,8 @@ from typing import Any, TypedDict, Union, Tuple, List import torch +from transformers import AutoTokenizer + from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict @@ -45,8 +47,8 @@ def verify_right_padding( ) assert pad_value is not None, ( - "Tokenizer does not have a pad token assigned. \n" - "If the default tokenizer does not have a pad token, you can assign it the value of eos token by tokenizer.pad_token = tokenizer.eos_token" + "Tokenizer does not have a pad_token_id. \n" + "Please use the nemo_reinforcer.algorithms.utils.get_tokenizer(...) API which sets pad_token_id if absent." ) # Determine which type of data we're dealing with @@ -107,7 +109,28 @@ class GenerationConfig(TypedDict): top_k: int model_name: str stop_token_ids: List[int] - pad_token: int + pad_token_id: int + + +def configure_generation_config( + config: GenerationConfig, tokenizer: AutoTokenizer, is_eval=False +): + """Apply specific configurations to generation config.""" + # tokenizer setting + config["pad_token_id"] = tokenizer.pad_token_id + # When https://github.com/NVIDIA/reinforcer/issues/57 is fixed, we should update stop_token_ids below. + config["stop_token_ids"] = [tokenizer.eos_token_id] + + # vllm setting + if config["backend"] == "vllm": + if is_eval: + config["vllm_cfg"]["skip_tokenizer_init"] = False + config["vllm_cfg"]["load_format"] = "auto" + else: + config["vllm_cfg"]["skip_tokenizer_init"] = True + config["vllm_cfg"]["load_format"] = "dummy" + + return config class GenerationDatumSpec(TypedDict): diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index f69f258327..2f5041b8ee 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -221,7 +221,7 @@ def generate( f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" ) is_right_padded, error_msg = verify_right_padding( - data, pad_value=self.cfg["pad_token"] + data, pad_value=self.cfg["pad_token_id"] ) if not is_right_padded: warnings.warn( @@ -285,7 +285,7 @@ def generate( # Create a new tensor with the right size and fill with padding token full_output = torch.full( - (total_length,), self.cfg["pad_token"], dtype=input_ids.dtype + (total_length,), self.cfg["pad_token_id"], dtype=input_ids.dtype ) # Copy original input (with padding) into the beginning @@ -519,7 +519,9 @@ def generate( results = self.worker_group.get_all_worker_results(future_bundle) # Combine results from all tied worker groups - combined = BatchedDataDict.from_batches(results) + combined = BatchedDataDict.from_batches( + results, pad_value_dict={"output_ids": self.cfg["pad_token_id"]} + ) # Verify the output has all required fields required_keys = [ @@ -560,7 +562,9 @@ def generate_text( results = self.worker_group.get_all_worker_results(future_bundle) # Combine results from all tied worker groups - combined = BatchedDataDict.from_batches(results) + combined = BatchedDataDict.from_batches( + results, pad_value_dict={"output_ids": self.cfg["pad_token_id"]} + ) # Verify the output has all required fields required_keys = ["texts"] diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index 3a316ba3ae..c36bc0fec7 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -28,9 +28,10 @@ StateDictType, ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.algorithms.utils import get_tokenizer from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup @@ -97,10 +98,7 @@ def __init__( ) else: self.reference_model = None - self.tokenizer = AutoTokenizer.from_pretrained(model_name) - # If no pad token is defined, you might need: - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer = get_tokenizer(model_name) # ------------------------------------------------ # 3) Move to GPU + Composable FSDP @@ -519,7 +517,9 @@ def generate( batch_size, seq_len = input_ids.shape # Convert right padding to left padding - left_padded_input_ids = torch.zeros_like(input_ids) + left_padded_input_ids = torch.full_like( + input_ids, gen_cfg["pad_token_id"] + ) left_padded_attention_mask = torch.zeros( (batch_size, seq_len), dtype=torch.long, device=input_ids.device ) @@ -569,7 +569,12 @@ def generate( micro_batches.append(mb) # Get lengths, pad, and concatenate all batches - return_data = BatchedDataDict.from_batches(micro_batches) + return_data = BatchedDataDict.from_batches( + micro_batches, + pad_value_dict={ + "left_padded_output_ids": self.cfg["generation"]["pad_token_id"] + }, + ) # Calculate the lengths of generations for each sequence by finding stop tokens generation_lengths = [] @@ -581,8 +586,9 @@ def generate( max_seq_len = max( [seq.size(0) for seq in return_data["left_padded_output_ids"]] ) - right_padded_output_ids = torch.zeros( + right_padded_output_ids = torch.full( (batch_size, max_seq_len), + self.cfg["generation"]["pad_token_id"], dtype=return_data["left_padded_output_ids"][0].dtype, device=return_data["left_padded_output_ids"][0].device, ) @@ -1017,7 +1023,8 @@ def generate( "generate", sharded_data, common_kwargs={"greedy": greedy} ) result = BatchedDataDict.from_batches( - self.worker_group.get_all_worker_results(futures) + self.worker_group.get_all_worker_results(futures), + pad_value_dict={"output_ids": self.cfg["generation"]["pad_token_id"]}, ) # Verify the output has all required fields diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index a5bcda1ff6..ed90267d10 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy + import pytest import torch import ray -import numpy as np - -from transformers import AutoTokenizer +from nemo_reinforcer.algorithms.utils import get_tokenizer from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.models.generation.interfaces import configure_generation_config from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig @@ -41,19 +42,6 @@ } -def configure_vllm_with_tokenizer(vllm_config, tokenizer, is_eval=False): - """Apply tokenizer-specific configurations to vLLM config.""" - if is_eval: - vllm_config["vllm_cfg"]["skip_tokenizer_init"] = False - vllm_config["vllm_cfg"]["load_format"] = "auto" - else: - vllm_config["vllm_cfg"]["skip_tokenizer_init"] = True - vllm_config["vllm_cfg"]["load_format"] = "dummy" - vllm_config["pad_token"] = tokenizer.pad_token_id - vllm_config["stop_token_ids"] = [tokenizer.eos_token_id] - return vllm_config - - @pytest.fixture(scope="module") def check_vllm_available(): """Skip tests if vLLM is not installed.""" @@ -82,9 +70,7 @@ def cluster(): def tokenizer(): """Initialize tokenizer for the test model.""" model_name = basic_vllm_test_config["model_name"] - tokenizer = AutoTokenizer.from_pretrained(model_name) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + tokenizer = get_tokenizer(model_name) return tokenizer @@ -93,7 +79,7 @@ def policy(cluster, tokenizer, check_vllm_available): """Initialize the vLLM policy.""" # Create separate configs for each policy vllm_config = basic_vllm_test_config.copy() - vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) + vllm_config = configure_generation_config(vllm_config, tokenizer) policy = VllmGeneration(cluster, vllm_config) yield policy @@ -213,7 +199,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Create separate configs for each policy vllm_config = basic_vllm_test_config.copy() - vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer) + vllm_config = configure_generation_config(vllm_config, tokenizer) # Create HF-specific config with required parameters hf_config = { @@ -252,6 +238,17 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): "Where is the sun?", ] + expected_generations = [ + "Write a story about a magical forest. The forest is magical because it is full of", + "Explain how photosynthesis works\nExplain how photosynthesis works\nPhotosynthesis", + "What are the benefits of exercise? The benefits of exercise are many and varied. It", + "Describe the water cycle in your own words.\nDescribe the water cycle in", + "What is the capital of France? A. Paris B. New York C. Washington", + "Who is the president of the USA? Who is the president of the USA? Who is", + "What is the capital of the moon? A. Houston, Texas B. New York City", + "Where is the sun? Where is the moon? Where is the earth?", + ] + # Tokenize the prompts the same way as in test_hf_ray_policy tokenized = tokenizer( prompts, @@ -286,7 +283,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Step 1: Use vLLM for generation print("Using vLLM policy for fast generation...") - generation_results = vllm_policy.generate(test_input_data) + generation_results = vllm_policy.generate(test_input_data, greedy=True) vllm_policy.finish_generation() # Validate generation outputs assert "output_ids" in generation_results, ( @@ -301,6 +298,9 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): generation_results["output_ids"], skip_special_tokens=True ) print(f"vLLM generated texts: {generated_texts}") + assert generated_texts == expected_generations, ( + "Output should be the same as the expected output" + ) # Run logprob calculation with HF policy to verify @@ -401,9 +401,9 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): def test_vllm_policy_tensor_parallel(cluster, tokenizer): """Test vLLM policy with tensor parallelism > 1.""" # Configure with tensor_parallel_size=2 - tp_config = basic_vllm_test_config.copy() - tp_config = configure_vllm_with_tokenizer(tp_config, tokenizer) - tp_config["tensor_parallel_size"] = 2 + tp_config = deepcopy(basic_vllm_test_config) + tp_config = configure_generation_config(tp_config, tokenizer) + tp_config["vllm_cfg"]["tensor_parallel_size"] = 2 # Ensure we specify the distributed executor backend tp_config["vllm_kwargs"] = {"distributed_executor_backend": "ray"} @@ -466,7 +466,7 @@ def test_vllm_generate_text(cluster, tokenizer): # Create separate configs for each policy vllm_config = basic_vllm_test_config.copy() - vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer, is_eval=True) + vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) # Ensure we can get same output assert vllm_config["model_name"] == "meta-llama/Llama-3.2-1B", ( @@ -499,8 +499,8 @@ def test_vllm_weight_update_and_prefix_cache_reset( from nemo_reinforcer.models.policy.hf_policy import HfPolicy # Create configs - vllm_config = basic_vllm_test_config.copy() - vllm_config = configure_vllm_with_tokenizer(vllm_config, tokenizer, is_eval=True) + vllm_config = deepcopy(basic_vllm_test_config) + vllm_config = configure_generation_config(vllm_config, tokenizer, is_eval=True) vllm_config["vllm_cfg"]["tensor_parallel_size"] = tensor_parallel_size if tensor_parallel_size > 1: vllm_config["vllm_kwargs"] = {"distributed_executor_backend": "ray"} diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index ded244feac..76926960cf 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -16,13 +16,14 @@ import pprint import torch -from nemo_reinforcer.models.policy import PolicyConfig -from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.algorithms.utils import get_tokenizer from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict -from nemo_reinforcer.algorithms.interfaces import LossFunction +from nemo_reinforcer.models.generation.interfaces import configure_generation_config +from nemo_reinforcer.models.policy import PolicyConfig +from nemo_reinforcer.models.policy.hf_policy import HfPolicy from tests.unit.test_utils import simple_loss, nll_loss -from transformers import AutoTokenizer basic_llama_test_config: PolicyConfig = { @@ -66,8 +67,16 @@ def gc_collect(): gc.collect() +@pytest.fixture(scope="function") +def tokenizer(): + """Initialize tokenizer for the test model.""" + model_name = basic_llama_test_config["model_name"] + tokenizer = get_tokenizer(model_name) + return tokenizer + + @pytest.fixture -def policy_setup(): +def policy_setup(tokenizer): """Setup and teardown for policy tests - creates a virtual cluster and policy.""" policy = None cluster = None @@ -84,6 +93,7 @@ def policy_setup(): ) config = basic_llama_test_config + config["generation"] = configure_generation_config(config["generation"], tokenizer) print("Creating HfPolicy...") policy = HfPolicy(cluster=cluster, config=config) @@ -278,7 +288,7 @@ def verify_loss_tensor(loss_tensor): @pytest.fixture -def generation_setup(request): +def generation_setup(request, tokenizer): """Setup and teardown specifically for generation tests.""" policy = None cluster = None @@ -298,6 +308,9 @@ def generation_setup(request): ) config = basic_llama_test_config + config["generation"] = configure_generation_config( + config["generation"], tokenizer + ) print("Creating generation HfPolicy...") policy = HfPolicy( @@ -331,8 +344,6 @@ def generation_setup(request): ] # Tokenize the prompts - tokenizer = AutoTokenizer.from_pretrained(config["model_name"]) - tokenizer.pad_token = tokenizer.eos_token tokenized = tokenizer( prompts, padding=True, @@ -353,7 +364,7 @@ def generation_setup(request): ) # Provide the resources to the test - yield policy, cluster, data, tokenizer, prompts, expected_generations + yield policy, cluster, data, prompts, expected_generations except Exception as e: print(f"Error during generation setup: {e}") @@ -367,8 +378,8 @@ def generation_setup(request): @pytest.mark.timeout(180) @pytest.mark.parametrize("generation_setup", [False], indirect=True) -def test_hf_policy_generation(generation_setup, tracker): - policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup +def test_hf_policy_generation(generation_setup, tokenizer, tracker): + policy, cluster, data, prompts, expected_generations = generation_setup # Verify resources were created properly assert policy is not None, "Generation policy was not created properly" @@ -386,6 +397,10 @@ def test_hf_policy_generation(generation_setup, tracker): # Verify results assert "output_ids" in results, "Generation results should contain 'output_ids'" output_ids = results["output_ids"] + generated_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) + assert generated_texts == expected_generations, ( + "Output should be the same as the expected output" + ) # run logprob calculation manually to verify fprop_logprob_data = BatchedDataDict( @@ -455,7 +470,7 @@ def test_hf_policy_generation(generation_setup, tracker): @pytest.mark.timeout(180) @pytest.mark.parametrize("generation_setup", [True], indirect=True) def test_all_hf_policy_generation_lps_ref_training(generation_setup): - policy, cluster, data, tokenizer, prompts, expected_generations = generation_setup + policy, cluster, data, prompts, expected_generations = generation_setup # Verify resources were created properly assert policy is not None, "Generation policy was not created properly" From 01e674f4eb35bd84f60e7e4cd8f1abfcd8149b70 Mon Sep 17 00:00:00 2001 From: Anna Shors Date: Mon, 7 Apr 2025 15:14:39 -0700 Subject: [PATCH 04/11] feat: Distributed checkpointing (#99) Signed-off-by: ashors1 Signed-off-by: Anna Shors Signed-off-by: Parth Chadha --- docs/design_docs/checkpointing.md | 22 ++ docs/index.md | 1 + examples/configs/grpo_math_1B.yaml | 1 + examples/configs/grpo_math_8B.yaml | 1 + examples/configs/sft.yaml | 1 + examples/convert_dcp_to_hf.py | 92 ++++++ nemo_reinforcer/algorithms/grpo.py | 18 +- nemo_reinforcer/algorithms/sft.py | 27 +- nemo_reinforcer/models/policy/__init__.py | 1 + nemo_reinforcer/models/policy/hf_policy.py | 132 ++++---- nemo_reinforcer/utils/native_checkpoint.py | 204 +++++++++++++ tests/functional/sft.sh | 7 +- .../models/generation/test_vllm_generation.py | 3 + .../unit/models/policy/test_hf_ray_policy.py | 1 + tests/unit/utils/test_native_checkpoint.py | 288 ++++++++++++++++++ 15 files changed, 707 insertions(+), 92 deletions(-) create mode 100644 docs/design_docs/checkpointing.md create mode 100644 examples/convert_dcp_to_hf.py create mode 100644 nemo_reinforcer/utils/native_checkpoint.py create mode 100755 tests/unit/utils/test_native_checkpoint.py diff --git a/docs/design_docs/checkpointing.md b/docs/design_docs/checkpointing.md new file mode 100644 index 0000000000..9b9a6f6826 --- /dev/null +++ b/docs/design_docs/checkpointing.md @@ -0,0 +1,22 @@ +# Checkpointing with HuggingFace Models + +## Checkpoint Format +Reinforcer provides two checkpoint formats for HuggingFace models: Torch distributed and HuggingFace format. Torch distributed is used by default for efficiency, and HuggingFace format is provided for compatibility with HuggingFace's `AutoModel.from_pretrained` API. Note that HuggingFace format checkpoints save only the model weights, ignoring the optimizer states. It is recommended to use Torch distributed format to save intermediate checkpoints and to save a HuggingFace checkpoint only at the end of training. + +There are two ways to get a Reinforcer checkpoint in HuggingFace format. + +1. (Recommended) Save the HuggingFace checkpoint directly by passing `save_hf=True` to `HFPolicy`'s `save_checkpoint`: + + ```python + policy.save_checkpoint( + weights_path=, + optimizer_path=, + save_torch_dist=True, + save_hf=True, + ) + ``` +2. Convert a Torch distributed checkpoint checkpoint to HuggingFace format after training. We provide a conversion script for this purpose. + + ```python + uv run examples/convert_dcp_to_hf.py --config= --dcp-ckpt-path= --hf-ckpt-path= + ``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 0628f19953..553778ff98 100644 --- a/docs/index.md +++ b/docs/index.md @@ -47,4 +47,5 @@ design_docs/logger.md design_docs/uv.md design_docs/chat_datasets.md design_docs/generation.md +design_docs/checkpointing.md ``` diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 422e869f56..3d8fdfce43 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -25,6 +25,7 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B-Instruct" + tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 train_micro_batch_size: 4 generation_batch_size: 32 # Only used when generating using HF backend diff --git a/examples/configs/grpo_math_8B.yaml b/examples/configs/grpo_math_8B.yaml index 261db927b1..f2e0576fbc 100644 --- a/examples/configs/grpo_math_8B.yaml +++ b/examples/configs/grpo_math_8B.yaml @@ -7,6 +7,7 @@ grpo: policy: model_name: "meta-llama/Llama-3.1-8B-Instruct" + tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 512 train_micro_batch_size: 1 generation_batch_size: 32 # Only used when generating using HF backend diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index e4b116a351..bb8467165f 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -18,6 +18,7 @@ checkpointing: policy: model_name: "meta-llama/Llama-3.2-1B" + tokenizer_name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default train_global_batch_size: 32 train_micro_batch_size: 1 max_total_sequence_length: 1024 diff --git a/examples/convert_dcp_to_hf.py b/examples/convert_dcp_to_hf.py new file mode 100644 index 0000000000..ee347eeb9e --- /dev/null +++ b/examples/convert_dcp_to_hf.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import json + +from nemo_reinforcer.distributed.virtual_cluster import init_ray, RayVirtualCluster +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from nemo_reinforcer.utils.config import load_config + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Convert Torch DCP checkpoint to HF checkpoint" + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to config.json file in the checkpoint directory", + ) + parser.add_argument( + "--dcp-ckpt-path", type=str, default=None, help="Path to DCP checkpoint" + ) + parser.add_argument( + "--hf-ckpt-path", type=str, default=None, help="Path to save HF checkpoint" + ) + # Parse known args for the script + args = parser.parse_args() + + return args + + +def main(): + """Main entry point.""" + args = parse_args() + + with open(args.config, "r") as f: + config = json.load(f) + + dcp_ckpt = args.dcp_ckpt_path + hf_ckpt = args.hf_ckpt_path + + # Extract individual configs for easier access + policy_config = config["policy"] + cluster_config = config["cluster"] + + init_ray() + + cluster = RayVirtualCluster( + name="convert_cluster", + bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] + * cluster_config["num_nodes"], + use_gpus=True, + num_gpus_per_node=cluster_config["gpus_per_node"], + max_colocated_worker_groups=1, + ) + + policy = HfPolicy( + cluster=cluster, + config=policy_config, + weights_path=dcp_ckpt, + init_optimizer=False, + ) + + policy.save_checkpoint( + weights_path=os.path.abspath(hf_ckpt), + save_hf=True, + save_torch_dist=False, + ) + + print(f"Saved HF checkpoint to: {hf_ckpt}-hf") + + cluster.shutdown() + policy.worker_group.shutdown() + + +if __name__ == "__main__": + main() diff --git a/nemo_reinforcer/algorithms/grpo.py b/nemo_reinforcer/algorithms/grpo.py index 6ea5a56dd6..0eda853375 100644 --- a/nemo_reinforcer/algorithms/grpo.py +++ b/nemo_reinforcer/algorithms/grpo.py @@ -236,10 +236,10 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, - weights_path=Path(last_checkpoint_path) / "policy.pt" + weights_path=Path(last_checkpoint_path) / "policy" / "weights" if last_checkpoint_path else None, - optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" if last_checkpoint_path else None, init_optimizer=True, @@ -608,6 +608,13 @@ def grpo_train( and (step + 1) % master_config["checkpointing"]["save_period"] == 0 ): # +1 because step is 0-indexed policy.prepare_for_training() + + is_last_checkpoint = ( + min(len(dataloader), master_config["grpo"]["max_num_steps"]) + - (step + 1) + < master_config["checkpointing"]["save_period"] + ) + grpo_save_state["step"] = step + 1 grpo_save_state["val_reward"] = val_metrics["accuracy"] grpo_save_state["consumed_samples"] = consumed_samples @@ -617,8 +624,11 @@ def grpo_train( step + 1, grpo_save_state, master_config ) policy.save_checkpoint( - os.path.join(checkpoint_path, "policy.pt"), - os.path.join(checkpoint_path, "policy_optimizer.pt"), + weights_path=os.path.join(checkpoint_path, "policy", "weights"), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + save_hf=is_last_checkpoint, ) torch.save( dataloader.state_dict(), diff --git a/nemo_reinforcer/algorithms/sft.py b/nemo_reinforcer/algorithms/sft.py index 8f9e34f9da..b5bb41aec5 100644 --- a/nemo_reinforcer/algorithms/sft.py +++ b/nemo_reinforcer/algorithms/sft.py @@ -175,10 +175,10 @@ def setup( policy = HfPolicy( cluster=cluster, config=policy_config, - weights_path=Path(last_checkpoint_path) / "policy.pt" + weights_path=Path(last_checkpoint_path) / "policy" / "weights" if last_checkpoint_path else None, - optimizer_path=Path(last_checkpoint_path) / "policy_optimizer.pt" + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" if last_checkpoint_path else None, init_optimizer=True, @@ -311,9 +311,7 @@ def sft_train( sft_save_state = _default_sft_save_state() step = 0 else: - step = ( - sft_save_state["step"] + 1 - ) # N+1 because the checkpoint is _after_ SFT iteration N + step = sft_save_state["step"] sft_config = master_config["sft"] # Validation configuration @@ -399,19 +397,26 @@ def sft_train( master_config["checkpointing"]["enabled"] and (step + 1) % master_config["checkpointing"]["save_period"] == 0 ): # +1 because step is 0-indexed - sft_save_state["step"] = step + is_last_checkpoint = ( + min(len(train_dataloader), master_config["sft"]["max_num_steps"]) + - (step + 1) + < master_config["checkpointing"]["save_period"] + ) + + sft_save_state["step"] = step + 1 sft_save_state["val_loss"] = val_metrics["val_loss"] with timer.time("checkpointing"): print(f"Saving checkpoint for step {step + 1}...") checkpoint_path = checkpointer.init_tmp_checkpoint( step + 1, sft_save_state, master_config ) + policy.save_checkpoint( - os.path.join(checkpoint_path, "policy.pt"), - os.path.join(checkpoint_path, "policy_optimizer.pt"), - ## NOTE: below is a workaround to avoid a bug with checkpointing - ## this should be removed once the bug is fixed - offload_to_cpu=False, + weights_path=os.path.join(checkpoint_path, "policy", "weights"), + optimizer_path=os.path.join( + checkpoint_path, "policy", "optimizer" + ), + save_hf=is_last_checkpoint, ) torch.save( train_dataloader.state_dict(), diff --git a/nemo_reinforcer/models/policy/__init__.py b/nemo_reinforcer/models/policy/__init__.py index ee2bf2389e..24390b9670 100644 --- a/nemo_reinforcer/models/policy/__init__.py +++ b/nemo_reinforcer/models/policy/__init__.py @@ -19,6 +19,7 @@ class PolicyConfig(TypedDict): model_name: str + tokenizer_name: str train_global_batch_size: int train_micro_batch_size: int learning_rate: float diff --git a/nemo_reinforcer/models/policy/hf_policy.py b/nemo_reinforcer/models/policy/hf_policy.py index c36bc0fec7..051e56e23f 100644 --- a/nemo_reinforcer/models/policy/hf_policy.py +++ b/nemo_reinforcer/models/policy/hf_policy.py @@ -23,9 +23,7 @@ from torch.distributed.device_mesh import init_device_mesh from torch.distributed.fsdp import ( FullyShardedDataParallel, - FullStateDictConfig, MixedPrecision, - StateDictType, ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy from transformers import AutoModelForCausalLM @@ -47,6 +45,10 @@ from nemo_reinforcer.distributed.virtual_cluster import ( PY_EXECUTABLES, ) +from nemo_reinforcer.utils.native_checkpoint import ( + save_checkpoint, + load_checkpoint, +) @ray.remote @@ -77,6 +79,7 @@ def __init__( rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() model_name = self.cfg["model_name"] + tokenizer_name = self.cfg["tokenizer_name"] if self.cfg["precision"] == "float32": self.dtype = torch.float32 elif self.cfg["precision"] == "bfloat16": @@ -98,7 +101,7 @@ def __init__( ) else: self.reference_model = None - self.tokenizer = get_tokenizer(model_name) + self.tokenizer = get_tokenizer(tokenizer_name) # ------------------------------------------------ # 3) Move to GPU + Composable FSDP @@ -139,7 +142,7 @@ def do_fsdp(model): else: self.optimizer = None - if "scheduler" in self.cfg: + if "scheduler" in self.cfg and self.optimizer is not None: if isinstance(self.cfg["scheduler"], dict): scheduler_cls = import_class_from_path(self.cfg["scheduler"]["name"]) self.scheduler = scheduler_cls( @@ -165,7 +168,7 @@ def do_fsdp(model): self.optimizer, schedulers, milestones ) - else: + elif self.optimizer is not None: ## default to a passthrough LR schedule self.scheduler = torch.optim.lr_scheduler.LambdaLR( self.optimizer, lr_lambda=lambda epoch: 1 @@ -173,7 +176,10 @@ def do_fsdp(model): # restore if weights_path: - self.load_checkpoint(weights_path, optimizer_path) + self.load_checkpoint( + weights_path, + optimizer_path, + ) else: print( "No weights path provided. Starting from scratch (default policy init)" @@ -817,78 +823,50 @@ def save_checkpoint( self, weights_path: str, optimizer_path: Optional[str] = None, - offload_to_cpu: bool = True, + save_torch_dist: bool = True, + save_hf: bool = False, ): - # Config to save full state dict on rank 0, offloaded to CPU - state_dict_config = FullStateDictConfig( - offload_to_cpu=offload_to_cpu, rank0_only=True + """Save a checkpoint of the model. + + The checkpoint is saved in the following format: + + weights_path/ + __0_1.distcp + __1_0.distcp + ... + weights_path-hf/ + config.json + generation_config.json + model-00001-of-.safetensors + ... + model.safetensors.index.json + optimizer_path/ + __0_0.distcp + __1_0.distcp + ... + + the HuggingFace checkpoint is saved only if `save_hf` is True, + and the optimizer states are saved only if `optimizer` and `optimizer_path` are provided. + """ + save_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer if optimizer_path else None, + scheduler=self.scheduler if optimizer_path else None, + optimizer_path=optimizer_path, + save_torch_dist=save_torch_dist, + save_hf=save_hf, ) - with FullyShardedDataParallel.state_dict_type( - self.model, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=state_dict_config, - ): - # Save model state dict - model_state_dict = self.model.state_dict() - optim_state_dict = FullyShardedDataParallel.optim_state_dict( - self.model, self.optimizer - ) - scheduler_state_dict = self.scheduler.state_dict() - - optim_and_scheduler_state_dict = { - "optimizer": optim_state_dict, - "scheduler": scheduler_state_dict, - } - - if torch.distributed.get_rank() == 0: - # check if weights_path dir exists - weights_dir = os.path.dirname(weights_path) - if not os.path.exists(weights_dir): - print( - f"Creating weights directory {weights_dir} DOESN'T EXIST SOMEHOW" - ) - os.makedirs(weights_dir) - torch.save(model_state_dict, weights_path) - if optimizer_path is not None: - torch.save(optim_and_scheduler_state_dict, optimizer_path) - def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): - print(f"Loading Policy from {weights_path} and optimizer from {optimizer_path}") - state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - - state_dict = torch.load(weights_path) - if optimizer_path is not None: - optim_data = torch.load(optimizer_path) - optimizer_state_dict = optim_data["optimizer"] - scheduler_state_dict = optim_data.get("scheduler") - else: - optimizer_state_dict = None - scheduler_state_dict = None - with FullyShardedDataParallel.state_dict_type( - self.model, - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=state_dict_config, - ): - # Load model weights - self.model.load_state_dict(state_dict if state_dict else None) - - # Load optimizer state - if optimizer_state_dict is not None: - optim_state_dict = FullyShardedDataParallel.shard_full_optim_state_dict( - optimizer_state_dict, self.model - ) - if self.optimizer is not None: - self.optimizer.load_state_dict(optim_state_dict) - else: - print("WARNING: initializing without optimizer") - else: - print("WARNING: No optimizer checkpoint provided") - - if scheduler_state_dict is not None: - self.scheduler.load_state_dict(scheduler_state_dict) - else: - print("WARNING: No scheduler checkpoint provided") + """Load a checkpoint into the model.""" + load_checkpoint( + model=self.model, + weights_path=weights_path, + optimizer=self.optimizer if optimizer_path else None, + scheduler=self.scheduler if optimizer_path else None, + optimizer_path=optimizer_path, + ) def shutdown(self): """Shutdown the policy.""" @@ -1107,14 +1085,16 @@ def save_checkpoint( self, weights_path: str, optimizer_path: Optional[str] = None, - offload_to_cpu: bool = True, + save_torch_dist: bool = True, + save_hf: bool = False, ): """Save a checkpoint of the model.""" futures = self.worker_group.run_all_workers_single_data( "save_checkpoint", weights_path, optimizer_path, - offload_to_cpu=offload_to_cpu, + save_torch_dist, + save_hf, respect_tied_workers=True, ) ray.get(futures) diff --git a/nemo_reinforcer/utils/native_checkpoint.py b/nemo_reinforcer/utils/native_checkpoint.py new file mode 100644 index 0000000000..6f22ea82fd --- /dev/null +++ b/nemo_reinforcer/utils/native_checkpoint.py @@ -0,0 +1,204 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint management utilities for HF models.""" + +import os +from pathlib import Path +from typing import Any, Optional +import torch + +import torch.distributed.checkpoint as dcp +from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.state_dict import ( + get_model_state_dict, + set_model_state_dict, + get_optimizer_state_dict, + set_optimizer_state_dict, +) + + +## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html +class ModelState(Stateful): + """Helper class for tracking model state in distributed checkpointing. + + This class is compliant with the Stateful protocol, allowing DCP to automatically + call state_dict/load_state_dict as needed in the dcp.save/load APIs. + + Args: + model: The PyTorch model to track. + """ + + def __init__(self, model): + self.model = model + + def state_dict(self): + """Get the model's state dictionary. + + Returns: + dict: Dictionary containing the model's state dict with CPU offloading enabled. + """ + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + model_state_dict = get_model_state_dict( + self.model, + options=torch.distributed.checkpoint.state_dict.StateDictOptions( + cpu_offload=True + ), + ) + return model_state_dict + + def load_state_dict(self, state_dict): + """Load the state dictionary into the model. + + Args: + state_dict (dict): State dictionary to load. + """ + # sets our state dicts on the model, now that we've loaded + set_model_state_dict( + self.model, + state_dict, + ) + + +class OptimizerState(Stateful): + """Helper class for tracking optimizer state in distributed checkpointing. + + This class is compliant with the Stateful protocol, allowing DCP to automatically + call state_dict/load_state_dict as needed in the dcp.save/load APIs. + + Args: + model: The PyTorch model associated with the optimizer. + optimizer: The optimizer to track. + scheduler: Optional learning rate scheduler. + """ + + def __init__(self, model, optimizer, scheduler=None): + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + + def state_dict(self): + """Get the optimizer and scheduler state dictionaries. + + Returns: + dict: Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled. + """ + # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT + optimizer_state_dict = get_optimizer_state_dict( + self.model, + self.optimizer, + options=torch.distributed.checkpoint.state_dict.StateDictOptions( + cpu_offload=True + ), + ) + + state_dict = { + "optim": optimizer_state_dict, + } + if self.scheduler is not None: + state_dict["sched"] = self.scheduler.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Load the state dictionaries into the optimizer and scheduler. + + Args: + state_dict (dict): State dictionary containing optimizer and scheduler states to load. + """ + # sets our state dicts on the optimizer, now that we've loaded + set_optimizer_state_dict( + self.model, + self.optimizer, + state_dict["optim"], + ) + + ## load the scheduler state if it exists + if "sched" in state_dict: + self.scheduler.load_state_dict(state_dict["sched"]) + + +def save_checkpoint( + model, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[Any] = None, + optimizer_path: Optional[str] = None, + save_torch_dist: bool = True, + save_hf: bool = False, +) -> None: + """Save a checkpoint of the model and optionally optimizer state. + + Args: + model: The PyTorch model to save + weights_path: Path to save model weights + optimizer: Optional optimizer to save + scheduler: Optional scheduler to save + optimizer_path: Path to save optimizer state (required if optimizer provided) + save_torch_dist: Whether to save in PyTorch distributed format + save_hf: Whether to save in HuggingFace format + """ + if save_hf: + model_state_dict = model._fsdp_wrapped_module.state_dict() + + if torch.distributed.get_rank() == 0: + # Create a new path by appending "-hf" to the weights path + hf_weights_path = f"{Path(weights_path)}-hf" + + model.save_pretrained( + hf_weights_path, + state_dict=model_state_dict, + ) + + if save_torch_dist: + model_state = {"model": ModelState(model)} + dcp.save(model_state, checkpoint_id=weights_path) + + if optimizer is not None: + if optimizer_path is None: + raise ValueError( + "optimizer_path must be provided when saving optimizer state" + ) + optimizer_state = {"optim": OptimizerState(model, optimizer, scheduler)} + dcp.save(optimizer_state, checkpoint_id=optimizer_path) + + +def load_checkpoint( + model, + weights_path: str, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[Any] = None, + optimizer_path: Optional[str] = None, +) -> None: + """Load a model weights and optionally optimizer state. + + Args: + model: The PyTorch model whose weights to update + weights_path: Path to load model weights from + optimizer: Optional optimizer to load state into + scheduler: Optional scheduler to load state into + optimizer_path: Path to load optimizer state from (required if optimizer provided) + """ + print(f"Loading weights from {weights_path}") + model_state_dict = {"model": ModelState(model)} + dcp.load(state_dict=model_state_dict, checkpoint_id=weights_path) + + if optimizer is not None: + if optimizer_path is None: + raise ValueError( + "optimizer_path must be provided when loading optimizer state" + ) + print(f"Loading optimizer from {optimizer_path}") + optimizer_state_dict = {"optim": OptimizerState(model, optimizer, scheduler)} + dcp.load(state_dict=optimizer_state_dict, checkpoint_id=optimizer_path) diff --git a/tests/functional/sft.sh b/tests/functional/sft.sh index 82d263c9da..85282d64f4 100755 --- a/tests/functional/sft.sh +++ b/tests/functional/sft.sh @@ -1,5 +1,8 @@ #!/bin/bash +## clean up checkpoint directory on exit +trap "rm -rf /tmp/sft_checkpoints" EXIT + SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) # Mark the current repo as safe, since wandb fetchs metadata about the repo @@ -26,7 +29,9 @@ python -u $PROJECT_ROOT/examples/run_sft.py \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \ - checkpointing.enabled=false \ + checkpointing.enabled=true \ + checkpointing.save_every_n_steps=10 \ + checkpointing.checkpoint_dir=/tmp/sft_checkpoints \ $@ \ 2>&1 | tee $RUN_LOG diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index ed90267d10..ba1bade3fc 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -29,6 +29,7 @@ basic_vllm_test_config: VllmConfig = { "backend": "vllm", "model_name": "meta-llama/Llama-3.2-1B", # Small model for testing + "tokenizer_name": "meta-llama/Llama-3.2-1B", "dtype": "bfloat16", "max_new_tokens": 10, "temperature": 1.0, @@ -204,6 +205,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer): # Create HF-specific config with required parameters hf_config = { "model_name": basic_vllm_test_config["model_name"], + "tokenizer_name": basic_vllm_test_config["tokenizer_name"], # Required training parameters "train_global_batch_size": 4, "train_micro_batch_size": 1, @@ -507,6 +509,7 @@ def test_vllm_weight_update_and_prefix_cache_reset( hf_config = { "model_name": basic_vllm_test_config["model_name"], + "tokenizer_name": "meta-llama/Llama-3.2-1B", "train_global_batch_size": 1, "train_micro_batch_size": 1, "learning_rate": 1e-6, diff --git a/tests/unit/models/policy/test_hf_ray_policy.py b/tests/unit/models/policy/test_hf_ray_policy.py index 76926960cf..7cde591049 100644 --- a/tests/unit/models/policy/test_hf_ray_policy.py +++ b/tests/unit/models/policy/test_hf_ray_policy.py @@ -28,6 +28,7 @@ basic_llama_test_config: PolicyConfig = { "model_name": "meta-llama/Llama-3.2-1B", + "tokenizer_name": "meta-llama/Llama-3.2-1B", "generation_batch_size": 1, # Small batch size for testing "train_global_batch_size": 4, "train_micro_batch_size": 1, diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py new file mode 100755 index 0000000000..8f71badea1 --- /dev/null +++ b/tests/unit/utils/test_native_checkpoint.py @@ -0,0 +1,288 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import os +import pytest +import torch +from tempfile import TemporaryDirectory + +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.distributed.virtual_cluster import RayVirtualCluster +from nemo_reinforcer.models.policy.hf_policy import HfPolicy +from transformers import AutoTokenizer, AutoModelForCausalLM +from nemo_reinforcer.utils.native_checkpoint import ( + load_checkpoint, + save_checkpoint, + ModelState, + OptimizerState, +) + +# Define basic test config +simple_policy_config = { + "model_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", + "tokenizer_name": "meta-llama/Llama-3.2-1B", # "hf-internal-testing/tiny-random-Gemma3ForCausalLM", + "train_global_batch_size": 32, + "train_micro_batch_size": 1, + "logprob_batch_size": 1, + "max_total_sequence_length": 1024, + "precision": "float32", +} + + +@pytest.fixture +def mock_experiment(): + model = torch.nn.ModuleList( + [ + torch.nn.Linear(4, 4), + torch.nn.LayerNorm(4), + torch.nn.ReLU(), + torch.nn.Linear(4, 1), + ] + ).to("cuda") + + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) + + return model, optimizer, scheduler + + +@pytest.fixture(scope="module") +def cluster(): + """Create a virtual cluster for testing.""" + # Create a cluster with 2 GPU + virtual_cluster = RayVirtualCluster( + bundle_ct_per_node_list=[2], # 1 node with 2 GPU bundle + use_gpus=True, + max_colocated_worker_groups=1, + num_gpus_per_node=2, # Use available GPUs + name="test-cluster", + ) + yield virtual_cluster + virtual_cluster.shutdown() + + +@pytest.fixture(scope="function") +def tokenizer(): + """Initialize tokenizer for the test model.""" + tokenizer_name = simple_policy_config["tokenizer_name"] + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return tokenizer + + +@pytest.fixture(scope="function") +def policy(cluster, tokenizer): + """Initialize the policy.""" + return HfPolicy( + cluster=cluster, + config=simple_policy_config, + init_optimizer=False, + init_reference_model=False, + ) + + +def get_dummy_state_dict(state_dict, dummy_dict={}): + """Recursively get the dummy state dict + by replacing tensors with random ones of the same shape. + """ + for k in state_dict.keys(): + if isinstance(state_dict[k], dict): + dummy_dict[k] = get_dummy_state_dict(state_dict[k], {}) + elif isinstance(state_dict[k], torch.Tensor): + dummy_dict[k] = torch.randn(state_dict[k].shape) + else: + dummy_dict[k] = state_dict[k] + return dummy_dict + + +def check_dict_equality(dict1, dict2): + """Recursively check equality of two dictionaries""" + for k in dict1.keys(): + if isinstance(dict1[k], dict): + check_dict_equality(dict1[k], dict2[k]) + elif isinstance(dict1[k], torch.Tensor): + assert torch.allclose(dict1[k], dict2[k]) + else: + assert dict1[k] == dict2[k] + + +def test_model_state(mock_experiment): + test_model, _, _ = mock_experiment + model_state = ModelState(test_model) + state_dict = model_state.state_dict() + + ## relu has no parameters + expected_keys = { + "0.bias", + "0.weight", + "1.bias", + "1.weight", + "3.bias", + "3.weight", + } + assert set(state_dict.keys()) == expected_keys + + dummy_model_state_dict = get_dummy_state_dict(state_dict, {}) + + ## update the model's state dict and verify that the model's parameters are updated + model_state.load_state_dict(dummy_model_state_dict) + new_model_state_dict = model_state.state_dict() + check_dict_equality(new_model_state_dict, dummy_model_state_dict) + + +def test_optimizer_state(mock_experiment): + test_model, optimizer, scheduler = mock_experiment + + optim_state = OptimizerState(test_model, optimizer, scheduler) + state_dict = optim_state.state_dict() + + assert set(state_dict.keys()) == {"optim", "sched"} + + ## relu has no parameters + expected_keys = { + "0.bias", + "0.weight", + "1.bias", + "1.weight", + "3.bias", + "3.weight", + } + + assert set(state_dict["optim"]["state"].keys()) == expected_keys + + dummy_state_dict = get_dummy_state_dict(state_dict, {}) + + optim_state.load_state_dict(dummy_state_dict) + new_state_dict = optim_state.state_dict() + check_dict_equality(new_state_dict, dummy_state_dict) + + +def test_save_and_load_model_only(mock_experiment): + test_model, _, _ = mock_experiment + + with TemporaryDirectory() as tmp_dir: + save_checkpoint(test_model, os.path.join(tmp_dir, "test_model_only")) + assert os.path.exists(os.path.join(tmp_dir, "test_model_only")) + assert not os.path.exists(os.path.join(tmp_dir, "test_model_only-hf")) + assert set(os.listdir(os.path.join(tmp_dir, "test_model_only"))) == { + ".metadata", + "__0_0.distcp", + } + + +def test_save_and_load_model_and_optimizer(mock_experiment): + test_model, optimizer, scheduler = mock_experiment + for _ in range(5): + scheduler.step() + + with TemporaryDirectory() as tmp_dir: + save_checkpoint( + test_model, + os.path.join(tmp_dir, "model_and_optimizer/model"), + optimizer, + scheduler, + optimizer_path=os.path.join(tmp_dir, "model_and_optimizer/optimizer"), + ) + + assert set(os.listdir(os.path.join(tmp_dir, "model_and_optimizer/model"))) == { + ".metadata", + "__0_0.distcp", + } + assert set( + os.listdir(os.path.join(tmp_dir, "model_and_optimizer/optimizer")) + ) == { + ".metadata", + "__0_0.distcp", + } + + ## modify the model, optimizer, and scheduler and verify that loading the checkpoint overrides the values + new_linear = torch.nn.Linear(4, 4) + new_linear.weight = torch.nn.Parameter(torch.ones([4, 4]).to("cuda")) + new_linear.bias = torch.nn.Parameter(torch.ones(4).to("cuda")) + new_model = torch.nn.ModuleList( + [ + new_linear, + torch.nn.LayerNorm(4), + torch.nn.ReLU(), + torch.nn.Linear(4, 1), + ] + ).to("cuda") + + new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.001) + new_scheduler = torch.optim.lr_scheduler.StepLR( + new_optimizer, step_size=4, gamma=0.2 + ) + load_checkpoint( + new_model, + os.path.join(tmp_dir, "model_and_optimizer/model"), + new_optimizer, + new_scheduler, + optimizer_path=os.path.join(tmp_dir, "model_and_optimizer/optimizer"), + ) + + assert scheduler.state_dict() == new_scheduler.state_dict() + check_dict_equality(new_model.state_dict(), test_model.state_dict()) + check_dict_equality(new_optimizer.state_dict(), optimizer.state_dict()) + + +def test_save_and_load_hf_checkpoint(policy): + ## warm up with a forward pass + ## this is needed before saving a checkpoint because FSDP does some lazy initialization + input_ids = torch.randint(0, 16000, (4, 128)) # 4 sequences, each of length 128 + attention_mask = torch.ones(4, 128) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + dummy_fwd_dict = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + "labels": torch.randint(0, 16000, (4, 128)), + } + ) + policy.get_logprobs(dummy_fwd_dict) + + with TemporaryDirectory() as tmp_dir: + policy.save_checkpoint( + os.path.join(tmp_dir, "test_hf_and_dcp"), + save_hf=True, + save_torch_dist=True, + ) + + ## make sure we save both HF and DCP checkpoints + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp"))) == { + "__0_0.distcp", + "__1_0.distcp", + ".metadata", + } + ## 1B model has two shards + assert set(os.listdir(os.path.join(tmp_dir, "test_hf_and_dcp-hf"))) == { + "config.json", + "generation_config.json", + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + "model.safetensors.index.json", + } + + coverted_model = AutoModelForCausalLM.from_pretrained( + os.path.join(tmp_dir, "test_hf_and_dcp-hf") + ) + original_model = AutoModelForCausalLM.from_pretrained( + simple_policy_config["model_name"] + ) + + ## make sure converted model matches the original + check_dict_equality(coverted_model.state_dict(), original_model.state_dict()) + + policy.worker_group.shutdown() From d69ebb525ec2b24b8cd0f9a42afbf52fa6c7190d Mon Sep 17 00:00:00 2001 From: Charlie Truong Date: Tue, 8 Apr 2025 12:33:37 -0700 Subject: [PATCH 05/11] ci: Add DCO placeholder check for merge queue (#147) Signed-off-by: Charlie Truong Signed-off-by: Parth Chadha --- .github/workflows/cicd-main.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 0bfc2aae9d..0c46a55f16 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -182,3 +182,10 @@ jobs: echo "$SUMMARY" >> $GITHUB_STEP_SUMMARY test "$ALL_SUCCESS" = "true" || test "$CI_SKIP" = "true" + + DCO_merge_group: + name: DCO + if: github.event_name == 'merge_group' + runs-on: ubuntu-latest + steps: + - run: echo "The actual DCO check happens on PRs only. This is a placeholder for the merge queue to keep the DCO check as a required status check." From a5f1a51e3f4ab6d2ec5d6480d108b3aa2adc2eb2 Mon Sep 17 00:00:00 2001 From: Charlie Truong Date: Wed, 9 Apr 2025 13:38:22 -0700 Subject: [PATCH 06/11] ci: Clarify DCO check in merge_group (#154) Signed-off-by: Charlie Truong Signed-off-by: Parth Chadha --- .github/workflows/cicd-main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index 0c46a55f16..4041f31bc3 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -188,4 +188,4 @@ jobs: if: github.event_name == 'merge_group' runs-on: ubuntu-latest steps: - - run: echo "The actual DCO check happens on PRs only. This is a placeholder for the merge queue to keep the DCO check as a required status check." + - run: echo "The real DCO check happens on PRs only. This is a placeholder for the merge queue to keep the DCO check as a required status check." From 6998632e41f525aedba10a44c32ce65846e49528 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 9 Apr 2025 15:16:39 -0700 Subject: [PATCH 07/11] fix: host ip resolution uses ray vs socket (#153) Signed-off-by: Terry Kong Signed-off-by: Parth Chadha --- .../distributed/virtual_cluster.py | 3 +- .../unit/distributed/test_virtual_cluster.py | 32 +++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) create mode 100644 tests/unit/distributed/test_virtual_cluster.py diff --git a/nemo_reinforcer/distributed/virtual_cluster.py b/nemo_reinforcer/distributed/virtual_cluster.py index 86c9ca4661..4f19fb821f 100644 --- a/nemo_reinforcer/distributed/virtual_cluster.py +++ b/nemo_reinforcer/distributed/virtual_cluster.py @@ -52,8 +52,7 @@ def _get_node_ip_and_free_port(): import socket # Get the IP address of the current node - # Use socket.gethostbyname(socket.gethostname()) as a fallback - node_ip = socket.gethostbyname(socket.gethostname()) + node_ip = ray._private.services.get_node_ip_address() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) # Bind to port 0 to get a random free port diff --git a/tests/unit/distributed/test_virtual_cluster.py b/tests/unit/distributed/test_virtual_cluster.py new file mode 100644 index 0000000000..4d01dd24f0 --- /dev/null +++ b/tests/unit/distributed/test_virtual_cluster.py @@ -0,0 +1,32 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nemo_reinforcer.distributed.virtual_cluster import ( + _get_node_ip_and_free_port, + PY_EXECUTABLES, +) +import ray + + +def test_get_node_ip_and_free_port_does_not_start_with_zero(): + # This test covers a case where the hostname was an integer like "255" + # and socket returned an ip address equivalent to this hostname, i.e., "0.0.0.255". + # It's not possible to mock the way the hostname is actually set on other platforms, + # so we leave this test so we can ask users to run on their environment if needed. + + node_ip, _ = ray.get( + _get_node_ip_and_free_port.options( + runtime_env={"py_executable": PY_EXECUTABLES.SYSTEM} + ).remote() + ) + assert not node_ip.startswith("0."), "Node IP should not start with 0.*.*.*" From c5544fe689891a4baeb15094d7c5dd6804beb69f Mon Sep 17 00:00:00 2001 From: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Date: Thu, 10 Apr 2025 22:38:40 -0700 Subject: [PATCH 08/11] test: Add grpo/reinforce/ppo loss tests (prep for incoming vocab parallel changes) (#162) Signed-off-by: Sahil Jain Signed-off-by: Parth Chadha --- tests/unit/algorithms/test_loss_functions.py | 325 ++++++++++++++++++- 1 file changed, 320 insertions(+), 5 deletions(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index fe874ecc26..af78baf34d 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -13,7 +13,14 @@ # limitations under the License. import pytest import torch -from nemo_reinforcer.algorithms.loss_functions import NLLLoss +import numpy as np + +from nemo_reinforcer.algorithms.loss_functions import NLLLoss, ClippedPGLossFn +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.algorithms.utils import ( + calculate_kl_penalty_joschu2020, + masked_mean, +) def test_nll_loss(): @@ -46,7 +53,7 @@ def test_nll_loss(): .to("cuda") ) loss, metrics_dict = loss_fn(next_token_logits, data) - torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0)) + torch.testing.assert_close(loss.cpu(), torch.tensor(0.0)) # Check the metrics dictionary contains the expected values assert metrics_dict["num_unmasked_tokens"] == 2 assert metrics_dict["total_tokens"] == 3 @@ -66,8 +73,316 @@ def test_nll_loss(): ) loss, metrics_dict = loss_fn(next_token_logits, data) ## loss per token is 999, and we have two unmasked tokens - ## with the updated loss function, we now average the loss over unmasked tokens - torch.testing.assert_allclose(loss.cpu(), torch.tensor(999.0)) - # Check the metrics dictionary contains the expected values + ## NLLLoss averages the loss over unmasked tokens + torch.testing.assert_close(loss.cpu(), torch.tensor(999.0)) assert metrics_dict["num_unmasked_tokens"] == 2 assert metrics_dict["total_tokens"] == 3 + + +def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="cuda"): + """Sets up basic mock data structure. Tests should fill values.""" + input_ids = torch.randint( # Input IDs only needed if original loss fn used + 0, vocab_size, (batch_size, seq_len), dtype=torch.int64, device=device + ) + # Default mask: Mask first token [[0, 1, 1, 1]] + token_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) + token_mask[:, 0] = 0 + # sample_mask needs shape [B] + sample_mask = torch.ones(batch_size, dtype=torch.int64, device=device) + + # Simple default values, tests overwrite these + advantages = torch.zeros((batch_size, seq_len), device=device) + prev_logprobs = torch.zeros((batch_size, seq_len), device=device) + reference_policy_logprobs = torch.zeros((batch_size, seq_len), device=device) + generation_logprobs = torch.zeros((batch_size, seq_len), device=device) + + data = BatchedDataDict( + { + "input_ids": input_ids, # Include for completeness + "token_mask": token_mask, + "sample_mask": sample_mask, + "advantages": advantages, + "prev_logprobs": prev_logprobs, + "reference_policy_logprobs": reference_policy_logprobs, + "generation_logprobs": generation_logprobs, + } + ) + # Return seq_len and vocab_size needed by tests + return data, seq_len, vocab_size + + +# Helper to create logits that yield specific target log probs after log_softmax +def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, device): + """Constructs logits such that log_softmax results in target_curr_lp_masked.""" + dummy_logits = torch.full( + (1, seq_len, vocab_size), -100.0, device=device + ) # Start very low + + # Loss fn uses logits[:, :-1] and gathers based on next_tokens = input_ids[:, 1:] + # We need to set logits for indices i=0..S-2 of the sliced logits tensor. + # These correspond to target logprobs at indices 0..S-2 of target_curr_lp_masked. + num_effective_pos = target_curr_lp_masked.shape[1] + for i in range(num_effective_pos): + logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :]) + data_idx = i + 1 # Index in the original input_ids to find the target token + + target_token_id = input_ids[0, data_idx].item() + # Keep target_lp as a 0-dim tensor for torch ops + target_lp = target_curr_lp_masked[0, i] + + # Handle target_lp = 0 case separately + if torch.isclose(target_lp, torch.tensor(0.0, device=device)): + dummy_logits[0, logit_idx, target_token_id] = 100.0 # Large positive logit + elif target_lp < 0: + # Set target token logit to 0 + dummy_logits[0, logit_idx, target_token_id] = 0.0 + # Set one distractor token logit using the formula + distractor_token_id = (target_token_id + 1) % vocab_size + # Ensure distractor isn't same as target if vocab_size=1 (edge case) + if distractor_token_id == target_token_id: + distractor_token_id = (target_token_id + 2) % vocab_size + distractor_logit = torch.log(torch.exp(-target_lp) - 1.0) + dummy_logits[0, logit_idx, distractor_token_id] = distractor_logit + else: # target_lp > 0 is not supported by this method + raise ValueError( + "Target log probability must be negative or zero for this construction" + ) + return dummy_logits + + +# Simplified PPO Clipping Test using original Loss +def test_clipped_pg_loss_ppo_clipping(): + """Tests PPO clipping calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_eps = 0.2 + cfg = { + "ratio_eps_min": ratio_eps, + "ratio_eps_max": ratio_eps, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + # Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0 + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping + # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) + # Curr = log(Ratio) + Prev + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + # Fill full tensors (only need first dim for B=1) + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + + # --- Hand Calculation --- + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # approx [0.5, 1.0, 1.5] + ratios_clamped = torch.clamp( + ratios, 1.0 - ratio_eps, 1.0 + ratio_eps + ) # [0.8, 1.0, 1.2] + loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + expected_loss = torch.mean( + max_loss + ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Simplified REINFORCE Test using original Loss +def test_clipped_pg_loss_reinforce_mode(): + """Tests REINFORCE mode calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + cfg = { + "disable_ppo_ratio": True, + "reference_policy_kl_penalty": 0.0, + "ratio_eps_min": 0.0, # Placeholder, ignored + "ratio_eps_max": 0.0, # Placeholder, ignored + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + curr_lp_masked = torch.tensor([[-0.5, -1.0, -1.5]], device=device) + + data["advantages"][0, 1:] = adv_masked + data["_test_curr_logprobs"] = curr_lp_masked + data["prev_logprobs"][0, 1:] = torch.zeros_like(curr_lp_masked) + + # --- Hand Calculation --- + expected_loss_per_token = -adv_masked * curr_lp_masked # [0.5, -1.0, 3.0] + expected_loss = torch.mean(expected_loss_per_token) # 2.5 / 3 = 0.8333 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Simplified KL Penalty Test using original Loss +def test_clipped_pg_loss_kl_penalty(): + """Tests KL penalty calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + # --- Test Setup --- + kl_beta = 0.1 + cfg = { + "reference_policy_kl_penalty": kl_beta, + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) + curr_lp_masked = torch.tensor([[0.0, -1.0, -2.0]], device=device) + ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + prev_lp_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) + + data["advantages"][0, 1:] = adv_masked + data["reference_policy_logprobs"][0, 1:] = ref_lp_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + data["_test_curr_logprobs"] = curr_lp_masked + + # --- Hand Calculation --- + # Actor loss is 0. Total loss = kl_beta * mean(kl_term) + # kl_term = exp(ref - curr) - (ref - curr) - 1 + r = ref_lp_masked - curr_lp_masked # [-1.0, 0.0, 1.0] + kl_term_per_token = torch.exp(r) - r - 1 # [0.368, 0.0, 0.718] + expected_kl_mean = torch.mean(kl_term_per_token) # 0.362 + expected_loss = kl_beta * expected_kl_mean # 0.0362 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Masking tests - Should work with original Loss Fn if needed, but less critical +def test_clipped_pg_loss_masking(): + """Tests the effect of token_mask and sample_mask.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + batch_size = 2 + seq_len = 4 + device = "cuda" + # Use original loss function for masking tests, as it involves interactions + # that the Testable class might obscure slightly. + data, seq_len, vocab_size = _setup_clipped_pg_test_data( + batch_size=batch_size, seq_len=seq_len, device=device + ) + # Need some realistic-ish logits and logprobs for masking test + dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) + # Ensure logprobs used by the loss fn make sense relative to advantages + data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 + data["reference_policy_logprobs"] = ( + torch.randn_like(data["reference_policy_logprobs"]) * 0.1 + ) + # Make advantages non-zero + data["advantages"] = torch.randn_like(data["advantages"]) + 1.0 + + cfg = { + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + + # --- Test 1: Token Mask --- + # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample + loss_default, _ = loss_fn(dummy_logits, data) + + # Modify token_mask for batch item 0 to mask one more token (pos 1) + data_mod_token = data.copy() + data_mod_token["token_mask"] = data["token_mask"].clone() + data_mod_token["token_mask"][0, 1] = ( + 0 # New mask: [[0, 0, 1, 1], [0, 1, 1, 1]] -> 2 tokens sample 0, 3 tokens sample 1 + ) + + loss_token_masked, _ = loss_fn(dummy_logits, data_mod_token) + # Loss should change if a potentially contributing token is masked + assert not torch.isclose(loss_default, loss_token_masked, atol=1e-4), ( + "Token mask did not change loss as expected" + ) + + # --- Test 2: Sample Mask --- + data_mod_sample = data.copy() + data_mod_sample["sample_mask"] = torch.tensor( + [1, 0], dtype=torch.int64, device=device + ) # Ignore item 1 + + loss_sample_masked, _ = loss_fn(dummy_logits, data_mod_sample) + + # Manually create data dict for only batch 0 + data_only_b0_dict = {} + for key, value in data.items(): + if isinstance(value, torch.Tensor): + if key == "sample_mask": + data_only_b0_dict[key] = value[0:1] + else: + data_only_b0_dict[key] = value[0:1] + else: + data_only_b0_dict[key] = value + data_only_b0 = BatchedDataDict(data_only_b0_dict) + + logits_only_b0 = dummy_logits[0:1] + loss_only_b0, _ = loss_fn(logits_only_b0, data_only_b0) + + torch.testing.assert_close(loss_sample_masked, loss_only_b0) + + +def test_clipped_pg_loss_zero_mask(): + """Tests the case where the combined mask sum is zero.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + # Need dummy logits + dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) + + cfg = { + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + + # Set token mask to all zeros + data["token_mask"] = torch.zeros_like(data["token_mask"]) + + loss, _ = loss_fn(dummy_logits, data) + + # Loss should be exactly zero + torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) From ef83276439cae02b35cd01ec1d150c3a80ec6606 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Fri, 11 Apr 2025 10:18:03 -0700 Subject: [PATCH 09/11] fix: always test vllm (#167) Signed-off-by: Parth Chadha --- .../unit/models/generation/test_vllm_generation.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index ba1bade3fc..aadb1fec77 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -43,15 +43,6 @@ } -@pytest.fixture(scope="module") -def check_vllm_available(): - """Skip tests if vLLM is not installed.""" - try: - import vllm # noqa: F401 - except ImportError: - pytest.skip("vLLM not installed") - - @pytest.fixture(scope="module") def cluster(): """Create a virtual cluster for testing.""" @@ -76,7 +67,7 @@ def tokenizer(): @pytest.fixture(scope="function") -def policy(cluster, tokenizer, check_vllm_available): +def policy(cluster, tokenizer): """Initialize the vLLM policy.""" # Create separate configs for each policy vllm_config = basic_vllm_test_config.copy() @@ -126,7 +117,7 @@ def test_input_data(tokenizer): ) -def test_vllm_missing_required_config_key(cluster, check_vllm_available): +def test_vllm_missing_required_config_key(cluster): """Test that an assertion error is raised when a required config key is missing.""" # Create a config missing a required key by removing 'model_name' incomplete_config = basic_vllm_test_config.copy() From 787949d02b0b232293aceb0c7daeec63a8069472 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Fri, 11 Apr 2025 11:31:26 -0700 Subject: [PATCH 10/11] fix: remove dead code Signed-off-by: Parth Chadha --- nemo_reinforcer/models/generation/vllm.py | 8 -------- .../models/generation/vllm_backend.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index 2f5041b8ee..1dd9b9847c 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -18,7 +18,6 @@ import ray import torch -from transformers import AutoTokenizer from nemo_reinforcer.models.generation.interfaces import ( GenerationInterface, @@ -94,8 +93,6 @@ def configure_worker( env_vars["RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES"] = "1" init_kwargs["fraction_of_gpus"] = num_gpus - # Force vllm to use v0 runtime (will be enabled by default in #51) - # env_vars["VLLM_USE_V1"] = "0" env_vars["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" return resources, env_vars, init_kwargs @@ -136,10 +133,6 @@ def __init__( try: import vllm - # from vllm import LLM, SamplingParams - # from nemo_reinforcer.models.generation.vllm_backend import ( - # UpdatableVllmInternalWorker, - # ) self.SamplingParams = vllm.SamplingParams except ImportError: @@ -184,7 +177,6 @@ def __init__( max_model_len=self.cfg["vllm_cfg"]["max_model_len"], trust_remote_code=True, worker_extension_cls="nemo_reinforcer.models.generation.vllm_backend.VllmInternalWorkerExtension", - # worker_cls=UpdatableVllmInternalWorker, enable_sleep_mode=True, disable_log_stats=True, **vllm_kwargs, diff --git a/nemo_reinforcer/models/generation/vllm_backend.py b/nemo_reinforcer/models/generation/vllm_backend.py index 6114194bc8..a7fd12aa26 100644 --- a/nemo_reinforcer/models/generation/vllm_backend.py +++ b/nemo_reinforcer/models/generation/vllm_backend.py @@ -13,14 +13,14 @@ # limitations under the License. import torch -# try: -# from vllm.worker.worker import Worker -# except ImportError: -# raise ImportError( -# "vLLM is not installed. Please install it with `pip install nemo-reinforcer[vllm]` " -# "or `pip install vllm` separately. This issue may also occur if worker is using incorrect " -# "py_executable." -# ) +try: + import vllm +except ImportError: + raise ImportError( + "vLLM is not installed. Please install it with `pip install nemo-reinforcer[vllm]` " + "or `pip install vllm` separately. This issue may also occur if worker is using incorrect " + "py_executable." + ) class VllmInternalWorkerExtension: From 146c085590c8b44a963aece8997fd848c8b044a2 Mon Sep 17 00:00:00 2001 From: Parth Chadha Date: Fri, 11 Apr 2025 14:22:37 -0700 Subject: [PATCH 11/11] feat: add a unique seed for each vllm llm engine Signed-off-by: Parth Chadha --- nemo_reinforcer/distributed/worker_groups.py | 6 +++--- nemo_reinforcer/models/generation/vllm.py | 20 ++++++++++++++++++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/nemo_reinforcer/distributed/worker_groups.py b/nemo_reinforcer/distributed/worker_groups.py index d4ec9d7f1a..4e3bbbf2a6 100644 --- a/nemo_reinforcer/distributed/worker_groups.py +++ b/nemo_reinforcer/distributed/worker_groups.py @@ -91,7 +91,7 @@ def __call__( placement_group: PlacementGroup, placement_group_bundle_index: int, num_gpus: int, - bundle_indices: Optional[list] = None, + bundle_indices: Optional[tuple] = None, **extra_options: Dict[str, Any], ): """Create a Ray worker with the specified configuration. @@ -108,7 +108,7 @@ def __call__( placement_group: Ray placement group for resource allocation placement_group_bundle_index: Index of the bundle in the placement group num_gpus: Number of GPUs to allocate to this worker - bundle_indices: List of bundle indices for tensor parallelism (if applicable) + bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) extra_options: Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method) Returns: @@ -300,7 +300,7 @@ def _create_workers_from_bundle_indices( # For tensor parallel groups, only the first worker gets bundle_indices worker_bundle_indices = ( - local_bundle_indices if local_rank == 0 else None + (node_idx, local_bundle_indices) if local_rank == 0 else None ) # Create a descriptive name based on group structure diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index 3f8528f549..9c2402f6db 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -61,7 +61,7 @@ def __repr__(self): @staticmethod def configure_worker( - num_gpus: int | float, bundle_indices: Optional[list] = None + num_gpus: int | float, bundle_indices: Optional[tuple] = None ) -> tuple[dict, dict, dict]: """Provides complete worker configuration for vLLM tensor parallelism. @@ -70,7 +70,7 @@ def configure_worker( Args: num_gpus: Original GPU allocation for this worker based on the placement group - bundle_indices: Bundle indices for tensor parallelism (if applicable) + bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) Returns: tuple with complete worker configuration: @@ -83,8 +83,22 @@ def configure_worker( init_kwargs = {} env_vars = {} + node_idx = bundle_indices[0] + bundle_indices = bundle_indices[1] + init_kwargs["bundle_indices"] = bundle_indices + """ + compute a unique seed from the node_idx and bundle_indices: + node_idx = 0, bundle_indices = [0, 1, 2, 3] -> seed = 0*1024 + 0 + node_idx = 0, bundle_indices = [4, 5, 6, 7] -> seed = 0*1024 + 1 + node_idx = 1, bundle_indices = [0, 1, 2, 3] -> seed = 1*1024 + 0 + node_idx = 1, bundle_indices = [4, 5, 6, 7] -> seed = 1*1024 + 1 + """ + bundle_id = bundle_indices[0] // len(bundle_indices) + seed = node_idx * 1024 + bundle_id + init_kwargs["seed"] = seed + is_part_of_tp_workers = ( bundle_indices is not None and len(bundle_indices) > 1 ) or bundle_indices is None @@ -104,6 +118,7 @@ def __init__( config: VllmConfig, bundle_indices: Optional[list] = None, fraction_of_gpus: float = 1.0, + seed: Optional[int] = None, ): """Initialize a vLLM worker for distributed inference. @@ -177,6 +192,7 @@ def __init__( gpu_memory_utilization=self.cfg["vllm_cfg"]["gpu_memory_utilization"], enable_prefix_caching=True, dtype="auto", + seed=seed, # Use cuda-graph by default for performance, set to True to use eager execution enforce_eager=False, max_model_len=self.cfg["vllm_cfg"]["max_model_len"],