From 19a83ffae75f01403b7093de0b67b0a9a1af00c9 Mon Sep 17 00:00:00 2001 From: KiddoZhu Date: Mon, 9 Jun 2025 16:32:48 -0700 Subject: [PATCH] allow uneven shards in generate_text Signed-off-by: KiddoZhu --- nemo_rl/models/generation/vllm.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 59fcc26320..85b5182148 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -592,11 +592,8 @@ def generate_text( f"data must be a BatchedDataDict, got type: {type(data)}" ) - # Get total batch size - batch_size = len(data["prompts"]) - # Shard the data across the tied worker groups - sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=batch_size) + sharded_data = data.shard_by_batch_size(self.dp_size, allow_uneven_shards=True) future_bundle = self.worker_group.run_all_workers_multiple_data( "generate_text", sharded_data,