From c75a42793d8b2f5452fbd25ba7190ae5c7176a13 Mon Sep 17 00:00:00 2001 From: Camille7777 <44392324+Camille7777@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:45:22 +0000 Subject: [PATCH 01/12] support vllm --- .../colossal_eval/models/__init__.py | 3 +- .../ColossalEval/colossal_eval/models/vllm.py | 536 ++++++++++++++++++ applications/ColossalEval/requirements.txt | 1 + 3 files changed, 539 insertions(+), 1 deletion(-) create mode 100644 applications/ColossalEval/colossal_eval/models/vllm.py diff --git a/applications/ColossalEval/colossal_eval/models/__init__.py b/applications/ColossalEval/colossal_eval/models/__init__.py index 8f6c9b414145..ec557571ca07 100644 --- a/applications/ColossalEval/colossal_eval/models/__init__.py +++ b/applications/ColossalEval/colossal_eval/models/__init__.py @@ -1,5 +1,6 @@ from .base import BaseModel from .chatglm import ChatGLM2Model, ChatGLMModel from .huggingface import HuggingFaceCausalLM, HuggingFaceModel +from .vllm import vLLMModel -__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model"] +__all__ = ["BaseModel", "HuggingFaceModel", "HuggingFaceCausalLM", "ChatGLMModel", "ChatGLM2Model", "vLLMModel"] diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py new file mode 100644 index 000000000000..cb2dc6861dee --- /dev/null +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -0,0 +1,536 @@ +import copy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 +from torch.utils.data import DataLoader +from tqdm import tqdm +from vllm import LLM, SamplingParams + +from colossalai.logging import DistributedLogger + +from .base import BaseModel + +IGNORE_INDEX = -100 + + +class vLLMModel(BaseModel): + """ + Model wrapper around vLLM models. + + Args: + path: The path to a vLLM model. + model_max_length: The maximum sequence length of the model. + tokenizer_path: The path to the tokenizer. + tokenizer_kwargs: Keyword arguments for the tokenizer. + model_kwargs: Keyword arguments for the model. + prompt_template: The model's prompt template. + batch_size: Batch size for inference. + logger: Logger for the model. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. + quantization: The method used to quantize the model weights + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. + enforce_eager: Whether to enforce eager execution. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + disable_custom_all_reduce: See ParallelConfig + """ + + def __init__( + self, + path: str, + model_max_length: int = 2048, + tokenizer_path: Optional[str] = None, + tokenizer_kwargs: Dict = None, + model_kwargs: Dict = None, + prompt_template: Conversation = None, + batch_size: int = 1, + logger: DistributedLogger = None, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + quantization: Optional[str] = None, + gpu_memory_utilization: float = 0.5, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + **kwargs, + ): + super().__init__( + path=path, + model_max_length=model_max_length, + prompt_template=prompt_template, + batch_size=batch_size, + logger=logger, + ) + + self._load_model_and_tokenizer( + path=path, + model_kwargs=model_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + tokenizer_path=tokenizer_path if tokenizer_path else None, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + quantization=quantization, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + ) + + def _get_choices_indices(self, language: str): + """ + Get indices for each choice + + Some tokenizer will insert BOS if you don't specify add_special_tokens=False such as Llama-2. + The indices for choices may be different given the context. For example, for Llama-2 tokenizer, for Chinese context like "答案:{choice}", indices for choices A, B, C and D are 29909, 29933, 29907 and 29928, for English context like "Answer: {choice}", indices for choices A, B, C and D are 319, 350, 315 and 360. + print(self.tokenizer("答案:A")) to see + print(self.tokenizer("Answer: A")) to see + + """ + + # A trick for get "all" tokens ids related to given choices. + self.indices_for_choices = [[] for _ in range(2)] + for choice in self.choices: + self.indices_for_choices[0].append( + self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1] + ) + self.indices_for_choices[1].append( + self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1] + ) + + def _load_model_and_tokenizer( + self, + path: str, + model_kwargs: dict, + tokenizer_kwargs: dict, + tokenizer_path: Optional[str] = None, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + quantization: Optional[str] = None, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + ): + """ + Load model. + + Args: + path: The path to the model. + model_kwargs: Keyword arguments for the model. + tokenizer_kwargs: Keyword arguments for the tokenizer. + tokenizer_path: The path to the tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. + quantization: The method used to quantize the model weights + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. + enforce_eager: Whether to enforce eager execution. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + disable_custom_all_reduce: See ParallelConfig + + """ + if "torch_dtype" in model_kwargs: + model_kwargs["dtype"] = eval(model_kwargs["torch_dtype"]) + model_kwargs.pop("torch_dtype") + else: + model_kwargs.setdefault("dtype", torch.float16) + + if "trust_remote_code" in model_kwargs: + trust_remote_code = model_kwargs["trust_remote_code"] + model_kwargs.pop("trust_remote_code") + + if "trust_remote_code" in tokenizer_kwargs: + trust_remote_code = tokenizer_kwargs["trust_remote_code"] + tokenizer_kwargs.pop("trust_remote_code") + + self.model = LLM( + model=path, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + quantization=quantization, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + **model_kwargs, + **tokenizer_kwargs + ) + + self.tokenizer = self.model.get_tokenizer() + + if self.batch_size > 1: + self.tokenizer.padding_side = "left" + self.tokenizer.truncation_side = "left" + + if self.tokenizer.pad_token_id is None: + self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.") + if self.tokenizer.eos_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + elif hasattr(self.tokenizer, "eod_id"): + # Qwen has an eod token "<|endoftext|>". + self.tokenizer.pad_token_id = self.tokenizer.eod_id + + def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: + """ + Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110 + + Args: + input_ids_list: A batch of input string. + labels: A batch of labels. + + Returns: + A list of loss and a list of label length. + + """ + batch_size = len(inputs) + sampling_kwargs = SamplingParams(logprobs=1) + outputs = self.model.generate(inputs, sampling_kwargs) + ce_loss = [] + + if labels is not None: + lens = [ + len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels + ] + else: + lens = [1] * batch_size + + for i in range(batch_size): + logprobs = outputs[i].outputs[0].logprobs + token_ids = outputs[i].outputs[0].token_ids + + logprobs_list = [ + logprobs[i][token_ids[i]] + for i in range(len(logprobs)) + ] + logprobs_list = [i.logprob for i in logprobs_list] + logprobs_list = np.array(logprobs_list) + + if lens is not None: + logprobs_list = logprobs_list[:lens[i]] + + loss = -logprobs_list.sum(axis=-1) / lens[i] + ce_loss.append(loss) + + batch_loss = np.array(ce_loss) + + return batch_loss, lens + + def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: + """ + Truncate the input sequence to fit model_max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) + https://github.com/THUDM/LongBench/blob/main/pred.py#L16 + + Args: + inputs: A batch of input prompts. + max_new_tokens: Max new tokens for model to generate. + + Returns: + Truncated prompts. + + """ + + truncated_inputs = copy.deepcopy(inputs) + for i, input in enumerate(inputs): + tokenized_prompt = self.tokenizer(input, truncation=False, return_tensors="pt").input_ids[0] + if len(tokenized_prompt) > self.model_max_length - max_new_tokens: + half = (self.model_max_length - max_new_tokens) // 2 + prompt = self.tokenizer.decode( + tokenized_prompt[:half], skip_special_tokens=True + ) + self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) + truncated_inputs[i] = prompt + + return truncated_inputs + + def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: + """ + Infer the given data. + This function will call self.generate() to get model outputs and use LogitsProcessor param to get specific logits. + + Args: + data: The data for inference. + inference_kwargs: Arguments for inference. + debug: Whether to display generated prompt for debugging. + + Returns: + Inference results. + + """ + calculate_loss = inference_kwargs["calculate_loss"] + classes = inference_kwargs["all_classes"] + language = inference_kwargs["language"] + pretrain = inference_kwargs["pretrain"] + max_new_tokens = inference_kwargs["max_new_tokens"] + few_shot_data = inference_kwargs.get("few_shot_data", None) + + # Some classification questions' options are texts not a single letter such as A, B, C and D. + # If the text length is greater than 1, we won't calculate loss over choices. + if classes is not None and any(len(c) > 1 for c in classes): + classes = None + + self.choices = classes + self.indices_for_choices = None + if self.choices: + # Get indices for each choice + self._get_choices_indices(language) + + self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} + + bar = tqdm( + range(len(data_loader)), + desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps", + disable=not is_rank_0(), + ) + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + answers = [] + + for i, batch in enumerate(data_loader): + batch_prompt, batch_target = get_batch_prompt( + self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length + ) + + if is_rank_0() and debug and i == 0: + self.logger.info( + f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}" + ) + self.logger.info("-" * 120) + self.logger.info("An example prompt and prompt with target is:") + self.logger.info("-" * 120) + self.logger.info(batch_prompt[0]) + self.logger.info("-" * 120) + self.logger.info(batch_prompt[0] + batch_target[0][0]) + + if not pretrain: + batch_decodes, scores = self.generate(batch_prompt, max_new_tokens) + + if calculate_loss: + batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss( + batch_prompt, batch_target, pretrain + ) + + probs = [] + if self.indices_for_choices: + scores = scores.to(torch.float32) + # If we have indices_for_choices(must be single-choice question), there will be only one target answer for one data sample. + # Otherwise this will violate the single-choice setting. + + if calculate_loss: + labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))] + + loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist() + + probs = scores.numpy().tolist() + probs = [ + {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs)) + ] + + for j in range(len(batch)): + if not pretrain: + if isinstance(batch[j]["output"], list): + batch[j]["output"].append(batch_decodes[j].strip()) + else: + batch[j]["output"] = batch_decodes[j].strip() + + if isinstance(scores, torch.Tensor): + batch[j]["logits_over_choices"] = probs[j] + + if calculate_loss: + batch[j]["loss_over_choices"] = loss_over_choices[j] + + if calculate_loss: + batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() + + # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity. + # However, loss (which is per sample loss) suffices for most cases. + batch[j]["loss_sum"] = batch_losses[j] + batch[j]["token_num"] = batch_target_token_nums[j] + + if batch_bytes_nums: + batch[j]["byte_num"] = batch_bytes_nums[j] + answers.extend(batch) + + bar.update() + + return answers + + @torch.no_grad() + def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]: + """Generate results given a list of inputs and get logits of the first new token over choices. + + Args: + inputs: A list of strings. + max_new_tokens: Max new tokens for generation. + kwargs: Key arguments for generation + + Returns: + A list of generated strings and logits over choices. + + Note: + Currently the function only returns the logits of the first new token. + It is used for single choice question. + For multiple choices question, please avoid using the loss over choices. + You should set argument choices as None in self.inference(). + + """ + truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens) + + generation_kwargs = kwargs.copy() + generation_kwargs.update({'max_tokens': max_new_tokens}) + logits_processor = GetTokenLogitsProcessor(self.indices_for_choices) + + sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs) + + outputs = self.model.generate(truncated_inputs, sampling_kwargs) + output_strs = [] + for output in outputs: + generated_text = output.outputs[0].text + output_strs.append(generated_text) + scores = logits_processor.get_target_logits() + + return output_strs, scores + + + @torch.no_grad() + def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]: + """ + Calculate loss only on target tokens. + + Args: + batch: A batch of prompt without target answer. + batch_target: A batch of target answer. Sometimes one question can have multiple target answers. + + Returns: + Loss. + + """ + + # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. + # We don't need to generate new tokens. + # Target answer's length is usually << model_max_length, but we still call it in case. + # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. + if not pretrain: + batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] + + # Get the number of target answers for different questions + batch_target_nums = [len(prompt_target) for prompt_target in batch_target] + + if pretrain: + batch = [] + bytes_list = [] + batch_prompt_pretrain = [] + for p, b in zip(batch_prompt, batch_target): + batch.append(p + b[0]) + + for input in batch: + # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. + # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. + # After all, the rest of the original string doesn't need to be tokenized at the first place. + ratio = [16, 8, 4, 2, 1] + tokenized = None + for r in ratio: + tokenized = self.tokenizer( + [input[0 : len(input) // r]], truncation=True, max_length=self.model_max_length, return_tensors="pt" + ) + if tokenized.input_ids.size(1) >= self.model_max_length: + break + + string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True) + batch_prompt_pretrain.append(string) + bytes_list.append(len(string.encode("utf-8"))) + + batch_prompt = copy.deepcopy(batch_prompt_pretrain) + batch_target = None + else: + batch_prompt_processed = [] + batch_target_processed = [] + for prompt, targets in zip(batch_prompt, batch_target): + for target in targets: + target_tokenized = self.tokenizer( + [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" + ) + max_new_tokens = target_tokenized["input_ids"][0].size(0) + prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0] + batch_prompt_processed.append(prompt_with_correct_length) + batch_target_processed.append(target) + + batch_prompt = copy.deepcopy(batch_prompt_processed) + batch_target = copy.deepcopy(batch_target_processed) + bytes_list = None + + # Because of multiple target answers, the final batch size may be greater than self.batch_size. + # We will generate new batches. + losses = [] + target_token_nums = [] + + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_prompt, batch_target) + losses.extend(losses_per_batch) + target_token_nums.extend(target_token_num_per_batch) + + start_indice = 0 + losses_per_sample = [] + + target_token_nums_per_sample = [] + bytes_nums_per_sample = [] + for length in batch_target_nums: + losses_per_sample.append(losses[start_indice : start_indice + length]) + target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) + + if bytes_list: + bytes_nums_per_sample.append(bytes_list[start_indice : start_indice + length]) + + start_indice += length + + if bytes_list: + return losses_per_sample, target_token_nums_per_sample, bytes_nums_per_sample + + return losses_per_sample, target_token_nums_per_sample, None + +class GetTokenLogitsProcessor: + """ + LogitsProcessor to get specific logits + + Args: + indices_for_choices: token indices of required tokens + target_logits: store all the target logits + """ + def __init__(self, + indices_for_choices: List[List[int]], + ): + self.indices_for_choices = indices_for_choices, + self.target_logits = [] + + def __call__(self, + input_ids: torch.Tensor, + logits: torch.Tensor) -> torch.Tensor: + choice_scores = [] + + if not input_ids: + for option_indices in self.indices_for_choices[0]: + choice_scores.append(logits[option_indices].detach().cpu()) + + choice_scores = torch.max(torch.stack(choice_scores), dim=0)[0] + self.target_logits.append(choice_scores) + + return logits + + def get_target_logits(self) -> torch.Tensor: + return torch.stack(self.target_logits) if self.target_logits else torch.tensor([]) \ No newline at end of file diff --git a/applications/ColossalEval/requirements.txt b/applications/ColossalEval/requirements.txt index c5b9bad549e2..423978814bf5 100644 --- a/applications/ColossalEval/requirements.txt +++ b/applications/ColossalEval/requirements.txt @@ -10,3 +10,4 @@ matplotlib pandas seaborn scikit-learn +vllm==0.5.5 \ No newline at end of file From d621d3c04f47519138ecb70254189dc526ce5ced Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:49:38 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../ColossalEval/colossal_eval/models/vllm.py | 128 +++++++++--------- applications/ColossalEval/requirements.txt | 2 +- 2 files changed, 64 insertions(+), 66 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py index cb2dc6861dee..712ae300941e 100644 --- a/applications/ColossalEval/colossal_eval/models/vllm.py +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -33,13 +33,13 @@ class vLLMModel(BaseModel): quantization: The method used to quantize the model weights gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. swap_space: The size (GiB) of CPU memory per GPU to use as swap space. - cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. enforce_eager: Whether to enforce eager execution. max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. disable_custom_all_reduce: See ParallelConfig """ - + def __init__( self, path: str, @@ -69,12 +69,12 @@ def __init__( batch_size=batch_size, logger=logger, ) - + self._load_model_and_tokenizer( - path=path, + path=path, model_kwargs=model_kwargs, tokenizer_kwargs=tokenizer_kwargs, - tokenizer_path=tokenizer_path if tokenizer_path else None, + tokenizer_path=tokenizer_path if tokenizer_path else None, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, quantization=quantization, @@ -85,7 +85,7 @@ def __init__( max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, - ) + ) def _get_choices_indices(self, language: str): """ @@ -107,13 +107,13 @@ def _get_choices_indices(self, language: str): self.indices_for_choices[1].append( self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1] ) - + def _load_model_and_tokenizer( - self, - path: str, + self, + path: str, model_kwargs: dict, tokenizer_kwargs: dict, - tokenizer_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, trust_remote_code: bool = False, tensor_parallel_size: int = 1, quantization: Optional[str] = None, @@ -138,7 +138,7 @@ def _load_model_and_tokenizer( quantization: The method used to quantize the model weights gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. swap_space: The size (GiB) of CPU memory per GPU to use as swap space. - cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. enforce_eager: Whether to enforce eager execution. max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. @@ -147,20 +147,20 @@ def _load_model_and_tokenizer( """ if "torch_dtype" in model_kwargs: model_kwargs["dtype"] = eval(model_kwargs["torch_dtype"]) - model_kwargs.pop("torch_dtype") + model_kwargs.pop("torch_dtype") else: model_kwargs.setdefault("dtype", torch.float16) - + if "trust_remote_code" in model_kwargs: trust_remote_code = model_kwargs["trust_remote_code"] - model_kwargs.pop("trust_remote_code") - + model_kwargs.pop("trust_remote_code") + if "trust_remote_code" in tokenizer_kwargs: trust_remote_code = tokenizer_kwargs["trust_remote_code"] tokenizer_kwargs.pop("trust_remote_code") self.model = LLM( - model=path, + model=path, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, quantization=quantization, @@ -172,11 +172,11 @@ def _load_model_and_tokenizer( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **model_kwargs, - **tokenizer_kwargs - ) - + **tokenizer_kwargs, + ) + self.tokenizer = self.model.get_tokenizer() - + if self.batch_size > 1: self.tokenizer.padding_side = "left" self.tokenizer.truncation_side = "left" @@ -188,11 +188,11 @@ def _load_model_and_tokenizer( elif hasattr(self.tokenizer, "eod_id"): # Qwen has an eod token "<|endoftext|>". self.tokenizer.pad_token_id = self.tokenizer.eod_id - + def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: """ Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110 - + Args: input_ids_list: A batch of input string. labels: A batch of labels. @@ -205,35 +205,30 @@ def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: sampling_kwargs = SamplingParams(logprobs=1) outputs = self.model.generate(inputs, sampling_kwargs) ce_loss = [] - + if labels is not None: - lens = [ - len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels - ] + lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels] else: lens = [1] * batch_size - + for i in range(batch_size): logprobs = outputs[i].outputs[0].logprobs token_ids = outputs[i].outputs[0].token_ids - - logprobs_list = [ - logprobs[i][token_ids[i]] - for i in range(len(logprobs)) - ] + + logprobs_list = [logprobs[i][token_ids[i]] for i in range(len(logprobs))] logprobs_list = [i.logprob for i in logprobs_list] logprobs_list = np.array(logprobs_list) if lens is not None: - logprobs_list = logprobs_list[:lens[i]] + logprobs_list = logprobs_list[: lens[i]] loss = -logprobs_list.sum(axis=-1) / lens[i] ce_loss.append(loss) - + batch_loss = np.array(ce_loss) - + return batch_loss, lens - + def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: """ Truncate the input sequence to fit model_max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) @@ -392,23 +387,22 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str """ truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens) - + generation_kwargs = kwargs.copy() - generation_kwargs.update({'max_tokens': max_new_tokens}) + generation_kwargs.update({"max_tokens": max_new_tokens}) logits_processor = GetTokenLogitsProcessor(self.indices_for_choices) - + sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs) - + outputs = self.model.generate(truncated_inputs, sampling_kwargs) output_strs = [] for output in outputs: generated_text = output.outputs[0].text output_strs.append(generated_text) scores = logits_processor.get_target_logits() - + return output_strs, scores - - + @torch.no_grad() def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]: """ @@ -432,23 +426,26 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr # Get the number of target answers for different questions batch_target_nums = [len(prompt_target) for prompt_target in batch_target] - + if pretrain: batch = [] bytes_list = [] batch_prompt_pretrain = [] for p, b in zip(batch_prompt, batch_target): batch.append(p + b[0]) - + for input in batch: - # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. - # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. - # After all, the rest of the original string doesn't need to be tokenized at the first place. + # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. + # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. + # After all, the rest of the original string doesn't need to be tokenized at the first place. ratio = [16, 8, 4, 2, 1] tokenized = None for r in ratio: tokenized = self.tokenizer( - [input[0 : len(input) // r]], truncation=True, max_length=self.model_max_length, return_tensors="pt" + [input[0 : len(input) // r]], + truncation=True, + max_length=self.model_max_length, + return_tensors="pt", ) if tokenized.input_ids.size(1) >= self.model_max_length: break @@ -456,7 +453,7 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True) batch_prompt_pretrain.append(string) bytes_list.append(len(string.encode("utf-8"))) - + batch_prompt = copy.deepcopy(batch_prompt_pretrain) batch_target = None else: @@ -465,13 +462,13 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr for prompt, targets in zip(batch_prompt, batch_target): for target in targets: target_tokenized = self.tokenizer( - [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" - ) + [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" + ) max_new_tokens = target_tokenized["input_ids"][0].size(0) prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0] batch_prompt_processed.append(prompt_with_correct_length) batch_target_processed.append(target) - + batch_prompt = copy.deepcopy(batch_prompt_processed) batch_target = copy.deepcopy(batch_target_processed) bytes_list = None @@ -480,7 +477,7 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr # We will generate new batches. losses = [] target_token_nums = [] - + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_prompt, batch_target) losses.extend(losses_per_batch) target_token_nums.extend(target_token_num_per_batch) @@ -504,25 +501,26 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr return losses_per_sample, target_token_nums_per_sample, None + class GetTokenLogitsProcessor: """ - LogitsProcessor to get specific logits + LogitsProcessor to get specific logits Args: indices_for_choices: token indices of required tokens target_logits: store all the target logits """ - def __init__(self, - indices_for_choices: List[List[int]], - ): - self.indices_for_choices = indices_for_choices, + + def __init__( + self, + indices_for_choices: List[List[int]], + ): + self.indices_for_choices = (indices_for_choices,) self.target_logits = [] - - def __call__(self, - input_ids: torch.Tensor, - logits: torch.Tensor) -> torch.Tensor: + + def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: choice_scores = [] - + if not input_ids: for option_indices in self.indices_for_choices[0]: choice_scores.append(logits[option_indices].detach().cpu()) @@ -533,4 +531,4 @@ def __call__(self, return logits def get_target_logits(self) -> torch.Tensor: - return torch.stack(self.target_logits) if self.target_logits else torch.tensor([]) \ No newline at end of file + return torch.stack(self.target_logits) if self.target_logits else torch.tensor([]) diff --git a/applications/ColossalEval/requirements.txt b/applications/ColossalEval/requirements.txt index 423978814bf5..f9985b49f9ed 100644 --- a/applications/ColossalEval/requirements.txt +++ b/applications/ColossalEval/requirements.txt @@ -10,4 +10,4 @@ matplotlib pandas seaborn scikit-learn -vllm==0.5.5 \ No newline at end of file +vllm==0.5.5 From f1ab33d0560519d47d4f166e4fa21a2a46d0ff80 Mon Sep 17 00:00:00 2001 From: Camille7777 <44392324+Camille7777@users.noreply.github.com> Date: Wed, 11 Sep 2024 09:05:25 +0000 Subject: [PATCH 03/12] modify vllm and update readme --- applications/ColossalEval/README.md | 41 +++- .../ColossalEval/colossal_eval/models/vllm.py | 175 +++++++----------- 2 files changed, 101 insertions(+), 115 deletions(-) diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md index 890b1fed3912..fee572882b8c 100644 --- a/applications/ColossalEval/README.md +++ b/applications/ColossalEval/README.md @@ -230,7 +230,7 @@ Example: In this step, you will configure your tokenizer and model arguments to infer on the given datasets. A config file consists of two parts. -1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel` and `ChatGLMModel2`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields. +1. Model config. In model config, you need to specify model name, model path, model class, tokenizer arguments and model arguments. For model class, currently we support `HuggingFaceModel`, `HuggingFaceCausalLM`, `ChatGLMModel`, `ChatGLMModel2` and `vLLMModel`. `HuggingFaceModel` is for models that can be loaded with `AutoModel` and `HuggingFaceCausalLM` is for models that can be loaded with `AutoModelForCausalLM`. `ChatGLMModel` and `ChatGLMModel2` are for ChatGLM and ChatGLM2 models respectively. `vLLMModel` is for models that can be loaded with vllm `LLM`. You can check all model classes in `colossal_eval/models/__init__.py`. If your model should set `trust_remote_code` as true, specify it in the `tokenizer_kwargs` and `model_kwargs` fields. 2. Dataset config. In dataset config, you need to specify dataset name, path and dataset class. Currently, we support zero-shot on dataset MMLU, CMMLU, AGIEval, GAOKAO-Bench, GSM8K and LongBench and few-shot on dataset MMLU, CMMLU AGIEval and GSM8K. If you want to enable few shot, set `few_shot` as true. You can check all model classes in `colossal_eval/dataset/__init__.py`. Once you have all config ready, the program will run inference on all the given datasets on all the given models. @@ -272,7 +272,42 @@ An example config using model class `HuggingFaceCausalLM` and dataset class `CMM } ``` -Currently, we support Hugging Face models. The `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong. +An example config using model class `HuggingFaceCausalLM` and dataset class `CMMLUDataset` can be: +```json +{ + "model": [ + { + "name": "model name", + "model_class": "vLLMModel", + "parameters": { + "path": "path to model", + "model_max_length": 2048, + "tokenizer_path": "", + "tokenizer_kwargs": { + "trust_remote_code": true + }, + "model_kwargs": { + "trust_remote_code": true + }, + "prompt_template": "plain", + "batch_size": 4 + } + } + ], + "dataset": [ + { + "name": "dataset name", + "dataset_class": "CMMLUDataset", + "debug": false, + "few_shot": true, + "path": "path to original dataset", + "save_path": "path to save converted dataset" + } + ] +} +``` + +Currently, we support Hugging Face models as well as vLLM models. For Hugging Face models, the `tokenizer_kwargs` is the arguments used in `AutoTokenizer.from_pretrained()`. The `model_kwargs` is the arguments used in `AutoModel.from_pretrained` or `AutoModelForCausalLM.from_pretrained()`. For vLLM models, `tokenizer_kwargs` and `model_kwargs` are loaded together using offline inference `LLM` class. `few_shot` will be set true if you want to enable few-shot prompting for the dataset. `debug` will be set true if you want to verify whether your prompt is right or wrong. > For GSM8K dataset, you can set additional flags `load_train` or `load_reference` for dataset configuration as true and during the inference process, the program will calculate loss summation over all tokens for each data sample. During the evaluation process, you can use metric `loss_over_all_tokens` to calculate the overall loss and use it for data leakage evaluation. @@ -287,7 +322,7 @@ torchrun --nproc_per_node=4 inference.py \ --inference_save_path "path to save inference results" ``` -You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size. +You should specify the path to config file in `config`. You can run the script without specifying `load_dataset` if you already save the converted dataset or otherwise set it to first load the original dataset and save the converted dataset. You should specify the path to save inference results in `inference_save_path`. If you want to use tensor parallel inference, specify the tensor parallel size in `--tp_size` and the process will automatically calculate data parallel size (currently not support for `vLLMModel`). ### Evaluation diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py index cb2dc6861dee..7d4e3d66e7df 100644 --- a/applications/ColossalEval/colossal_eval/models/vllm.py +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -10,12 +10,12 @@ from colossalai.logging import DistributedLogger -from .base import BaseModel +from .huggingface import HuggingFaceModel IGNORE_INDEX = -100 -class vLLMModel(BaseModel): +class vLLMModel(HuggingFaceModel): """ Model wrapper around vLLM models. @@ -33,13 +33,13 @@ class vLLMModel(BaseModel): quantization: The method used to quantize the model weights gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. swap_space: The size (GiB) of CPU memory per GPU to use as swap space. - cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. enforce_eager: Whether to enforce eager execution. max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. disable_custom_all_reduce: See ParallelConfig """ - + def __init__( self, path: str, @@ -69,12 +69,12 @@ def __init__( batch_size=batch_size, logger=logger, ) - + self._load_model_and_tokenizer( - path=path, + path=path, model_kwargs=model_kwargs, tokenizer_kwargs=tokenizer_kwargs, - tokenizer_path=tokenizer_path if tokenizer_path else None, + tokenizer_path=tokenizer_path if tokenizer_path else None, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, quantization=quantization, @@ -85,35 +85,14 @@ def __init__( max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, - ) - - def _get_choices_indices(self, language: str): - """ - Get indices for each choice - - Some tokenizer will insert BOS if you don't specify add_special_tokens=False such as Llama-2. - The indices for choices may be different given the context. For example, for Llama-2 tokenizer, for Chinese context like "答案:{choice}", indices for choices A, B, C and D are 29909, 29933, 29907 and 29928, for English context like "Answer: {choice}", indices for choices A, B, C and D are 319, 350, 315 and 360. - print(self.tokenizer("答案:A")) to see - print(self.tokenizer("Answer: A")) to see - - """ + ) - # A trick for get "all" tokens ids related to given choices. - self.indices_for_choices = [[] for _ in range(2)] - for choice in self.choices: - self.indices_for_choices[0].append( - self.tokenizer(f"Answer: {choice}", add_special_tokens=False).input_ids[-1] - ) - self.indices_for_choices[1].append( - self.tokenizer(f"答案:{choice}", add_special_tokens=False).input_ids[-1] - ) - def _load_model_and_tokenizer( - self, - path: str, + self, + path: str, model_kwargs: dict, tokenizer_kwargs: dict, - tokenizer_path: Optional[str] = None, + tokenizer_path: Optional[str] = None, trust_remote_code: bool = False, tensor_parallel_size: int = 1, quantization: Optional[str] = None, @@ -138,7 +117,7 @@ def _load_model_and_tokenizer( quantization: The method used to quantize the model weights gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. swap_space: The size (GiB) of CPU memory per GPU to use as swap space. - cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. enforce_eager: Whether to enforce eager execution. max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. @@ -147,20 +126,20 @@ def _load_model_and_tokenizer( """ if "torch_dtype" in model_kwargs: model_kwargs["dtype"] = eval(model_kwargs["torch_dtype"]) - model_kwargs.pop("torch_dtype") + model_kwargs.pop("torch_dtype") else: model_kwargs.setdefault("dtype", torch.float16) - + if "trust_remote_code" in model_kwargs: trust_remote_code = model_kwargs["trust_remote_code"] - model_kwargs.pop("trust_remote_code") - + model_kwargs.pop("trust_remote_code") + if "trust_remote_code" in tokenizer_kwargs: trust_remote_code = tokenizer_kwargs["trust_remote_code"] tokenizer_kwargs.pop("trust_remote_code") self.model = LLM( - model=path, + model=path, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, quantization=quantization, @@ -172,11 +151,11 @@ def _load_model_and_tokenizer( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **model_kwargs, - **tokenizer_kwargs - ) - + **tokenizer_kwargs, + ) + self.tokenizer = self.model.get_tokenizer() - + if self.batch_size > 1: self.tokenizer.padding_side = "left" self.tokenizer.truncation_side = "left" @@ -188,11 +167,11 @@ def _load_model_and_tokenizer( elif hasattr(self.tokenizer, "eod_id"): # Qwen has an eod token "<|endoftext|>". self.tokenizer.pad_token_id = self.tokenizer.eod_id - + def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: """ Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110 - + Args: input_ids_list: A batch of input string. labels: A batch of labels. @@ -205,60 +184,29 @@ def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: sampling_kwargs = SamplingParams(logprobs=1) outputs = self.model.generate(inputs, sampling_kwargs) ce_loss = [] - + if labels is not None: - lens = [ - len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels - ] + lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels] else: lens = [1] * batch_size - + for i in range(batch_size): logprobs = outputs[i].outputs[0].logprobs token_ids = outputs[i].outputs[0].token_ids - - logprobs_list = [ - logprobs[i][token_ids[i]] - for i in range(len(logprobs)) - ] + + logprobs_list = [logprobs[i][token_ids[i]] for i in range(len(logprobs))] logprobs_list = [i.logprob for i in logprobs_list] logprobs_list = np.array(logprobs_list) if lens is not None: - logprobs_list = logprobs_list[:lens[i]] + logprobs_list = logprobs_list[: lens[i]] loss = -logprobs_list.sum(axis=-1) / lens[i] ce_loss.append(loss) - - batch_loss = np.array(ce_loss) - - return batch_loss, lens - - def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: - """ - Truncate the input sequence to fit model_max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) - https://github.com/THUDM/LongBench/blob/main/pred.py#L16 - - Args: - inputs: A batch of input prompts. - max_new_tokens: Max new tokens for model to generate. - - Returns: - Truncated prompts. - - """ - truncated_inputs = copy.deepcopy(inputs) - for i, input in enumerate(inputs): - tokenized_prompt = self.tokenizer(input, truncation=False, return_tensors="pt").input_ids[0] - if len(tokenized_prompt) > self.model_max_length - max_new_tokens: - half = (self.model_max_length - max_new_tokens) // 2 - prompt = self.tokenizer.decode( - tokenized_prompt[:half], skip_special_tokens=True - ) + self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) - truncated_inputs[i] = prompt + batch_loss = np.array(ce_loss) - return truncated_inputs + return batch_loss, lens def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: """ @@ -392,23 +340,22 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str """ truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens) - + generation_kwargs = kwargs.copy() - generation_kwargs.update({'max_tokens': max_new_tokens}) + generation_kwargs.update({"max_tokens": max_new_tokens}) logits_processor = GetTokenLogitsProcessor(self.indices_for_choices) - + sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs) - + outputs = self.model.generate(truncated_inputs, sampling_kwargs) output_strs = [] for output in outputs: generated_text = output.outputs[0].text output_strs.append(generated_text) scores = logits_processor.get_target_logits() - + return output_strs, scores - - + @torch.no_grad() def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]: """ @@ -432,23 +379,26 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr # Get the number of target answers for different questions batch_target_nums = [len(prompt_target) for prompt_target in batch_target] - + if pretrain: batch = [] bytes_list = [] batch_prompt_pretrain = [] for p, b in zip(batch_prompt, batch_target): batch.append(p + b[0]) - + for input in batch: - # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. - # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. - # After all, the rest of the original string doesn't need to be tokenized at the first place. + # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. + # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. + # After all, the rest of the original string doesn't need to be tokenized at the first place. ratio = [16, 8, 4, 2, 1] tokenized = None for r in ratio: tokenized = self.tokenizer( - [input[0 : len(input) // r]], truncation=True, max_length=self.model_max_length, return_tensors="pt" + [input[0 : len(input) // r]], + truncation=True, + max_length=self.model_max_length, + return_tensors="pt", ) if tokenized.input_ids.size(1) >= self.model_max_length: break @@ -456,7 +406,7 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr string = self.tokenizer.decode(tokenized.input_ids[0], skip_special_tokens=True) batch_prompt_pretrain.append(string) bytes_list.append(len(string.encode("utf-8"))) - + batch_prompt = copy.deepcopy(batch_prompt_pretrain) batch_target = None else: @@ -465,13 +415,13 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr for prompt, targets in zip(batch_prompt, batch_target): for target in targets: target_tokenized = self.tokenizer( - [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" - ) + [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" + ) max_new_tokens = target_tokenized["input_ids"][0].size(0) prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0] batch_prompt_processed.append(prompt_with_correct_length) batch_target_processed.append(target) - + batch_prompt = copy.deepcopy(batch_prompt_processed) batch_target = copy.deepcopy(batch_target_processed) bytes_list = None @@ -480,7 +430,7 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr # We will generate new batches. losses = [] target_token_nums = [] - + losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_prompt, batch_target) losses.extend(losses_per_batch) target_token_nums.extend(target_token_num_per_batch) @@ -504,25 +454,26 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr return losses_per_sample, target_token_nums_per_sample, None + class GetTokenLogitsProcessor: """ - LogitsProcessor to get specific logits + LogitsProcessor to get specific logits Args: indices_for_choices: token indices of required tokens target_logits: store all the target logits """ - def __init__(self, - indices_for_choices: List[List[int]], - ): - self.indices_for_choices = indices_for_choices, + + def __init__( + self, + indices_for_choices: List[List[int]], + ): + self.indices_for_choices = (indices_for_choices,) self.target_logits = [] - - def __call__(self, - input_ids: torch.Tensor, - logits: torch.Tensor) -> torch.Tensor: + + def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: choice_scores = [] - + if not input_ids: for option_indices in self.indices_for_choices[0]: choice_scores.append(logits[option_indices].detach().cpu()) @@ -533,4 +484,4 @@ def __call__(self, return logits def get_target_logits(self) -> torch.Tensor: - return torch.stack(self.target_logits) if self.target_logits else torch.tensor([]) \ No newline at end of file + return torch.stack(self.target_logits) if self.target_logits else torch.tensor([]) From 85fdd245e339e338162a0dec05a07f522bcf214d Mon Sep 17 00:00:00 2001 From: Camille7777 <44392324+Camille7777@users.noreply.github.com> Date: Thu, 12 Sep 2024 01:45:34 +0000 Subject: [PATCH 04/12] run pre-commit --- .../ColossalEval/colossal_eval/models/vllm.py | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py index 090c6b81dd52..67a32b77b6fb 100644 --- a/applications/ColossalEval/colossal_eval/models/vllm.py +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -41,7 +41,6 @@ class vLLMModel(HuggingFaceModel): disable_custom_all_reduce: See ParallelConfig """ - def __init__( self, path: str, @@ -72,7 +71,6 @@ def __init__( logger=logger, ) - self._load_model_and_tokenizer( path=path, path=path, @@ -140,7 +138,6 @@ def _load_model_and_tokenizer( else: model_kwargs.setdefault("dtype", torch.float16) - if "trust_remote_code" in model_kwargs: trust_remote_code = model_kwargs["trust_remote_code"] model_kwargs.pop("trust_remote_code") @@ -168,12 +165,8 @@ def _load_model_and_tokenizer( **tokenizer_kwargs, ) - **tokenizer_kwargs, - ) - self.tokenizer = self.model.get_tokenizer() - if self.batch_size > 1: self.tokenizer.padding_side = "left" self.tokenizer.truncation_side = "left" @@ -186,7 +179,6 @@ def _load_model_and_tokenizer( # Qwen has an eod token "<|endoftext|>". self.tokenizer.pad_token_id = self.tokenizer.eod_id - def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: """ Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110 @@ -205,14 +197,12 @@ def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: outputs = self.model.generate(inputs, sampling_kwargs) ce_loss = [] - if labels is not None: lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels] lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels] else: lens = [1] * batch_size - for i in range(batch_size): logprobs = outputs[i].outputs[0].logprobs token_ids = outputs[i].outputs[0].token_ids @@ -230,10 +220,8 @@ def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: loss = -logprobs_list.sum(axis=-1) / lens[i] ce_loss.append(loss) - batch_loss = np.array(ce_loss) - return batch_loss, lens def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: @@ -369,16 +357,13 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str """ truncated_inputs = self._get_truncated_prompts(inputs, max_new_tokens) - generation_kwargs = kwargs.copy() generation_kwargs.update({"max_tokens": max_new_tokens}) generation_kwargs.update({"max_tokens": max_new_tokens}) logits_processor = GetTokenLogitsProcessor(self.indices_for_choices) - sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs) - outputs = self.model.generate(truncated_inputs, sampling_kwargs) output_strs = [] for output in outputs: @@ -386,10 +371,8 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str output_strs.append(generated_text) scores = logits_processor.get_target_logits() - return output_strs, scores - @torch.no_grad() def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]: """ @@ -414,7 +397,6 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr # Get the number of target answers for different questions batch_target_nums = [len(prompt_target) for prompt_target in batch_target] - if pretrain: batch = [] bytes_list = [] @@ -422,7 +404,6 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr for p, b in zip(batch_prompt, batch_target): batch.append(p + b[0]) - for input in batch: # Pretrain data tends to be very long, sometimes much larger than the model_max_length, we only tokenize 1/ratio of the data first to accelerate the tokenization process. # Once the length of the result is greater or equal to model_max_length, we stop iterating on ratios and use the result as input_ids and labels. @@ -438,10 +419,6 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr truncation=True, max_length=self.model_max_length, return_tensors="pt", - [input[0 : len(input) // r]], - truncation=True, - max_length=self.model_max_length, - return_tensors="pt", ) if tokenized.input_ids.size(1) >= self.model_max_length: break @@ -450,7 +427,6 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr batch_prompt_pretrain.append(string) bytes_list.append(len(string.encode("utf-8"))) - batch_prompt = copy.deepcopy(batch_prompt_pretrain) batch_target = None else: @@ -460,15 +436,12 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr for target in targets: target_tokenized = self.tokenizer( [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" - ) - [target], truncation=True, max_length=self.model_max_length, return_tensors="pt" ) max_new_tokens = target_tokenized["input_ids"][0].size(0) prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0] batch_prompt_processed.append(prompt_with_correct_length) batch_target_processed.append(target) - batch_prompt = copy.deepcopy(batch_prompt_processed) batch_target = copy.deepcopy(batch_target_processed) bytes_list = None @@ -478,7 +451,6 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr losses = [] target_token_nums = [] - losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_prompt, batch_target) losses.extend(losses_per_batch) target_token_nums.extend(target_token_num_per_batch) @@ -503,7 +475,6 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr return losses_per_sample, target_token_nums_per_sample, None - class GetTokenLogitsProcessor: """ LogitsProcessor to get specific logits @@ -527,12 +498,9 @@ def __init__( self.indices_for_choices = (indices_for_choices,) self.target_logits = [] - def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: - def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.Tensor: choice_scores = [] - if not input_ids: for option_indices in self.indices_for_choices[0]: choice_scores.append(logits[option_indices].detach().cpu()) From 3b3163633df3e80781f74990243e3fd49ae9f8ef Mon Sep 17 00:00:00 2001 From: Camille7777 <44392324+Camille7777@users.noreply.github.com> Date: Thu, 12 Sep 2024 06:39:19 +0000 Subject: [PATCH 05/12] remove dupilicated lines and refine code --- .../ColossalEval/colossal_eval/models/vllm.py | 33 ++++--------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py index 67a32b77b6fb..d0f8d0fecd4d 100644 --- a/applications/ColossalEval/colossal_eval/models/vllm.py +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -34,7 +34,6 @@ class vLLMModel(HuggingFaceModel): gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. swap_space: The size (GiB) of CPU memory per GPU to use as swap space. cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. - cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. enforce_eager: Whether to enforce eager execution. max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. @@ -71,13 +70,11 @@ def __init__( logger=logger, ) - self._load_model_and_tokenizer( - path=path, + self._load_model( path=path, model_kwargs=model_kwargs, tokenizer_kwargs=tokenizer_kwargs, tokenizer_path=tokenizer_path if tokenizer_path else None, - tokenizer_path=tokenizer_path if tokenizer_path else None, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, quantization=quantization, @@ -90,15 +87,12 @@ def __init__( disable_custom_all_reduce=disable_custom_all_reduce, ) - def _load_model_and_tokenizer( - self, - path: str, + def _load_model( self, path: str, model_kwargs: dict, tokenizer_kwargs: dict, tokenizer_path: Optional[str] = None, - tokenizer_path: Optional[str] = None, trust_remote_code: bool = False, tensor_parallel_size: int = 1, quantization: Optional[str] = None, @@ -124,7 +118,6 @@ def _load_model_and_tokenizer( gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. swap_space: The size (GiB) of CPU memory per GPU to use as swap space. cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. - cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. enforce_eager: Whether to enforce eager execution. max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. @@ -134,7 +127,6 @@ def _load_model_and_tokenizer( if "torch_dtype" in model_kwargs: model_kwargs["dtype"] = eval(model_kwargs["torch_dtype"]) model_kwargs.pop("torch_dtype") - model_kwargs.pop("torch_dtype") else: model_kwargs.setdefault("dtype", torch.float16) @@ -142,14 +134,11 @@ def _load_model_and_tokenizer( trust_remote_code = model_kwargs["trust_remote_code"] model_kwargs.pop("trust_remote_code") - model_kwargs.pop("trust_remote_code") - if "trust_remote_code" in tokenizer_kwargs: trust_remote_code = tokenizer_kwargs["trust_remote_code"] tokenizer_kwargs.pop("trust_remote_code") self.model = LLM( - model=path, model=path, trust_remote_code=trust_remote_code, tensor_parallel_size=tensor_parallel_size, @@ -178,12 +167,15 @@ def _load_model_and_tokenizer( elif hasattr(self.tokenizer, "eod_id"): # Qwen has an eod token "<|endoftext|>". self.tokenizer.pad_token_id = self.tokenizer.eod_id + else: + self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.") + raise ValueError("The tokenizer does not have a pad_token_id, eos_token, or eod_id. " + "Please set pad_token_id manually.") def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: """ Calculate loss on target tokens. Adapted from https://github.com/open-compass/opencompass/blob/c2bcd8725e615ec455bf5b7301f8d09962cd64e3/opencompass/models/vllm.py#L110 - Args: input_ids_list: A batch of input string. labels: A batch of labels. @@ -199,7 +191,6 @@ def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: if labels is not None: lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels] - lens = [len(self.tokenizer.encode(label, add_special_tokens=False)) for label in labels] else: lens = [1] * batch_size @@ -207,15 +198,12 @@ def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: logprobs = outputs[i].outputs[0].logprobs token_ids = outputs[i].outputs[0].token_ids - logprobs_list = [logprobs[i][token_ids[i]] for i in range(len(logprobs))] - logprobs_list = [logprobs[i][token_ids[i]] for i in range(len(logprobs))] logprobs_list = [i.logprob for i in logprobs_list] logprobs_list = np.array(logprobs_list) if lens is not None: logprobs_list = logprobs_list[: lens[i]] - logprobs_list = logprobs_list[: lens[i]] loss = -logprobs_list.sum(axis=-1) / lens[i] ce_loss.append(loss) @@ -359,7 +347,6 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str generation_kwargs = kwargs.copy() generation_kwargs.update({"max_tokens": max_new_tokens}) - generation_kwargs.update({"max_tokens": max_new_tokens}) logits_processor = GetTokenLogitsProcessor(self.indices_for_choices) sampling_kwargs = SamplingParams(logits_processors=[logits_processor], **generation_kwargs) @@ -478,19 +465,11 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr class GetTokenLogitsProcessor: """ LogitsProcessor to get specific logits - LogitsProcessor to get specific logits Args: indices_for_choices: token indices of required tokens target_logits: store all the target logits """ - - def __init__( - self, - indices_for_choices: List[List[int]], - ): - self.indices_for_choices = (indices_for_choices,) - def __init__( self, indices_for_choices: List[List[int]], From e90443d5f62e9b1056dc49b78c45de95638be4cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 06:41:43 +0000 Subject: [PATCH 06/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- applications/ColossalEval/colossal_eval/models/vllm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py index d0f8d0fecd4d..4aa377ec91aa 100644 --- a/applications/ColossalEval/colossal_eval/models/vllm.py +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -169,8 +169,10 @@ def _load_model( self.tokenizer.pad_token_id = self.tokenizer.eod_id else: self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.") - raise ValueError("The tokenizer does not have a pad_token_id, eos_token, or eod_id. " - "Please set pad_token_id manually.") + raise ValueError( + "The tokenizer does not have a pad_token_id, eos_token, or eod_id. " + "Please set pad_token_id manually." + ) def _calculate_loss(self, inputs: List[str], labels: List[str]) -> Tuple[List]: """ @@ -470,6 +472,7 @@ class GetTokenLogitsProcessor: indices_for_choices: token indices of required tokens target_logits: store all the target logits """ + def __init__( self, indices_for_choices: List[List[int]], From 740617f3aea5dd9311baf27d105a5d2a01fc9fb8 Mon Sep 17 00:00:00 2001 From: Camille7777 <44392324+Camille7777@users.noreply.github.com> Date: Fri, 13 Sep 2024 09:36:38 +0000 Subject: [PATCH 07/12] update param name --- .../colossal_eval/models/chatglm.py | 4 ++-- .../colossal_eval/models/huggingface.py | 18 +++++++++--------- .../ColossalEval/colossal_eval/models/vllm.py | 14 +++++++------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/models/chatglm.py b/applications/ColossalEval/colossal_eval/models/chatglm.py index 9c70c0d2a1ad..4a48f4c0ed3e 100644 --- a/applications/ColossalEval/colossal_eval/models/chatglm.py +++ b/applications/ColossalEval/colossal_eval/models/chatglm.py @@ -28,7 +28,7 @@ def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List @torch.no_grad() def get_loss( - self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False ) -> List[List[float]]: """ Calculate loss only on target tokens. @@ -225,7 +225,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str @torch.no_grad() def get_loss( - self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool = False ) -> List[List[float]]: """ Calculate loss only on target tokens. diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index e91743525f0e..af5f66098058 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -245,7 +245,7 @@ def _get_input_ids_and_labels_pretrain(self, batch_prompt: List[str]) -> Tuple[L return input_ids_list, labels_list, bytes_list def _get_input_ids_and_labels( - self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool ) -> Tuple[List[torch.LongTensor]]: """ Get input_ids and labels for the given data. @@ -258,7 +258,7 @@ def _get_input_ids_and_labels( Input_ids and labels for the given batch. """ - if pretrain: + if calculate_overall_loss: batch = [] # Concatenate prompt and target answers. # You should decide the concatenation character in the corresponding dataset script in dataset folder. For example, in line 119 dataset/gsm.py, the concatenation character is space. @@ -342,7 +342,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d calculate_loss = inference_kwargs["calculate_loss"] classes = inference_kwargs["all_classes"] language = inference_kwargs["language"] - pretrain = inference_kwargs["pretrain"] + calculate_overall_loss = inference_kwargs["pretrain"] max_new_tokens = inference_kwargs["max_new_tokens"] few_shot_data = inference_kwargs.get("few_shot_data", None) @@ -384,12 +384,12 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d self.logger.info("-" * 120) self.logger.info(batch_prompt[0] + batch_target[0][0]) - if not pretrain: + if not calculate_overall_loss: batch_decodes, scores = self.generate(batch_prompt, max_new_tokens) if calculate_loss: batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss( - batch_prompt, batch_target, pretrain + batch_prompt, batch_target, calculate_overall_loss ) probs = [] @@ -409,7 +409,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d ] for j in range(len(batch)): - if not pretrain: + if not calculate_overall_loss: if isinstance(batch[j]["output"], list): batch[j]["output"].append(batch_decodes[j].strip()) else: @@ -496,7 +496,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str return decoded_sequences, scores @torch.no_grad() - def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]: + def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool) -> List[List[float]]: """ Calculate loss only on target tokens. @@ -513,13 +513,13 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr # We don't need to generate new tokens. # Target answer's length is usually << model_max_length, but we still call it in case. # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. - if not pretrain: + if not calculate_overall_loss: batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] # Get the number of target answers for different questions batch_target_nums = [len(prompt_target) for prompt_target in batch_target] - input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, pretrain) + input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, calculate_overall_loss) # Because of multiple target answers, the final batch size may be greater than self.batch_size. # We will generate new batches. diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py index d0f8d0fecd4d..c39bf80c1d01 100644 --- a/applications/ColossalEval/colossal_eval/models/vllm.py +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -229,7 +229,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d calculate_loss = inference_kwargs["calculate_loss"] classes = inference_kwargs["all_classes"] language = inference_kwargs["language"] - pretrain = inference_kwargs["pretrain"] + calculate_overall_loss = inference_kwargs["pretrain"] max_new_tokens = inference_kwargs["max_new_tokens"] few_shot_data = inference_kwargs.get("few_shot_data", None) @@ -271,12 +271,12 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d self.logger.info("-" * 120) self.logger.info(batch_prompt[0] + batch_target[0][0]) - if not pretrain: + if not calculate_overall_loss: batch_decodes, scores = self.generate(batch_prompt, max_new_tokens) if calculate_loss: batch_losses, batch_target_token_nums, batch_bytes_nums = self.get_loss( - batch_prompt, batch_target, pretrain + batch_prompt, batch_target, calculate_overall_loss ) probs = [] @@ -296,7 +296,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d ] for j in range(len(batch)): - if not pretrain: + if not calculate_overall_loss: if isinstance(batch[j]["output"], list): batch[j]["output"].append(batch_decodes[j].strip()) else: @@ -361,7 +361,7 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str return output_strs, scores @torch.no_grad() - def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool) -> List[List[float]]: + def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool) -> List[List[float]]: """ Calculate loss only on target tokens. @@ -378,13 +378,13 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], pretr # We don't need to generate new tokens. # Target answer's length is usually << model_max_length, but we still call it in case. # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. - if not pretrain: + if not calculate_overall_loss: batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] # Get the number of target answers for different questions batch_target_nums = [len(prompt_target) for prompt_target in batch_target] - if pretrain: + if calculate_overall_loss: batch = [] bytes_list = [] batch_prompt_pretrain = [] From 24c4c03d0e890684334810a521f19691f660b5bc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Sep 2024 09:41:07 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../ColossalEval/colossal_eval/models/huggingface.py | 8 ++++++-- applications/ColossalEval/colossal_eval/models/vllm.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index af5f66098058..68c65948c88b 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -496,7 +496,9 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str return decoded_sequences, scores @torch.no_grad() - def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool) -> List[List[float]]: + def get_loss( + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool + ) -> List[List[float]]: """ Calculate loss only on target tokens. @@ -519,7 +521,9 @@ def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], calcu # Get the number of target answers for different questions batch_target_nums = [len(prompt_target) for prompt_target in batch_target] - input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels(batch_prompt, batch_target, calculate_overall_loss) + input_ids_list, labels_list, bytes_list = self._get_input_ids_and_labels( + batch_prompt, batch_target, calculate_overall_loss + ) # Because of multiple target answers, the final batch size may be greater than self.batch_size. # We will generate new batches. diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py index 002ea1331bb7..9ee02501790b 100644 --- a/applications/ColossalEval/colossal_eval/models/vllm.py +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -363,7 +363,9 @@ def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str return output_strs, scores @torch.no_grad() - def get_loss(self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool) -> List[List[float]]: + def get_loss( + self, batch_prompt: List[str], batch_target: List[List[str]], calculate_overall_loss: bool + ) -> List[List[float]]: """ Calculate loss only on target tokens. From 39c9c232451e556348770cbd42e90282b288be21 Mon Sep 17 00:00:00 2001 From: Camille7777 <44392324+Camille7777@users.noreply.github.com> Date: Tue, 17 Sep 2024 06:56:43 +0000 Subject: [PATCH 09/12] refine code --- .../ColossalEval/colossal_eval/dataset/agieval.py | 2 +- applications/ColossalEval/colossal_eval/dataset/ceval.py | 2 +- applications/ColossalEval/colossal_eval/dataset/cmmlu.py | 2 +- .../ColossalEval/colossal_eval/dataset/colossalai.py | 2 +- .../ColossalEval/colossal_eval/dataset/cvalues.py | 2 +- .../ColossalEval/colossal_eval/dataset/gaokaobench.py | 2 +- applications/ColossalEval/colossal_eval/dataset/gsm.py | 4 ++-- .../ColossalEval/colossal_eval/dataset/longbench.py | 2 +- applications/ColossalEval/colossal_eval/dataset/mmlu.py | 2 +- .../ColossalEval/colossal_eval/dataset/mtbench.py | 2 +- .../ColossalEval/colossal_eval/dataset/safetybench_en.py | 2 +- .../ColossalEval/colossal_eval/dataset/safetybench_zh.py | 2 +- .../ColossalEval/colossal_eval/models/huggingface.py | 8 +++++++- applications/ColossalEval/colossal_eval/models/vllm.py | 2 +- 14 files changed, 21 insertions(+), 15 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index c1cfe37d7599..07597048d7f9 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -47,7 +47,7 @@ "calculate_loss": True, "all_classes": None, "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py index 1023d1e23c1f..b15dd93afc87 100644 --- a/applications/ColossalEval/colossal_eval/dataset/ceval.py +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -70,7 +70,7 @@ "calculate_loss": False, "all_classes": ["A", "B", "C", "D"], "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py index 05752c2486fa..402a2d4c8eab 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py @@ -81,7 +81,7 @@ "calculate_loss": True, "all_classes": ["A", "B", "C", "D"], "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py index 0337454fa788..266eaef3f486 100644 --- a/applications/ColossalEval/colossal_eval/dataset/colossalai.py +++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py @@ -12,7 +12,7 @@ "calculate_loss": False, "all_classes": None, "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 256, } diff --git a/applications/ColossalEval/colossal_eval/dataset/cvalues.py b/applications/ColossalEval/colossal_eval/dataset/cvalues.py index 4023a4c76322..f5b81f90ed3f 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cvalues.py +++ b/applications/ColossalEval/colossal_eval/dataset/cvalues.py @@ -15,7 +15,7 @@ "calculate_loss": False, "all_classes": ["A", "B"], "language": LANGUAGE, - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py index 44ccea9cfa2c..533e9b4bfa52 100644 --- a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py +++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py @@ -36,7 +36,7 @@ "calculate_loss": True, "all_classes": None, "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/gsm.py b/applications/ColossalEval/colossal_eval/dataset/gsm.py index 775c5843ff79..a639201053ef 100644 --- a/applications/ColossalEval/colossal_eval/dataset/gsm.py +++ b/applications/ColossalEval/colossal_eval/dataset/gsm.py @@ -72,7 +72,7 @@ "calculate_loss": True, "all_classes": None, "language": "English", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 256, } @@ -114,7 +114,7 @@ def load( dataset[split][subject]["inference_kwargs"] = copy.deepcopy(default_inference_kwargs) if forward_only: - dataset[split][subject]["inference_kwargs"]["pretrain"] = True + dataset[split][subject]["inference_kwargs"]["calculate_overall_loss"] = True if split == "test" and few_shot: dataset[split][subject]["inference_kwargs"]["few_shot_data"] = get_few_shot_data() diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py index eb61efaa0d7c..e663e5e108e6 100644 --- a/applications/ColossalEval/colossal_eval/dataset/longbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py @@ -60,7 +60,7 @@ "calculate_loss": True, "all_classes": None, "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py index e9465c91b3ce..5e3ff6af6ef3 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py @@ -11,7 +11,7 @@ "calculate_loss": True, "all_classes": ["A", "B", "C", "D"], "language": "English", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py index ef474ec4ca23..abec8ebfb038 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mtbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py @@ -14,7 +14,7 @@ "calculate_loss": False, "all_classes": None, "language": "English", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 1024, "turns": 2, } diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py index 8056c3dfd8bf..494bb0993ccf 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py @@ -28,7 +28,7 @@ "calculate_loss": False, "all_classes": ["A", "B", "C", "D"], "language": LANGUAGE, - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py index f5f17e64c991..8c41664c02c8 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py @@ -28,7 +28,7 @@ "calculate_loss": False, "all_classes": ["A", "B", "C", "D"], "language": LANGUAGE, - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32, } diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index af5f66098058..b3a0faf9d92f 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -105,6 +105,12 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kw elif hasattr(self.tokenizer, "eod_id"): # Qwen has an eod token "<|endoftext|>". self.tokenizer.pad_token_id = self.tokenizer.eod_id + else: + self.logger.error("Neither eos_token nor eod_id is available for setting pad_token_id.") + raise ValueError( + "The tokenizer does not have a pad_token_id, eos_token, or eod_id. " + "Please set pad_token_id manually." + ) def _load_model( self, path: str, model_kwargs: dict, peft_path: Optional[str] = None, shard_config: ShardConfig = None @@ -342,7 +348,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d calculate_loss = inference_kwargs["calculate_loss"] classes = inference_kwargs["all_classes"] language = inference_kwargs["language"] - calculate_overall_loss = inference_kwargs["pretrain"] + calculate_overall_loss = inference_kwargs["calculate_overall_loss"] max_new_tokens = inference_kwargs["max_new_tokens"] few_shot_data = inference_kwargs.get("few_shot_data", None) diff --git a/applications/ColossalEval/colossal_eval/models/vllm.py b/applications/ColossalEval/colossal_eval/models/vllm.py index 002ea1331bb7..a798dbe6c45e 100644 --- a/applications/ColossalEval/colossal_eval/models/vllm.py +++ b/applications/ColossalEval/colossal_eval/models/vllm.py @@ -231,7 +231,7 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d calculate_loss = inference_kwargs["calculate_loss"] classes = inference_kwargs["all_classes"] language = inference_kwargs["language"] - calculate_overall_loss = inference_kwargs["pretrain"] + calculate_overall_loss = inference_kwargs["calculate_overall_loss"] max_new_tokens = inference_kwargs["max_new_tokens"] few_shot_data = inference_kwargs.get("few_shot_data", None) From 32fc3f71e2a7eaf06e725a416e7c9cca6daeea96 Mon Sep 17 00:00:00 2001 From: Camille7777 <44392324+Camille7777@users.noreply.github.com> Date: Tue, 17 Sep 2024 08:03:32 +0000 Subject: [PATCH 10/12] update readme --- applications/ColossalEval/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md index 0586f31dbf94..1429b263adfe 100644 --- a/applications/ColossalEval/README.md +++ b/applications/ColossalEval/README.md @@ -154,7 +154,7 @@ inference_kwargs = { "calculate_loss": True, "all_classes": ["A", "B", "C", "D"], "language": "Chinese", - "pretrain": False, + "calculate_overall_loss": False, "max_new_tokens": 32 } ``` @@ -163,7 +163,7 @@ The `inference_kwargs` currently contains 5 fields: - `calculate_loss` (bool, compulsory): Whether the loss on target tokens will be calculated - `all_classes` (Optional[list], compulsory): Whether the subcategory is a single-choice question. Specify all available options in a list or otherwise None. - `language` (str, compulsory): The language for the subcategory. -- `pretrain` (bool, compulsory): Whether the dataset is a pretrain dataset or not. It is usually used for calculate perplexity when you want to evaluate a model with extended context length. +- `calculate_overall_loss` (bool, compulsory): Whether to calculate the overall loss of sentences or not if the dataset is a pretrain dataset. It is usually used for calculate perplexity when you want to evaluate a model with extended context length. - `max_new_tokens` (int, compulsory): The number of new tokens to generate during inference. For example, for dataset MMLU, each subcategory consists of single-choice questions with options A, B, C and D by default and we can assign value `["A", "B", "C", "D"]` to key`all_classes`. For dataset C-Eval, target answers aren't provided in the test split so `calculate_loss` should be set as False. However, other dataset such as GAOKAO-bench contains different formats of questions and lacks some keys or metadata which can reveal what type (single-choice or multi-choice) of questions it is. Before assigning inference arguments, we first parse the dataset to decide which type of questions the subcategory belongs to and set the inference arguments accordingly. From 3b815f626e221424c35a38953aeefc20950bcad4 Mon Sep 17 00:00:00 2001 From: Camille7777 <44392324+Camille7777@users.noreply.github.com> Date: Wed, 18 Sep 2024 08:41:59 +0000 Subject: [PATCH 11/12] refine code --- applications/ColossalEval/README.md | 4 ---- .../ColossalEval/examples/dataset_evaluation/inference.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/applications/ColossalEval/README.md b/applications/ColossalEval/README.md index 1429b263adfe..bc5394a69a44 100644 --- a/applications/ColossalEval/README.md +++ b/applications/ColossalEval/README.md @@ -565,10 +565,6 @@ class CustomizedModel(BaseModel): Once you have successfully added your own model, you can specify your model class in your inference config. -## To do - -- [ ] Add visualization code for evaluation results on public dataset -- [ ] Improve the way to label target tokens ## Citations diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index c651970ee37c..943f610089d9 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -69,7 +69,7 @@ def rm_and_merge( os.remove(directory) except Exception as e: print(e) - print(len(answers["data"])) + all_answers[category] = answers all_answers_with_dataset_class["inference_results"] = all_answers From 56bc1863fbf9b6cb022152388eb0098ff71e37dd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 08:43:10 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../ColossalEval/examples/dataset_evaluation/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 943f610089d9..1d3f13745474 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -69,7 +69,7 @@ def rm_and_merge( os.remove(directory) except Exception as e: print(e) - + all_answers[category] = answers all_answers_with_dataset_class["inference_results"] = all_answers