diff --git a/requirements.txt b/requirements.txt index a9ab80b7c..d31f2e3a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ tokenizers>=0.13.3 peft>=0.10.0 torch>=2.0.1 wandb==0.14.0 -deepspeed==0.10.0 +deepspeed<=0.14.0 trl>=0.7.11 sentencepiece transformers>=4.31.0 @@ -18,8 +18,8 @@ scikit-learn==1.2.2 lm-eval==0.3.0 dill<0.3.5 bitsandbytes>=0.40.0 -pydantic<=1.10.9 +pydantic gradio accelerate>=0.27.2 einops>=0.6.1 -scikit-learn==1.2.2 +vllm>=0.4.1 \ No newline at end of file diff --git a/src/lmflow/args.py b/src/lmflow/args.py index ef91ac1f2..7935a89b5 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -12,15 +12,15 @@ extracted from the MODEL_CONFIG_CLASSES. """ import logging -from dataclasses import dataclass, field -from typing import Optional, List - -from transformers.utils.versions import require_version +from dataclasses import dataclass, field, fields, Field, make_dataclass +from pathlib import Path +from typing import Optional, List, Union, Dict from transformers import ( MODEL_FOR_CAUSAL_LM_MAPPING, TrainingArguments, ) +from transformers.utils.versions import require_version MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) @@ -308,6 +308,7 @@ class ModelArguments: "choices": ["right", "left", "auto"], } ) + def __post_init__(self): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): @@ -838,6 +839,40 @@ class InferencerArguments: repetition_penalty : float An argument of model.generate in huggingface to penalize repetitions. + use_beam_search : Optional[bool] + Whether to use beam search during inference, By default False. + num_output_sequences : Optional[int] + Number of output sequences to return for the given prompt, + currently only used in vllm inference, By default 8. + top_p : Optional[float] + top_p for sampling, By default 1.0. + top_k : Optional[int] + top_k for sampling, By default -1 (no top_k). + additional_stop_token_ids : Optional[List[int]] + the ids of the end of sentence tokens, By default []. + apply_chat_template : Optional[bool] + Whether to apply chat template, By default True. + save_results : Optional[bool] + Whether to save inference results, By default False. + results_path : Optional[str] + The **json file** path of inference results, By default None. + memory_safe_vllm_inference_detokenize : Optional[bool] + Whether to detokenize the memory safe vllm inference results. + + NOTE: For iterative align pipelines, whether to detokenize depends on + the homogeneity of the policy model and the reward model + (i.e., if they have the same tokenizer). + The reason why `detokenize` for memory safe vllm inference is + included in args is due to the its implementation (i.e., subprocess + rather than within the python codes, thus have to communicate through + command line arguments). + use_vllm: bool, optional + Whether to use VLLM for inference, By default False. + vllm_tensor_parallel_size: int, optional + The tensor parallel size for VLLM inference. + vllm_gpu_memory_utilization: float, optional + The GPU memory utilization for VLLM inference. The proportion of GPU + memory (per GPU) to use for VLLM inference. """ device: str = field( default="gpu", @@ -902,6 +937,69 @@ class InferencerArguments: use_accelerator: bool = field( default=False, metadata={"help": "Whether to use Huggingface Accelerator instead of Deepspeed"}, ) + use_beam_search: Optional[bool] = field( + default=False, + metadata={"help": "whether to use beam search during inference."}, + ) + num_output_sequences: Optional[int] = field( + default=8, + metadata={"help": ( + "number of output sequences to return for the given prompt, " + "currently only used in vllm inference." + )}, + ) + top_p: Optional[float] = field( + default=1.0, + metadata={"help": "top_p for sampling."}, + ) + top_k: Optional[int] = field( + default=-1, + metadata={"help": "top_k for sampling."}, + ) + additional_stop_token_ids: Optional[List[int]] = field( + default_factory=lambda: [], + metadata={"help": "the ids of the end of sentence tokens"}, + ) + apply_chat_template: Optional[bool] = field( + default=True, + metadata={"help": "whether to apply chat template"}, + ) + memory_safe_vllm_inference_detokenize: Optional[bool] = field( + default=False, + metadata={"help": "Whether to detokenize the memory safe vllm inference results."}, + ) + + # vllm inference args + use_vllm: bool = field( + default=False, + metadata={"help": "Whether to use VLLM for inference, By default False."} + ) + vllm_tensor_parallel_size: Optional[int] = field( + default=1, + metadata={"help": "The tensor parallel size for VLLM inference."} + ) + vllm_gpu_memory_utilization: Optional[float] = field( + default=0.95, + metadata={"help": "The GPU memory utilization for VLLM inference."} + ) + + # Args for result saving + save_results: Optional[bool] = field( + default=False, metadata={"help": "Whether to save inference results."} + ) + results_path: Optional[str] = field( + default=None, metadata={"help": "The path of inference results."} + ) + + def __post_init__(self): + if self.save_results: + if self.results_path is None: + raise ValueError("Need to specify results_path when save_results is True.") + else: + if not self.results_path.endswith(".json"): + raise ValueError("The results_path must be a json file.") + else: + Path(self.results_path).parent.mkdir(parents=True, exist_ok=True) @dataclass @@ -1144,13 +1242,21 @@ class DPOAlignerArguments: ) +@dataclass +class IterativeAlignerArguments(InferencerArguments): + """ + Arguments for iterative aligners. + """ + pass + + PIPELINE_ARGUMENT_MAPPING = { "finetuner": FinetunerArguments, "evaluator": EvaluatorArguments, "inferencer": InferencerArguments, "raft_aligner": RaftAlignerArguments, "dpo_aligner": DPOAlignerArguments, - "rm_tuner": RewardModelingArguments + "rm_tuner": RewardModelingArguments, } diff --git a/src/lmflow/datasets/dataset.py b/src/lmflow/datasets/dataset.py index 5ad4fb89d..826217f48 100644 --- a/src/lmflow/datasets/dataset.py +++ b/src/lmflow/datasets/dataset.py @@ -61,7 +61,7 @@ class Dataset: kwargs : Optional. Keyword arguments. """ - def __init__(self, data_args=None, backend: str="huggingface", *args, **kwargs): + def __init__(self, data_args: DatasetArguments=None, backend: str="huggingface", *args, **kwargs): self.data_args = data_args self.backend = backend self.backend_dataset = None @@ -263,7 +263,7 @@ def from_dict(self, dict_obj: dict, *args, **kwargs): return self else: raise NotImplementedError( - f'Currently .from_dict is not supported for backend "{backend}"' + f'Currently .from_dict is not supported for backend "{self.backend}"' ) @@ -331,7 +331,7 @@ def to_dict(self): return dict_obj else: raise NotImplementedError( - f'Current .to_dict is not supported for backend "{backend}"' + f'Current .to_dict is not supported for backend "{self.backend}"' ) @@ -347,7 +347,7 @@ def to_list(self): return instance_list else: raise NotImplementedError( - f'Current .to_list is not supported for backend "{backend}"' + f'Current .to_list is not supported for backend "{self.backend}"' ) @@ -376,7 +376,7 @@ def map(self, *args, **kwargs): else: # If the backend is not Hugging Face, raise a NotImplementedError raise NotImplementedError( - f'Currently .map is not supported for backend "{backend}"' + f'Currently .map is not supported for backend "{self.backend}"' ) diff --git a/src/lmflow/models/hf_decoder_model.py b/src/lmflow/models/hf_decoder_model.py index 0f5b1c4d5..ee4e94ff5 100644 --- a/src/lmflow/models/hf_decoder_model.py +++ b/src/lmflow/models/hf_decoder_model.py @@ -21,7 +21,7 @@ import hashlib import logging import os, shutil -from typing import List, Union +from typing import List, Union, Optional, Dict from pathlib import Path import torch @@ -44,6 +44,7 @@ get_peft_model, prepare_model_for_kbit_training ) +from vllm import SamplingParams from lmflow.datasets.dataset import Dataset from lmflow.models.hf_model_mixin import HFModelMixin @@ -332,16 +333,61 @@ def decode(self, input, *args, **kwargs ) -> Union[str, List[str]]: else: # Can be list of ints or a Tensor return self.tokenizer.decode(input, *args, **kwargs) + + + def inference( + self, + inputs, + release_gpu: bool = False, + use_vllm: bool = False, + **kwargs + ): + """ + Perform generation process of the model. + + Parameters + ------------ + inputs : + The sequence used as a prompt for the generation or as model inputs to the model. + When using vllm inference, this should be a string or a list of strings. + When using normal inference, this should be a tensor. + release_gpu : bool, optional + Whether to release the GPU resource after inference, by default False. + use_vllm : bool, optional + Whether to use VLLM for inference, by default False. + kwargs : Optional. + Keyword arguments. + + Returns + ------------ + outputs : + The generated sequence output + """ + if not self._activated: + self.activate_model_for_inference( + use_vllm=use_vllm, + **kwargs, + ) + + if use_vllm: + res = self.__vllm_inference(inputs, **kwargs) + else: + res = self.__inference(inputs, **kwargs) + + if release_gpu: + self.deactivate_model_for_inference(use_vllm=use_vllm) + + return res - def inference(self, inputs, use_accelerator=False, *args, **kwargs): + def __inference(self, inputs, use_accelerator=False, *args, **kwargs): """ Perform generation process of the model. Parameters ------------ inputs : - The sequence used as a prompt for the generation or as model inputs to the model. + The **tokenized** sequence used as a prompt for the generation or as model inputs to the model. args : Optional. Positional arguments. @@ -354,8 +400,6 @@ def inference(self, inputs, use_accelerator=False, *args, **kwargs): outputs : The generated sequence output """ - - with torch.no_grad(): if use_accelerator: outputs = self.backend_model.generate( @@ -386,6 +430,167 @@ def inference(self, inputs, use_accelerator=False, *args, **kwargs): f"device \"{self.device}\" is not supported" ) return outputs + + + def __vllm_inference( + self, + inputs: Union[str, List[str]], + sampling_params: Optional[SamplingParams] = None, + **kwargs, + ) -> Union[List[List[str]], List[List[List[int]]]]: + """Perform VLLM inference process of the model. + + Parameters + ---------- + inputs : Union[str, List[str]] + Prompt(s), string or a list of strings. + sampling_params : Optional[SamplingParams], optional + vllm SamplingParams object, by default None. + + Returns + ------- + Union[List[List[str]], List[List[List[int]]]] + When `sampling_params.detokenize = True`, return a list of list of strings. Inner list + contains sampling_params.n samples for a single prompt (i.e., `len(res[i]) = sampling_params.n`). + Outer list contains the results for all prompts (i.e., `len(res) = len(inputs)`). + + When `sampling_params.detokenize = False`, return a list of list of list of ints + (token ids, no decoding after generation). + """ + vllm_outputs = self.backend_model_for_inference.generate( + inputs, + sampling_params=sampling_params, + use_tqdm=True, + ) + final_output = [] + if sampling_params.detokenize: + for output in vllm_outputs: + final_output.append([sentence.text for sentence in output.outputs]) + else: + for output in vllm_outputs: + final_output.append([sentence.token_ids for sentence in output.outputs]) + + return final_output + + + def prepare_inputs_for_inference( + self, + dataset: Dataset, + apply_chat_template: bool = True, + use_vllm: bool = False, + ) -> Union[List[str], Dict[str, torch.Tensor]]: + """ + Prepare inputs for inference. + + Parameters + ------------ + dataset : lmflow.datasets.Dataset. + The dataset used for inference. + + args : Optional. + Positional arguments. + + kwargs : Optional. + Keyword arguments. + + Returns + ------------ + outputs : + The prepared inputs for inference. + """ + if use_vllm: + inference_inputs = self.__prepare_inputs_for_vllm_inference( + dataset=dataset, + apply_chat_template=apply_chat_template + ) + else: + inference_inputs = self.__prepare_inputs_for_inference(dataset) + + return inference_inputs + + + def __prepare_inputs_for_vllm_inference( + self, + dataset: Dataset, + apply_chat_template: bool = True, + ) -> List[str]: + if dataset.get_type() == 'text_only': + if apply_chat_template: + dataset = dataset.map( + lambda sample: { + "templated": self.tokenizer.apply_chat_template( + [{"role":"user", "content": sample['text']}], + tokenize=False, + add_generation_prompt=True + ) + }, + num_proc=dataset.data_args.preprocessing_num_workers, + ) + inference_inputs = dataset.get_backend_dataset()['templated'] + else: + inference_inputs = dataset.get_backend_dataset()['text'] + + elif dataset.get_type() == "text2text": + logger.warning(f"For a text2text dataset, only `input` will be used as the model input.") + if apply_chat_template: + dataset = dataset.map( + lambda sample: { + "templated": self.tokenizer.apply_chat_template( + conversation=[{"role":"user", "content": sample['input']}], + tokenize=False, + add_generation_prompt=True + ) + }, + num_proc=dataset.data_args.preprocessing_num_workers, + ) + inference_inputs = dataset.get_backend_dataset()['templated'] + else: + inference_inputs = dataset.get_backend_dataset()['input'] + + elif dataset.get_type() == 'conversation': + if apply_chat_template: + def preprocess_conversation(sample): + if len(sample['messages'])%2 == 0: + conversation = sample['messages'][:-1] + + if sample['messages'][-1]['role'] != 'assistant': + logger.warning("Not a valid conversation, skip.") + sample_out = {"templated": ""} + else: + sample_out = {"templated": self.tokenizer.apply_chat_template( + conversation=conversation, + tokenize=False, + add_generation_prompt=True, + )} + + return sample_out + dataset = dataset.map( + preprocess_conversation, + num_proc=dataset.data_args.preprocessing_num_workers, + ) + inference_inputs = dataset.get_backend_dataset()['templated'] + else: + logger.warning( + "Your dataset is `conversation` type but `apply_chat_template` is set to False. " + "Will use the first user input in conversation as model input." + ) + inference_inputs = [conversation[0]['content'] for conversation in dataset.get_backend_dataset()['messages']] + + else: + raise NotImplementedError( + f"Currently `{dataset.get_type()}` data are not supported for vllm inference." + ) + + inference_inputs = [sentence for sentence in inference_inputs if len(sentence) > 0] + + return inference_inputs + + + def __prepare_inputs_for_inference( + self, + dataset: Dataset, + ): + raise NotImplementedError("prepare_inputs_for_inference is not implemented") def merge_lora_weights(self): diff --git a/src/lmflow/models/hf_model_mixin.py b/src/lmflow/models/hf_model_mixin.py index 0d7e73dd0..2721c1947 100644 --- a/src/lmflow/models/hf_model_mixin.py +++ b/src/lmflow/models/hf_model_mixin.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # coding=utf-8 # Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. +import gc import os import logging from typing import Union, Optional, Dict @@ -25,6 +26,8 @@ prepare_model_for_kbit_training ) from peft.utils.constants import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING +from vllm import LLM +from vllm.distributed.parallel_state import destroy_model_parallel from lmflow.models.base_model import BaseModel from lmflow.utils.constants import ( @@ -88,20 +91,22 @@ def __init__( self.device = device self.model_args = model_args + self.hf_auto_model = HF_AUTOMODEL_MAPPING[model_args.arch_type] + self.use_accelerator = use_accelerator + self.ds_config = ds_config + self.tokenizer = self.__prepare_tokenizer(model_args) self.torch_dtype = self.__prepare_dtype(model_args) self.hf_model_config = self.__prepare_model_config(model_args, hf_auto_model_additional_args) self.quant_config = self.__prepare_quant_config(model_args) self.peft_config = self.__prepare_peft_config(model_args) + self._activated = False # for inference load and offload # Some implementations require custom modules to be injected into the model. self.__model_module_inject(model_args) - hf_auto_model = HF_AUTOMODEL_MAPPING[model_args.arch_type] if do_train: - self.__prepare_model_for_training(model_args, hf_auto_model) - else: - self.__prepare_model_for_inference(model_args, hf_auto_model, use_accelerator, ds_config) + self.__prepare_model_for_training(model_args, self.hf_auto_model) # some post processing if self.tokenizer.eos_token_id is None: @@ -352,81 +357,157 @@ def __prepare_model_for_inference( self, model_args: ModelArguments, hf_auto_model: HF_AUTOMODEL_TYPE, - use_accelerator, + use_accelerator: bool, ds_config ): - # TODO: change to accelerate - logger.info("Preparing model for inference") - if use_accelerator: - peft_model_id = model_args.lora_model_path - self.backend_model = hf_auto_model.from_pretrained( - model_args.model_name_or_path, - config=self.hf_model_config, - device_map="auto", - offload_folder="offload", - offload_state_dict=True, - load_in_8bit = model_args.use_int8, - ) - if peft_model_id is not None: - self.backend_model = PeftModel.from_pretrained( - self.backend_model, - peft_model_id, - ) - else: - from transformers.integrations import HfDeepSpeedConfig - dschf = HfDeepSpeedConfig(ds_config) - peft_model_id = model_args.lora_model_path - # NOTE: Currently offload is not supported by llama - if self.hf_model_config.model_type == "llama" and model_args.use_ram_optimized_load: - logger.warning( - "llama does not support RAM optimized load. Automatically" - " use original load instead." - ) - model_args.use_ram_optimized_load = False - - if model_args.use_ram_optimized_load and peft_model_id is None: - try: - # RAM-optimized load - self.backend_model = hf_auto_model.from_pretrained( + if not hasattr(self, "backend_model"): + # TODO: change to accelerate + logger.info("Preparing model for inference") + if use_accelerator: + peft_model_id = model_args.lora_model_path + self.backend_model = hf_auto_model.from_pretrained( model_args.model_name_or_path, config=self.hf_model_config, device_map="auto", offload_folder="offload", offload_state_dict=True, + load_in_8bit = model_args.use_int8, ) - except: + if peft_model_id is not None: + self.backend_model = PeftModel.from_pretrained( + self.backend_model, + peft_model_id, + ) + else: + from transformers.integrations import HfDeepSpeedConfig + dschf = HfDeepSpeedConfig(ds_config) + peft_model_id = model_args.lora_model_path + # NOTE: Currently offload is not supported by llama + if self.hf_model_config.model_type == "llama" and model_args.use_ram_optimized_load: logger.warning( - "Failed to use RAM optimized load. Automatically" + "llama does not support RAM optimized load. Automatically" " use original load instead." ) - # Normal load + model_args.use_ram_optimized_load = False + + if model_args.use_ram_optimized_load and peft_model_id is None: + try: + # RAM-optimized load + self.backend_model = hf_auto_model.from_pretrained( + model_args.model_name_or_path, + config=self.hf_model_config, + device_map="auto", + offload_folder="offload", + offload_state_dict=True, + ) + except: + logger.warning( + "Failed to use RAM optimized load. Automatically" + " use original load instead." + ) + # Normal load + self.backend_model = hf_auto_model.from_pretrained( + model_args.model_name_or_path, + config=self.hf_model_config, + ) + else: + if peft_model_id is not None: + logger.warning( + "LoRA does not support RAM optimized load currently." + " Automatically use original load instead." + ) self.backend_model = hf_auto_model.from_pretrained( model_args.model_name_or_path, config=self.hf_model_config, ) - else: + + self.backend_model_full = self.backend_model if peft_model_id is not None: - logger.warning( - "LoRA does not support RAM optimized load currently." - " Automatically use original load instead." + self.backend_model = PeftModel.from_pretrained( + self.backend_model, peft_model_id ) - self.backend_model = hf_auto_model.from_pretrained( - model_args.model_name_or_path, - config=self.hf_model_config, - ) - - self.backend_model_full = self.backend_model - if peft_model_id is not None: - self.backend_model = PeftModel.from_pretrained( - self.backend_model, peft_model_id - ) - if self.device == "gpu": - deepspeed.init_distributed() - self.ds_engine = deepspeed.initialize(model=self.backend_model, config_params=ds_config)[0] - self.ds_engine.module.eval() - - self.tokenizer.padding_side = "left" # necessary for llama, gpt2 and other decoder models + if self.device == "gpu": + deepspeed.init_distributed() + self.ds_engine = deepspeed.initialize(model=self.backend_model, config_params=ds_config)[0] + self.ds_engine.module.eval() + + # backend model already initialized + else: + if self.backend_model.device == torch.device("cpu"): + self.backend_model.to(self.device) + else: + return + + + def __prepare_model_for_vllm_inference( + self, + model_args: ModelArguments, + vllm_gpu_memory_utilization: float, + vllm_tensor_parallel_size: int, + ): + self.backend_model_for_inference = LLM( + model=model_args.model_name_or_path, + tokenizer=model_args.model_name_or_path, + dtype=model_args.torch_dtype, + load_format="auto", + gpu_memory_utilization=vllm_gpu_memory_utilization, + tensor_parallel_size=vllm_tensor_parallel_size, + ) + + + def activate_model_for_inference( + self, + use_vllm: bool=False, + **kwargs, + ): + if self._activated: + logger.warning("You are trying to activate the model for inference, but it is already activated.") + return + + if use_vllm: + self.__prepare_model_for_vllm_inference( + model_args=self.model_args, + vllm_gpu_memory_utilization=kwargs.get("vllm_gpu_memory_utilization"), + vllm_tensor_parallel_size=kwargs.get("vllm_tensor_parallel_size"), + ) + else: + self.__prepare_model_for_inference( + model_args=self.model_args, + hf_auto_model=self.hf_auto_model, + use_accelerator=self.use_accelerator, + ds_config=self.ds_config, + ) + + self._activated = True + + + def deactivate_model_for_inference( + self, + use_vllm: bool=False, + ): + """Deactivate the model and release the resources. + + NOTE: Currently, VLLM doesn't have an official way to do this, and the + implementation below cannot release all gpu resources by our observation. + Thus this method is just a placeholder for future implementation. See: + [Github issue](https://github.com/vllm-project/vllm/issues/1908) + """ + if not self._activated: + logger.warning("You are trying to deactivate the model for inference, but it is already deactivated.") + return + + if use_vllm: + destroy_model_parallel() + del self.backend_model_for_inference.llm_engine.model_executor.driver_worker + del self.backend_model_for_inference + gc.collect() + torch.cuda.empty_cache() + else: + self.backend_model.to("cpu") + pass + + self._activated = False def get_max_length(self): diff --git a/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py b/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py new file mode 100644 index 000000000..74c3e7fc2 --- /dev/null +++ b/src/lmflow/pipeline/utils/memory_safe_vllm_inference.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. + +# Note that this is only a workaround, since vllm +# inference engine cannot release GPU memory properly by now. Please see this github +# [issue](https://github.com/vllm-project/vllm/issues/1908). + +import logging +import sys +import os +from typing import Dict + +from transformers import ( + HfArgumentParser +) + +from lmflow.datasets import Dataset +from lmflow.models.hf_decoder_model import HFDecoderModel +from lmflow.pipeline.vllm_inferencer import VLLMInferencer +from lmflow.args import ( + ModelArguments, + DatasetArguments, + AutoArguments, +) +from lmflow.utils.constants import MEMORY_SAFE_VLLM_INFERENCE_FINISH_FLAG + + +logger = logging.getLogger(__name__) + + +def main(): + # Parses arguments + pipeline_name = "inferencer" + PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) + + parser = HfArgumentParser(( + ModelArguments, + DatasetArguments, + PipelineArguments + )) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, pipeline_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, pipeline_args = parser.parse_args_into_dataclasses() + + dataset = Dataset(data_args) + model = HFDecoderModel(model_args) + inferencer = VLLMInferencer(model_args, pipeline_args) + + res = inferencer.inference( + model, + dataset, + release_gpu=False, + detokenize=pipeline_args.memory_safe_vllm_inference_detokenize, + ) + + # use this as a flag, stdout will be captured by the pipeline + print(MEMORY_SAFE_VLLM_INFERENCE_FINISH_FLAG) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/lmflow/pipeline/vllm_inferencer.py b/src/lmflow/pipeline/vllm_inferencer.py new file mode 100644 index 000000000..6d4520a70 --- /dev/null +++ b/src/lmflow/pipeline/vllm_inferencer.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. +import os +import sys +import signal +import json +from pathlib import Path +import logging +import subprocess +import importlib.resources as pkg_resources +from typing import List, Union, Optional +import time + +from vllm import SamplingParams +from transformers import AutoTokenizer + +from lmflow.datasets import Dataset +from lmflow.pipeline.base_pipeline import BasePipeline +from lmflow.models.hf_decoder_model import HFDecoderModel +from lmflow.args import ( + InferencerArguments, + ModelArguments, + DatasetArguments, +) +from lmflow.utils.common import make_shell_args_from_dataclass +from lmflow.utils.constants import MEMORY_SAFE_VLLM_INFERENCE_FINISH_FLAG + + +logger = logging.getLogger(__name__) + + +class InferencerWithOffloading(BasePipeline): + def __init__( + self, + model_args: ModelArguments, + inferencer_args: InferencerArguments, + ): + self.model_args = model_args + self.inferencer_args = inferencer_args + self.eos_token_id = AutoTokenizer.from_pretrained(model_args.model_name_or_path).eos_token_id + + def inference(self): + raise NotImplementedError(".inference is not implemented") + + def save_inference_results(self): + raise NotImplementedError(".save_inference_results is not implemented") + + def load_inference_results(self): + raise NotImplementedError(".load_inference_results is not implemented") + + +class VLLMInferencer(InferencerWithOffloading): + def __init__( + self, + model_args: ModelArguments, + inferencer_args: InferencerArguments, + ): + assert inferencer_args.use_vllm, "The inferencer_args.use_vllm must be True." + super().__init__(model_args, inferencer_args) + self.sampling_params = self.parse_to_sampling_params(inferencer_args) + + + def parse_to_sampling_params( + self, + inference_args: InferencerArguments, + ) -> SamplingParams: + return SamplingParams( + use_beam_search=inference_args.use_beam_search, + n=inference_args.num_output_sequences, + temperature=inference_args.temperature + 1e-6, + max_tokens=inference_args.max_new_tokens, + seed=inference_args.random_seed, + top_p=inference_args.top_p, + top_k=inference_args.top_k, + stop_token_ids=[self.eos_token_id] + inference_args.additional_stop_token_ids + ) + + + def inference( + self, + model: HFDecoderModel, + dataset: Dataset, + detokenize: bool = True, + release_gpu: bool = False, + inference_args: Optional[InferencerArguments] = None, + ) -> Union[List[List[str]], List[List[List[int]]]]: + """Perform inference using the provided model and dataset. Will save inference results if + `save_results` is set to True in `inferencer_args`. + + Parameters + ---------- + model : HFDecoderModel + LMFlow HFDecoderModel object + dataset : Dataset + LMFlow Dataset object + apply_chat_template : bool, optional + Whether to apply chat template to the input, by default True. + detokenize : bool, optional + Whether to decode after generation, by default False. + release_gpu : bool, optional + Whether to release gpu resources, by default False. + NOTE: The reason why `release_gpu` and `detokenize` are not in `inference_args` is that + Inferencer may be used by other pipeline, and the pipeline may want to control these + two behaviors dynamically. + inference_args : InferencerArguments, optional + by default None + + Returns + ------- + Union[List[List[str]], List[List[List[int]]]] + When `detokenize = True`, return a list of list of strings. Inner list + contains inference_args.num_output_sequences samples for a single prompt + (i.e., `len(res[i]) = inference_args.num_output_sequences`). Outer list + contains the results for all prompts (i.e., `len(res) = len(dataset)`). + + When `detokenize = False`, return a list of list of list of ints + (token ids, no decoding after generation). + """ + if inference_args: + logger.warning( + "Overriding the default inference arguments with the provided arguments in .inference()" + ) + sampling_params = self.parse_to_sampling_params(inference_args) + else: + sampling_params = self.sampling_params + + sampling_params.detokenize = detokenize + + model_input = model.prepare_inputs_for_inference( + dataset=dataset, + apply_chat_template=self.inferencer_args.apply_chat_template, + use_vllm=self.inferencer_args.use_vllm + ) + + outputs = model.inference( + inputs=model_input, + sampling_params=sampling_params, + release_gpu=release_gpu, + use_vllm=self.inferencer_args.use_vllm, + vllm_gpu_memory_utilization=self.inferencer_args.vllm_gpu_memory_utilization, + vllm_tensor_parallel_size=self.inferencer_args.vllm_tensor_parallel_size, + ) + + if self.inferencer_args.save_results: + self.save_inference_results(outputs, self.inferencer_args.results_path) + + return outputs + + + def save_inference_results( + self, + outputs: Union[List[List[str]], List[List[List[int]]]], + save_file_path: str, + ): + with open(save_file_path, "w") as f: + json.dump(outputs, f) + + logger.info(f"Inference results are saved to {save_file_path}.") + + + def load_inference_results( + self, + results_path: str, + ) -> Union[List[List[str]], List[List[List[int]]]]: + with open(results_path, "r") as f: + results = json.load(f) + + return results + + +class MemorySafeVLLMInferencer(VLLMInferencer): + def __init__( + self, + model_args: ModelArguments, + data_args: DatasetArguments, + inferencer_args: InferencerArguments, + ): + assert inferencer_args.save_results, "For MemorySafeVLLMInferencer, `save_results` must be True." + super().__init__(model_args, inferencer_args) + self.data_args = data_args + self.inferencer_file_path = pkg_resources.files("lmflow.pipeline.utils") / "memory_safe_vllm_inference.py" + + + def inference(self): + inferencer_args = make_shell_args_from_dataclass( + dataclass_objects=[ + self.model_args, + self.data_args, + self.inferencer_args, + ], + format="shell", + ) + cmd = "python " + str(self.inferencer_file_path) + " " + inferencer_args + + cli_res = subprocess.run( + args=cmd, + stdout=sys.stdout, + stderr=sys.stdout, + shell=True, + preexec_fn=os.setsid + ) + # wait for the subprocess to finish (kill cleanly, otherwise may leads to: + # > Fatal Python error: _enter_buffered_busy: could not acquire lock for <_io.BufferedWriter name=''> + # > at interpreter shutdown, possibly due to daemon threads + time.sleep(30) + logger.info(f"MemorySafeVLLMInference subprocess run finished, info at finish: {cli_res}") + + if cli_res.returncode != 0: + raise RuntimeError(f"Error during MemorySafeVLLMInference.") + else: + outputs = self.load_inference_results(self.inferencer_args.results_path) + logger.info("MemorySafeVLLMInference result captured.") + + return outputs \ No newline at end of file diff --git a/src/lmflow/utils/common.py b/src/lmflow/utils/common.py new file mode 100644 index 000000000..0b42fc027 --- /dev/null +++ b/src/lmflow/utils/common.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. +import logging +from dataclasses import dataclass, field, fields, Field, make_dataclass +from pathlib import Path +from typing import Optional, List, Union, Dict + + +logger = logging.getLogger(__name__) + + +def make_shell_args_from_dataclass( + dataclass_objects: List, + format: str="subprocess", + skip_default: bool=True, +) -> Union[str, List[str]]: + """Return a string or a list of strings that can be used as shell arguments. + + Parameters + ---------- + dataclass_objects : List + A list of dataclass objects. + format : str, optional + Return format, can be "shell" or "subprocess", by default "subprocess". + skip_default : bool, optional + Whether to skip attributes with default values, by default True. + + Returns + ------- + Union[str, List[str]] + """ + assert isinstance(dataclass_objects, list), "dataclass_objects should be a list of dataclass objects." + all_args = {} + for dataclass_object in dataclass_objects: + for k, v in dataclass_object.__dict__.items(): + if not v: + continue + if skip_default: + if dataclass_object.__dataclass_fields__[k].default == v: + continue + if k not in all_args: + all_args[k] = v + elif k in all_args: + if all_args[k] == v: + continue + else: + logger.warning(f"Found different values for the same key: {k}, using value: {v} instead.") + all_args[k] = v + + if format == "shell": + final_res = " ".join([f"--{k} {v}" for k, v in all_args.items()]) + elif format == "subprocess": + final_res = [] + for k, v in all_args.items(): + final_res.extend([f"--{k}", str(v)]) + else: + raise ValueError(f"Unknown format: {format}") + + return final_res + + +def create_copied_dataclass( + original_dataclass, + field_prefix: str, + class_prefix: str, + new_default: Dict=None +): + """Create a copied dataclass with new field names and default values. + + Parameters + ---------- + original_dataclass : dataclass + field_prefix : str + The prefix to add to the **field** names of the copied dataclass. + class_prefix : str + The prefix to add to the **class** name of the copied dataclass. + new_default : Dict, optional + The new default values for the copied dataclass. When None, the + default values of the original dataclass are used. + + Returns + ------- + dataclass + """ + original_fields = fields(original_dataclass) + new_default = new_default or {} + new_fields = [] + for field in original_fields: + new_field = ( + f"{field_prefix}{field.name}", + field.type, + Field( + default=new_default.get(f"{field_prefix}{field.name}", field.default), + default_factory=field.default_factory, + init=field.init, + repr=field.repr, + hash=field.hash, + compare=field.compare, + metadata=field.metadata, + ) + ) + new_fields.append(new_field) + copied_dataclass = make_dataclass(f"{class_prefix}{original_dataclass.__name__}", new_fields) + return copied_dataclass + + +def remove_dataclass_attr_prefix(data_instance, prefix: str) -> Dict: + """Remove the prefix from the attribute names of a dataclass instance. + + Parameters + ---------- + data_instance : dataclass + prefix : str + The prefix to remove from the attribute names of the dataclass instance. + + Returns + ------- + Dict + """ + new_attributes = {} + for field in fields(data_instance): + attr_name = field.name + attr_value = getattr(data_instance, attr_name) + new_attr_name = f"{attr_name[len(prefix):]}" + new_attributes[new_attr_name] = attr_value + + return new_attributes \ No newline at end of file diff --git a/src/lmflow/utils/constants.py b/src/lmflow/utils/constants.py index 8c5e34f94..5506eb55a 100644 --- a/src/lmflow/utils/constants.py +++ b/src/lmflow/utils/constants.py @@ -314,3 +314,6 @@ 'qwen2': ["q_proj", "v_proj"], 'internlm2': ["wqkv"], } + +# vllm inference +MEMORY_SAFE_VLLM_INFERENCE_FINISH_FLAG = "MEMORY_SAFE_VLLM_INFERENCE_DONE" \ No newline at end of file diff --git a/src/lmflow/utils/model.py b/src/lmflow/utils/model.py new file mode 100644 index 000000000..99f7cc37d --- /dev/null +++ b/src/lmflow/utils/model.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved. +import logging +from typing import Dict, Any, List, Tuple, Union + +from transformers import AutoTokenizer + +from lmflow.args import ModelArguments + + +logger = logging.getLogger(__name__) + + +def check_homogeneity(model_args_list: List[ModelArguments]) -> bool: + assert all(isinstance(model_args, ModelArguments) for model_args in model_args_list), \ + "model_args_list should be a list of ModelArguments objects." + assert len(model_args_list) > 1, "model_args_list should have at least two elements." + + tokenizer_names = [] + for model_args in model_args_list: + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, use_fast=False) + tokenizer_names.append(tokenizer.__class__.__name__) + + return len(set(tokenizer_names)) == 1 \ No newline at end of file diff --git a/tests/pipeline/test_memory_safe_vllm_inferencer.py b/tests/pipeline/test_memory_safe_vllm_inferencer.py new file mode 100644 index 000000000..b2c2b2803 --- /dev/null +++ b/tests/pipeline/test_memory_safe_vllm_inferencer.py @@ -0,0 +1,85 @@ +# cannot use unittest, since memory safe vllm inference uses stdout, +# which has conflicts with unittest stdout. +import logging +import json + +from lmflow.args import DatasetArguments, ModelArguments, InferencerArguments +from lmflow.models.hf_decoder_model import HFDecoderModel +from lmflow.pipeline.vllm_inferencer import MemorySafeVLLMInferencer +from lmflow.datasets import Dataset + + +logger = logging.getLogger(__name__) + + +model_args = ModelArguments( + 'Qwen/Qwen2-0.5B', + torch_dtype='auto', +) +data_args = DatasetArguments( + './data/alpaca/test_conversation', + preprocessing_num_workers=4, +) +inferencer_args = InferencerArguments( + random_seed=42, + apply_chat_template=True, + num_output_sequences=2, + temperature=1.0, + max_new_tokens=1024, + save_results=True, + results_path='./data/mem_safe_vllm_res.json', + use_vllm=True, + memory_safe_vllm_inference_detokenize=False, + vllm_gpu_memory_utilization=0.95, + vllm_tensor_parallel_size=2, +) + + +class MemorySafeVLLMInferencerTest: + def test_init(self): + self.dataset = Dataset(data_args) + self.model = HFDecoderModel(model_args) + self.inferencer = MemorySafeVLLMInferencer( + model_args=model_args, + data_args=data_args, + inferencer_args=inferencer_args, + ) + self.status = [] + + def test_inference(self): + res = self.inferencer.inference() + test_res = all([ + isinstance(res, list), + isinstance(res[0], list), + isinstance(res[0][0], list), + isinstance(res[0][0][0], int), + ]) + self.status.append(test_res) + logger.warning(f"test_inference: {test_res}") + + def test_inference_detokenize(self): + inferencer_args.memory_safe_vllm_inference_detokenize = True + self.inferencer = MemorySafeVLLMInferencer( + model_args=model_args, + data_args=data_args, + inferencer_args=inferencer_args, + ) + res = self.inferencer.inference() + test_res = all([ + isinstance(res, list), + isinstance(res[0], list), + isinstance(res[0][0], str), + ]) + self.status.append(test_res) + logger.warning(f"test_inference_detokenize: {test_res}") + + def summary(self): + logger.warning(f"MemorySafeVLLMInferencerTest: {all(self.status)}") + + +if __name__ == "__main__": + test = MemorySafeVLLMInferencerTest() + test.test_init() + test.test_inference() + test.test_inference_detokenize() + test.summary() \ No newline at end of file