diff --git a/pyproject.toml b/pyproject.toml index afa83e33c..98dc0d400 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,7 @@ multilingual = [ "pyvi", # for vietnamese tokenizer ] math = ["latex2sympy2_extended==1.0.6"] -translation = ["unbabel-comet>=2.2.0", "sentencepiece"] +translation = ["unbabel-comet>=2.2.0"] wandb = ["wandb"] trackio = ["trackio"] diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 8ac6fb0af..0271e4fc0 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -28,7 +28,8 @@ from typing import Coroutine, Optional import torch -from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt +from packaging.version import Version +from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt, model_validator from tqdm import tqdm from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset @@ -98,6 +99,16 @@ class VLLMModelConfig(ModelConfig): Number of GPUs to use for data parallelism. Defaults to 1. pipeline_parallel_size (PositiveInt): Number of GPUs to use for pipeline parallelism. Defaults to 1. + prefill_context_parallel_size (PositiveInt): + Number of GPUs to use for prefill context parallelism. Splits long sequences across GPUs + during the prefill phase, reducing peak KV-cache memory. Requires vllm >= 0.15.0 and an + attention backend that sets supports_pcp=True (not available in vllm 0.15.1). + Increases total GPU count by this factor. Defaults to 1 (disabled). + decode_context_parallel_size (PositiveInt): + Number of context parallel groups for the decode phase. Shards the KV cache along + the token dimension, reusing the existing TP GPUs (does not require extra GPUs). + tensor_parallel_size must be divisible by this value. Requires vllm >= 0.15.0. + Defaults to 1 (disabled). gpu_memory_utilization (NonNegativeFloat): Fraction of GPU memory to use. Lower this if running out of memory. Defaults to 0.9. enable_prefix_caching (bool): @@ -161,6 +172,18 @@ class VLLMModelConfig(ModelConfig): tensor_parallel_size: PositiveInt = 1 # how many GPUs to use for tensor parallelism data_parallel_size: PositiveInt = 1 # how many GPUs to use for data parallelism pipeline_parallel_size: PositiveInt = 1 # how many GPUs to use for pipeline parallelism + prefill_context_parallel_size: PositiveInt = 1 # context parallelism for prefill phase (requires vllm >= 0.15.0) + decode_context_parallel_size: PositiveInt = 1 # context parallelism for decode phase (requires vllm >= 0.15.0) + + @model_validator(mode="after") + def validate_context_parallelism(self) -> "VLLMModelConfig": + if self.decode_context_parallel_size > 1: + if self.tensor_parallel_size % self.decode_context_parallel_size != 0: + raise ValueError( + f"tensor_parallel_size ({self.tensor_parallel_size}) must be divisible by " + f"decode_context_parallel_size ({self.decode_context_parallel_size})." + ) + return self gpu_memory_utilization: NonNegativeFloat = 0.9 # lower this if you are running out of memory enable_prefix_caching: bool = None # whether to enable prefix caching to speed up generation. May use more memory. Should be disabled for LFM2 max_model_length: PositiveInt | None = ( @@ -196,6 +219,8 @@ def __init__( ) self.data_parallel_size = config.data_parallel_size self.tensor_parallel_size = config.tensor_parallel_size + self.pipeline_parallel_size = config.pipeline_parallel_size + self.prefill_context_parallel_size = config.prefill_context_parallel_size self._add_special_tokens = config.add_special_tokens if config.add_special_tokens is not None else False self._tokenizer = self._create_auto_tokenizer(config) @@ -274,8 +299,36 @@ def _create_auto_model(self, config: VLLMModelConfig) -> Optional[LLM]: if config.load_format is not None: self.model_args["load_format"] = config.load_format + if config.prefill_context_parallel_size > 1 or config.decode_context_parallel_size > 1: + from importlib.metadata import version as get_package_version + + _VLLM_MIN_VERSION_CP = Version("0.15.0") + _vllm_version = Version(get_package_version("vllm")) + if _vllm_version < _VLLM_MIN_VERSION_CP: + raise ValueError( + f"Context parallelism (prefill_context_parallel_size / decode_context_parallel_size) " + f"requires vllm >= {_VLLM_MIN_VERSION_CP}, but the installed version is {_vllm_version}." + ) + if config.prefill_context_parallel_size > 1: + # PCP requires attention backends to set supports_pcp=True. Check this early + # to avoid failing after several minutes of model loading. + try: + from vllm.v1.attention.backend import AttentionImplBase + + if not AttentionImplBase.supports_pcp: + raise NotImplementedError( + f"prefill_context_parallel_size > 1 is not supported by any attention " + f"backend in the installed vllm {_vllm_version}. " + f"Consider using tensor_parallel_size or decode_context_parallel_size instead." + ) + except ImportError: + pass # older vllm layout; let vllm raise its own error + self.model_args["prefill_context_parallel_size"] = config.prefill_context_parallel_size + if config.decode_context_parallel_size > 1: + self.model_args["decode_context_parallel_size"] = config.decode_context_parallel_size + if config.data_parallel_size > 1: - self.model_args["distributed_executor_backend"] = "ray" + self.model_args["distributed_executor_backend"] = "mp" self._batch_size = "auto" if self._max_length is None: @@ -442,7 +495,7 @@ def _generate( if self.data_parallel_size > 1: - @ray.remote(num_gpus=self.tensor_parallel_size) + @ray.remote(num_gpus=self.tensor_parallel_size * self.pipeline_parallel_size * self.prefill_context_parallel_size) def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, requests): llm = LLM(**model_args) return llm.generate(