[2/2] Top-k and Top-p support for dtensor worker with vLLM V0 when TP>1#774
[2/2] Top-k and Top-p support for dtensor worker with vLLM V0 when TP>1#774zhandaz wants to merge 1 commit intozhanda/top-p-kfrom
TP>1#774Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR adds support for top-k and top-p sampling in the dtensor policy worker with vLLM V0 when tensor parallelism (TP) is greater than 1. The implementation introduces a new distributed log softmax function that handles sampling parameters and modifies existing functions to propagate these parameters through the call stack.
- Implements
_compute_distributed_log_softmax_with_samplingto handle top-k/top-p sampling in distributed environments - Adds sampling parameter extraction and propagation through the dtensor policy worker
- Updates function signatures across the distributed model utilities to support sampling parameters
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| nemo_rl/models/policy/dtensor_policy_worker.py | Adds sampling parameter extraction and passes them to logprob computation functions |
| nemo_rl/models/dtensor/parallelize.py | Updates function signature to accept and forward sampling parameters |
| nemo_rl/distributed/model_utils.py | Implements new sampling-aware distributed log softmax and updates all related functions |
| Returns: | ||
| Log softmax output with sampling applied, same shape as input | ||
| """ | ||
| if (top_k is not None and top_k == -1) and (top_p is not None and top_p == 1.0): |
There was a problem hiding this comment.
The condition uses and between two parenthesized conditions, but logically this should be or since either condition being true (top_k disabled OR top_p disabled) should trigger the fallback to regular log softmax.
| if (top_k is not None and top_k == -1) and (top_p is not None and top_p == 1.0): | |
| if (top_k is not None and top_k == -1) or (top_p is not None and top_p == 1.0): |
| log_softmax_output = _compute_distributed_log_softmax( | ||
| vocab_parallel_logits, group=group | ||
| ) | ||
| # Use sampling-aware distributed log softmax if sampling parameters are provided |
There was a problem hiding this comment.
The condition on line 142 uses or logic but the comment suggests both parameters need to be provided. The logic is correct (either parameter being active should use sampling), but the comment is misleading.
| # Use sampling-aware distributed log softmax if sampling parameters are provided | |
| # Use sampling-aware distributed log softmax if either top_k or top_p is provided |
|
|
||
| Args: | ||
| vocab_parallel_logits (orch.Tensor): Logits distributed across tensor parallel workers, | ||
| vocab_parallel_logits (torch.Tensor): Logits distributed across tensor parallel workers, |
There was a problem hiding this comment.
The docstring has a typo - 'orch.Tensor' was partially corrected to 'torch.Tensor' but the diff shows this was already fixed.
What does this PR do ?
tldr: Support top-k and top-p for dtensor worker with vLLM v0. This pr supports
tp>1on top of #773.Instead of using
_compute_distributed_log_softmax, we implement_compute_distributed_log_softmax_with_samplingto support Top-k and Top-p when TP is enabled.Note: This change depends on #773 and should be merged after it. We should also decide if we want to merge this implementation or, alternatively, add a warning to users about a potential mismatch between this inference logic and the logic used in policy training for vLLM engine V0 and dtensor with TP>1.
Tests for distributed functionalities and docs will be added after we make the decision.
Issues
Related Issue: #69
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks: