diff --git a/src/lighteval/metrics/utils/llm_as_judge.py b/src/lighteval/metrics/utils/llm_as_judge.py index 0f9b3315c..7a0364975 100644 --- a/src/lighteval/metrics/utils/llm_as_judge.py +++ b/src/lighteval/metrics/utils/llm_as_judge.py @@ -296,9 +296,7 @@ def __call_transformers(self, prompt): def __call_vllm(self, prompt): tokenized = [self.tokenizer.apply_chat_template(p) for p in prompt] - # Convert token IDs to TokensPrompt format for vLLM v0.15+ - prompts = [{"prompt_token_ids": token_ids} for token_ids in tokenized] - output = self.pipe.generate(prompts=prompts, sampling_params=self.sampling_params, use_tqdm=True) + output = self.pipe.generate(prompts=tokenized, sampling_params=self.sampling_params, use_tqdm=True) outputs = [output.outputs[0].text for output in output] return outputs diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 3100c56b7..a403c1f75 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -439,9 +439,7 @@ def _generate( @ray.remote(num_gpus=self.tensor_parallel_size) def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, requests): llm = LLM(**model_args) - # Convert token IDs to TokensPrompt format for vLLM v0.15+ - prompts = [{"prompt_token_ids": req} for req in requests] - return llm.generate(prompts=prompts, sampling_params=sampling_params) + return llm.generate(prompts=requests, sampling_params=sampling_params) # dispatch requests to all self.data_parallel_size workers, in interleaved fashion # interleaved important to balance context lengths across workers @@ -458,12 +456,8 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r if x is not None ] else: - from vllm.inputs import TokenInputs - - # Convert token IDs to TokensPrompt format for vLLM v0.15+ - prompts = [TokenInputs(prompt_token_ids=token_ids) for token_ids in inputs] outputs = self.model.generate( - prompts=prompts, + prompts=inputs, sampling_params=sampling_params, use_tqdm=True, )