Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions proposals/disaggregated-prefill-orchestrated-routing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Disaggregated Prefill Orchestrated Routing

## Table of Contents

- [Summary](#summary)
- [Motivation](#motivation)
- [Proposal](#proposal)

## Summary

This proposal adds a new routing algorithm `disaggregated_prefill_orchestrated` to the vLLM Production Stack router. This enables prefill/decode disaggregation where the router orchestrates the request flow between dedicated prefill and decode pods, forwarding KV cache transfer metadata between them. This complements LMCache-based disaggregated inference by supporting backends with custom `kv_connector` implementations (e.g., NIXL, NCCL).

## Motivation

Disaggregated inference separates compute-heavy prefill from memory-bound decode phases. This architectural pattern is increasingly important for:

- **Independent scaling** - Prefill and decode pods can scale based on different metrics (prompt throughput vs. generation throughput)
- **Heterogeneous hardware** - Prefill and decode can run on different hardware profiles optimized for their workloads
- **Better resource utilization** - Under high concurrency, avoiding co-located P/D reduces resource contention

### Goals

- Add `disaggregated_prefill_orchestrated` as a new routing logic option
- Enable router to identify and route to prefill vs. decode pods via labels
- Orchestrate the P→D request flow, extracting and forwarding KV transfer metadata
- Leverage existing K8s service discovery infrastructure
- Support streaming responses from decode phase

### Non-Goals

- Modifying LMCache-based disaggregated inference
- Implementing the underlying KV cache transfer mechanism (handled by vLLM backends)
- Autoscaling logic (handled by KEDA with vLLM metrics)
- Supporting non-Kubernetes deployments in this initial implementation

## Proposal

### Two Disaggregated Inference Approaches

| Approach | KV Transfer | Router Role | Use Case |
|----------|-------------|-------------|----------|
| **LMCache-based DI** | LMCache + NIXL | Transparent routing | GPU clusters with LMCache |
| **Router-orchestrated DI** (this proposal) | vLLM native `kv_transfer_config` | Orchestrates P→D flow | Any backend with kv_connector |

### Proposed Changes

**Architecture:**

```
┌──────────┐ ① ┌─────────────────────────────────────┐
│ Client │────────────────────▶│ Router (disaggregated_prefill_ │
│ Request │ │ orchestrated) │
└──────────┘ └──────────────────┬──────────────────┘
② │ ③
┌──────────────┐ │ ┌──────────────┐
│ Prefill │◀─────────┼─────│ Decode │
│ Pod │ │ │ Pod │
│ │──────────┼────▶│ │
│ (producer) │ KV ID │ │ (consumer) │
└──────────────┘ │ └──────────────┘
┌──────────┐ ④ │
│ Stream │◀───────────────────────────────────────┘
│ Response │
└──────────┘
```

**Request Flow:**
1. Client sends `/v1/chat/completions` to Router
2. Router forwards to Prefill pod with `max_tokens=1`
3. Prefill returns KV transfer ID in `kv_transfer_params` field
4. Router forwards to Decode pod with original `max_tokens` + transfer metadata
5. Decode streams response back to client

### Implementation Details/Notes/Constraints

**Architecture / Components:**
- `src/vllm_router/routers/routing_logic.py` - New `DisaggregatedPrefillOrchestratedRouter` class
- `src/vllm_router/parsers/parser.py` - New CLI arguments for prefill/decode labels
- `src/vllm_router/services/request_service/request.py` - New `route_orchestrated_disaggregated_request()` function

**Interface Changes:**

New CLI arguments:
| Argument | Description |
|----------|-------------|
| `--routing-logic=disaggregated_prefill_orchestrated` | Enable orchestrated disaggregated routing |
| `--prefill-model-labels=prefill` | Model label to identify prefill pods |
| `--decode-model-labels=decode` | Model label to identify decode pods |

Pod labels required:
```yaml
# Prefill deployment
metadata:
labels:
app: prefill
model: prefill

# Decode deployment
metadata:
labels:
app: decode
model: decode
```

**Performance Considerations:**
- Adds one HTTP round-trip (router→prefill) before decode streaming begins
- Prefill request is non-streaming (`max_tokens=1`) to get KV transfer ID
- Decode request streams normally
- No additional memory overhead in router

**Resource Constraints:**
- Minimal CPU overhead for JSON parsing of prefill response
- No GPU resources required by router

### Test plans

**Unit Tests:**
- Test `DisaggregatedPrefillOrchestratedRouter.route()` returns correct endpoints
- Test prefill/decode endpoint filtering based on model labels
- Test KV transfer params extraction from prefill response

**Integration/E2E Tests:**
- Deploy prefill + decode + router pods
- Send chat completion request
- Verify response contains decode output
- Verify logs show correct P→D flow

**Negative Tests:**
- No prefill endpoints available → 503 Service Unavailable
- No decode endpoints available → 503 Service Unavailable
- Prefill response missing `kv_transfer_params` → Error handling

## Drawbacks

- **Added latency** - One additional HTTP round-trip for prefill phase
- **Complexity** - Users must configure prefill/decode pods with correct labels
- **Backend dependency** - Requires vLLM backends to support `kv_transfer_config`

## Alternatives

1. **Do nothing** - Users would need a separate proxy (e.g., toy_proxy_server.py) outside production-stack
2. **Transparent routing only** - Let LMCache handle everything, but this doesn't support custom kv_connectors
3. **gRPC between P/D** - More complex, requires protocol changes

This proposal is the best approach because it:
- Leverages existing router infrastructure
- Follows established routing_logic patterns
- Supports any kv_connector backend
- Enables KEDA-based independent scaling

## Implementation Timeline / Phases

**Phase 1 (Complete):** Core implementation
- DisaggregatedPrefillOrchestratedRouter class
- CLI argument parsing
- Orchestrated request flow

**Phase 2 (TODO):** Testing & Documentation
- Unit tests
- E2E tests
- Documentation update

## References

- [2025 Q1 Roadmap - Support for disaggregated prefill](https://github.com/vllm-project/production-stack/issues/26)
- [vLLM Disaggregated Prefill](https://docs.vllm.ai/en/latest/serving/distributed_serving.html)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "vllm-router"
dynamic = ["version"]
description = "The router for vLLM"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.10"
license = {text = "Apache-2.0"}
classifiers = [
"Operating System :: OS Independent",
Expand Down
1 change: 1 addition & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def parse_args():
"kvaware",
"prefixaware",
"disaggregated_prefill",
"disaggregated_prefill_orchestrated",
],
help="The routing logic to use",
)
Expand Down
101 changes: 101 additions & 0 deletions src/vllm_router/routers/routing_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class RoutingLogic(str, enum.Enum):
KVAWARE = "kvaware"
PREFIXAWARE = "prefixaware"
DISAGGREGATED_PREFILL = "disaggregated_prefill"
DISAGGREGATED_PREFILL_ORCHESTRATED = "disaggregated_prefill_orchestrated"


class RoutingInterface(metaclass=SingletonABCMeta):
Expand Down Expand Up @@ -515,6 +516,99 @@ def route_request(
return decoder_endpoints[0].url


class DisaggregatedPrefillOrchestratedRouter(RoutingInterface):
"""
Orchestrates disaggregated inference in a single request by chaining Prefill → Decode.

Unlike DisaggregatedPrefillRouter (which requires 2 separate client requests),
this router handles the entire flow internally:
1. Receives request from client
2. Forwards to Prefill endpoint
3. Gets prefill response with KV cache metadata
4. Adds disagg_prefill_resp to request and forwards to Decode
5. Streams decode response back to client

This is designed for NxDI (Neuronx Distributed Inference) on AWS Trainium,
similar to NxDI's toy_proxy_server.py pattern.

Load balancing: Uses round-robin across available prefill and decode pods.
"""

def __init__(self, prefill_model_labels: List[str], decode_model_labels: List[str]):
if hasattr(self, "_initialized"):
return
self.prefill_model_labels = prefill_model_labels or []
self.decode_model_labels = decode_model_labels or []
# Round-robin counters for load balancing across xPyD pods
self.prefill_idx = 0
self.decode_idx = 0
self._initialized = True
logger.info(
f"Initialized DisaggregatedPrefillOrchestratedRouter with "
f"prefill_labels={self.prefill_model_labels}, "
f"decode_labels={self.decode_model_labels}"
)

def _find_endpoints(self, endpoints: List[EndpointInfo]):
"""Find prefill and decode endpoints based on model labels."""
prefiller_endpoints = [
e for e in endpoints if e.model_label in self.prefill_model_labels
]
decoder_endpoints = [
e for e in endpoints if e.model_label in self.decode_model_labels
]

if not prefiller_endpoints:
raise ValueError(
f"No prefill endpoints found with labels {self.prefill_model_labels}. "
f"Available endpoints: {[(e.url, e.model_label) for e in endpoints]}"
)
if not decoder_endpoints:
raise ValueError(
f"No decode endpoints found with labels {self.decode_model_labels}. "
f"Available endpoints: {[(e.url, e.model_label) for e in endpoints]}"
)

return prefiller_endpoints, decoder_endpoints

def select_prefill_endpoint(self, prefiller_endpoints: List[EndpointInfo]) -> EndpointInfo:
"""Select prefill endpoint using round-robin load balancing."""
if not prefiller_endpoints:
raise ValueError("No prefill endpoints available")
# Sort for consistency across requests
sorted_endpoints = sorted(prefiller_endpoints, key=lambda e: e.url)
selected = sorted_endpoints[self.prefill_idx % len(sorted_endpoints)]
self.prefill_idx += 1
return selected

def select_decode_endpoint(self, decoder_endpoints: List[EndpointInfo]) -> EndpointInfo:
"""Select decode endpoint using round-robin load balancing."""
if not decoder_endpoints:
raise ValueError("No decode endpoints available")
# Sort for consistency across requests
sorted_endpoints = sorted(decoder_endpoints, key=lambda e: e.url)
selected = sorted_endpoints[self.decode_idx % len(sorted_endpoints)]
self.decode_idx += 1
return selected

async def route_request(
self,
endpoints: List[EndpointInfo],
engine_stats: Dict[str, EngineStats],
request_stats: Dict[str, RequestStats],
request: Request,
request_json: Dict,
) -> str:
"""
This method is called by the router framework but for orchestrated routing,
we need to handle the full flow differently. This returns the prefill URL
as a placeholder - the actual orchestration happens in route_orchestrated_disaggregated_request.
"""
prefiller_endpoints, _ = self._find_endpoints(endpoints)
# Return prefill URL - actual orchestration is done in request.py
return prefiller_endpoints[0].url


# Instead of managing a global _global_router, we can define the initialization functions as:
def initialize_routing_logic(
routing_logic: RoutingLogic, *args, **kwargs
Expand Down Expand Up @@ -542,6 +636,11 @@ def initialize_routing_logic(
return DisaggregatedPrefillRouter(
kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels")
)
elif routing_logic == RoutingLogic.DISAGGREGATED_PREFILL_ORCHESTRATED:
logger.info("Initializing disaggregated prefill orchestrated routing logic (NxDI)")
return DisaggregatedPrefillOrchestratedRouter(
kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels")
)
else:
raise ValueError(f"Invalid routing logic {routing_logic}")

Expand All @@ -562,6 +661,7 @@ def get_routing_logic() -> RoutingInterface:
KvawareRouter,
PrefixAwareRouter,
DisaggregatedPrefillRouter,
DisaggregatedPrefillOrchestratedRouter,
):
if cls in SingletonABCMeta._instances:
return cls()
Expand All @@ -576,6 +676,7 @@ def cleanup_routing_logic():
KvawareRouter,
PrefixAwareRouter,
DisaggregatedPrefillRouter,
DisaggregatedPrefillOrchestratedRouter,
):
if cls in SingletonABCMeta._instances:
instance = cls()
Expand Down
Loading