From 857eb0261a446af7071aeb663ee0af0d362544b6 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Fri, 13 Jun 2025 11:03:25 -0700 Subject: [PATCH 1/5] Create 2.long_generation_decode_vs_prefill.py test Signed-off-by: Luis Vega --- .../2.long_generation_decode_vs_prefill.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 tools/model_diagnostics/2.long_generation_decode_vs_prefill.py diff --git a/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py b/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py new file mode 100644 index 0000000000..9f69179f0e --- /dev/null +++ b/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import argparse +from vllm import LLM, SamplingParams + + +def extract_logprobs(logprobs): + output = [] + for lp in logprobs: + if lp is not None: + output.append(list(lp.values())[0].logprob) + return output + + +def calculate_error(a, b): + return torch.exp(torch.abs(a - b)).mean().item() + + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", type=str, nargs="?", default="nvidia/Nemotron-H-8B-Base-8K" + ) + args = parser.parse_args() + + seed = 0 + + sampling_params = SamplingParams( + temperature=1.0, + top_p=1.0, + max_tokens=8192, + prompt_logprobs=0, + logprobs=0, + seed=seed, + ) + + # Examples as of 0.9.1 + # model="meta-llama/Meta-Llama-3-8B", # pass + # model="nvidia/Nemotron-H-8B-Base-8K", # fail + # model="ibm-ai-platform/Bamba-9B-v1", # pass + llm = LLM( + model=args.model, + enforce_eager=True, + trust_remote_code=True, + enable_prefix_caching=False, + enable_chunked_prefill=False, + tensor_parallel_size=2, + gpu_memory_utilization=0.8, + seed=seed, + ) + + num_batches = 2 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + outputs = llm.generate(prompts * num_batches, sampling_params) + + for i, output in enumerate(outputs): + + sequence = output.prompt_token_ids + list(output.outputs[0].token_ids) + prompt_logprobs = extract_logprobs(output.prompt_logprobs) + logprobs = extract_logprobs(output.outputs[0].logprobs) + decode_lp = prompt_logprobs + logprobs + decode_lp = torch.tensor(decode_lp) + + sampling_params = SamplingParams( + temperature=0.0, max_tokens=1, prompt_logprobs=0 + ) + score = llm.generate({"prompt_token_ids": sequence}, sampling_params) + + prefill_lp = extract_logprobs(score[0].prompt_logprobs) + prefill_lp = torch.tensor(prefill_lp) + + lp_error = calculate_error(decode_lp, prefill_lp) + max_abs_error = torch.abs(decode_lp - prefill_lp).max().item() + print( + f"Processed sequence length {len(sequence)} with lp error {lp_error} and max abs error {max_abs_error}" + ) + assert ( + lp_error < 1.0636 + ), f"lp error is higher than expected (1.0636): {lp_error}" + + print(f"[{args.model}] ALL GOOD!") + + +if __name__ == "__main__": + main() From e40b1d3a9c56e4be2228863958f7b4fe2cdc6bfc Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Mon, 16 Jun 2025 10:45:19 -0700 Subject: [PATCH 2/5] Change threshold value Signed-off-by: Luis Vega --- tools/model_diagnostics/2.long_generation_decode_vs_prefill.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py b/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py index 9f69179f0e..cc5ff2a70a 100644 --- a/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py +++ b/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py @@ -95,7 +95,7 @@ def main(): f"Processed sequence length {len(sequence)} with lp error {lp_error} and max abs error {max_abs_error}" ) assert ( - lp_error < 1.0636 + lp_error < 1.05 ), f"lp error is higher than expected (1.0636): {lp_error}" print(f"[{args.model}] ALL GOOD!") From bc99af41693586997dbad61d9937a4843c8a83d1 Mon Sep 17 00:00:00 2001 From: Luis Vega Date: Mon, 16 Jun 2025 10:53:13 -0700 Subject: [PATCH 3/5] Update adding-new-models.md Signed-off-by: Luis Vega --- docs/adding-new-models.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/adding-new-models.md b/docs/adding-new-models.md index c73d494907..26ab6c5934 100644 --- a/docs/adding-new-models.md +++ b/docs/adding-new-models.md @@ -140,4 +140,15 @@ uv run --extra vllm tools/model_diagnostics/1.max_model_len_respected.py Qwen/Qw # Generated tokens: 12 # Total tokens: 20 # [Qwen/Qwen2.5-1.5B] ALL GOOD! -``` \ No newline at end of file +``` + +## [2.long_generation_decode_vs_prefill](https://github.com/NVIDIA/NeMo-RL/blob/main/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py) + +Test that vLLM yields near-identical token log-probabilities when comparing decoding with a single prefill pass across multiple prompts + +```sh +# Run that is expected to pass +uv run --extra vllm tools/model_diagnostics/2.long_generation_decode_vs_prefill.py Qwen/Qwen2.5-1.5B +# ... +# [Qwen/Qwen2.5-1.5B] ALL GOOD! +``` From 92a4e05c94160b512c18071cf8f1cee87a9f08a0 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Mon, 16 Jun 2025 10:58:25 -0700 Subject: [PATCH 4/5] Update docs/adding-new-models.md Signed-off-by: Terry Kong --- docs/adding-new-models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/adding-new-models.md b/docs/adding-new-models.md index 26ab6c5934..155a012f47 100644 --- a/docs/adding-new-models.md +++ b/docs/adding-new-models.md @@ -144,7 +144,7 @@ uv run --extra vllm tools/model_diagnostics/1.max_model_len_respected.py Qwen/Qw ## [2.long_generation_decode_vs_prefill](https://github.com/NVIDIA/NeMo-RL/blob/main/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py) -Test that vLLM yields near-identical token log-probabilities when comparing decoding with a single prefill pass across multiple prompts +Test that vLLM yields near-identical token log-probabilities when comparing decoding with a single prefill pass across multiple prompts. ```sh # Run that is expected to pass From d83072363b08758c8f40973e4299c7e0a7cf2fd7 Mon Sep 17 00:00:00 2001 From: Luis Vega <2478335+vegaluisjose@users.noreply.github.com> Date: Thu, 26 Jun 2025 15:44:21 -0700 Subject: [PATCH 5/5] fixed formatting in test Signed-off-by: Luis Vega <2478335+vegaluisjose@users.noreply.github.com> --- .../2.long_generation_decode_vs_prefill.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py b/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py index cc5ff2a70a..69c153fd53 100644 --- a/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py +++ b/tools/model_diagnostics/2.long_generation_decode_vs_prefill.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch import argparse + +import torch from vllm import LLM, SamplingParams @@ -29,7 +30,6 @@ def calculate_error(a, b): def main(): - parser = argparse.ArgumentParser() parser.add_argument( "--model", type=str, nargs="?", default="nvidia/Nemotron-H-8B-Base-8K" @@ -74,7 +74,6 @@ def main(): outputs = llm.generate(prompts * num_batches, sampling_params) for i, output in enumerate(outputs): - sequence = output.prompt_token_ids + list(output.outputs[0].token_ids) prompt_logprobs = extract_logprobs(output.prompt_logprobs) logprobs = extract_logprobs(output.outputs[0].logprobs) @@ -94,9 +93,7 @@ def main(): print( f"Processed sequence length {len(sequence)} with lp error {lp_error} and max abs error {max_abs_error}" ) - assert ( - lp_error < 1.05 - ), f"lp error is higher than expected (1.0636): {lp_error}" + assert lp_error < 1.05, f"lp error is higher than expected (1.0636): {lp_error}" print(f"[{args.model}] ALL GOOD!")