From ead4bb0cd1156736a2c31f558064785a71bbf935 Mon Sep 17 00:00:00 2001 From: Michelle Date: Tue, 16 Jul 2024 07:25:26 +0000 Subject: [PATCH 01/10] support auto distributed data loader --- .../colossal_eval/dataset/agieval.py | 2 +- .../colossal_eval/dataset/base.py | 18 ++++++-- .../colossal_eval/dataset/ceval.py | 2 +- .../colossal_eval/dataset/cmmlu.py | 2 +- .../colossal_eval/dataset/colossalai.py | 2 +- .../colossal_eval/dataset/cvalues.py | 2 +- .../colossal_eval/dataset/gaokaobench.py | 2 +- .../colossal_eval/dataset/longbench.py | 2 +- .../colossal_eval/dataset/mmlu.py | 2 +- .../colossal_eval/dataset/mtbench.py | 6 +-- .../colossal_eval/dataset/safetybench_en.py | 2 +- .../colossal_eval/dataset/safetybench_zh.py | 2 +- .../colossal_eval/models/huggingface.py | 44 ++++++++++--------- .../colossal_eval/utils/conversation.py | 5 +-- .../examples/dataset_evaluation/inference.py | 39 +++++++++------- 15 files changed, 76 insertions(+), 56 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index d5f2302494e8..43d4cc222647 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -198,7 +198,7 @@ class AGIEvalDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs ) -> List[Dict]: dataset = {"test": {}} diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py index 531313d7e3c0..210be4920dda 100644 --- a/applications/ColossalEval/colossal_eval/dataset/base.py +++ b/applications/ColossalEval/colossal_eval/dataset/base.py @@ -1,6 +1,8 @@ from abc import abstractstaticmethod +from torch.utils.data import Dataset from colossal_eval.utils import jdump +from colossalai.logging import DistributedLogger class BaseDataset: @@ -12,13 +14,23 @@ class BaseDataset: logger: Logger for the dataset. """ - def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False): - self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference) + def __init__(self, path, logger, *args, **kwargs): + self.dataset = self.load(path, logger, *args, **kwargs) def save(self, save_path): """Save the converted dataset""" jdump(self.dataset, save_path) @abstractstaticmethod - def load(path, logger): + def load(path, logger: DistributedLogger, *args, **kwargs): """Load the original dataset and convert it into the inference dataset""" + +class DistributedDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py index 915f4d9b0850..3357b1131403 100644 --- a/applications/ColossalEval/colossal_eval/dataset/ceval.py +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -91,7 +91,7 @@ class CEvalDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs ) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py index 477280663218..8025a7d98cca 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py @@ -102,7 +102,7 @@ class CMMLUDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs ) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py index 54ea478ae5d6..0337454fa788 100644 --- a/applications/ColossalEval/colossal_eval/dataset/colossalai.py +++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py @@ -37,7 +37,7 @@ class ColossalDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} data = jload(path) data_per_category = get_data_per_category(data) diff --git a/applications/ColossalEval/colossal_eval/dataset/cvalues.py b/applications/ColossalEval/colossal_eval/dataset/cvalues.py index 30e802a028c8..4023a4c76322 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cvalues.py +++ b/applications/ColossalEval/colossal_eval/dataset/cvalues.py @@ -28,7 +28,7 @@ class CValuesDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl") data_list = [] diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py index cda6276bfe05..73bbf4fbd856 100644 --- a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py +++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py @@ -70,7 +70,7 @@ class GaoKaoBenchDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, *args, **kwargs ) -> List[Dict]: dataset = {"test": {}} for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]: diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py index 9ea5e3c7d77f..eb61efaa0d7c 100644 --- a/applications/ColossalEval/colossal_eval/dataset/longbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py @@ -77,7 +77,7 @@ class LongBenchDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} files = os.listdir(path) diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py index dcda68e8f5ac..44daf90444c3 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py @@ -32,7 +32,7 @@ class MMLUDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs ) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py index 03141556788f..ef474ec4ca23 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mtbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py @@ -27,12 +27,12 @@ class MTBenchDataset(BaseDataset): This dataset class will convert the original dataset into the inference dataset. """ - def __init__(self, path, logger, few_shot): + def __init__(self, path, logger: DistributedLogger, *args, **kwargs): self.multiturn = True - self.dataset = self.load(path, logger, few_shot) + self.dataset = self.load(path, logger, *args, **kwargs) @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": defaultdict(dict)} file_path = os.path.join(path, "question.jsonl") diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py index e77a3da34060..8056c3dfd8bf 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py @@ -130,7 +130,7 @@ class SafetyBenchENDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} data_files = [os.path.join(path, file_name) for file_name in FILES] for file_path in data_files: diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py index 3eca808bbc5b..f5f17e64c991 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py @@ -130,7 +130,7 @@ class SafetyBenchZHDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} data_files = [os.path.join(path, file_name) for file_name in FILES] for file_path in data_files: diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 23c399ccedbd..102b87a03fbf 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -4,6 +4,7 @@ import numpy as np import torch +from torch.utils.data import DataLoader from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 from peft import PeftModel from tqdm import tqdm @@ -325,7 +326,7 @@ def _get_input_ids_and_labels( return input_ids_list, labels_list, None - def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: + def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: """ Infer the given data. This function will call self.generate() to get model outputs and also self.model() to get logits. @@ -359,26 +360,28 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} - turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1 + for sample_data in data_loader: + break + turn = 0 if not isinstance(sample_data[0]["output"], list) else len(sample_data[0]["output"]) + 1 turn_desc = "" if turn == 0 else f"-turn{turn}" bar = tqdm( - range(math.ceil(len(data) / self.batch_size)), - desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps", + range(len(data_loader)), + desc=f"{sample_data[0]['dataset']}-{sample_data[0]['category']}{turn_desc} Inference steps", disable=not is_rank_0(), ) loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - answers = copy.deepcopy(data) - for i in range(0, len(data), self.batch_size): - batch = data[i : i + self.batch_size] + answers = [] + + for i, batch in enumerate(data_loader): batch_prompt, batch_target = get_batch_prompt( - self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length + self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length ) if is_rank_0() and debug and i == 0: self.logger.info( - f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}" + f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}" ) self.logger.info("-" * 120) self.logger.info("An example prompt and prompt with target is:") @@ -402,7 +405,7 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b # Otherwise this will violate the single-choice setting. if calculate_loss: - labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))] + labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))] loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist() @@ -411,29 +414,30 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs)) ] - for j in range(len(batch_prompt)): + for j in range(len(batch)): if not pretrain: - if isinstance(answers[i + j]["output"], list): - answers[i + j]["output"].append(batch_decodes[j].strip()) + if isinstance(batch[j]["output"], list): + batch[j]["output"].append(batch_decodes[j].strip()) else: - answers[i + j]["output"] = batch_decodes[j].strip() + batch[j]["output"] = batch_decodes[j].strip() if isinstance(scores, torch.Tensor): - answers[i + j]["logits_over_choices"] = probs[j] + batch[j]["logits_over_choices"] = probs[j] if calculate_loss: - answers[i + j]["loss_over_choices"] = loss_over_choices[j] + batch[j]["loss_over_choices"] = loss_over_choices[j] if calculate_loss: - answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() + batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity. # However, loss (which is per sample loss) suffices for most cases. - answers[i + j]["loss_sum"] = batch_losses[j] - answers[i + j]["token_num"] = batch_target_token_nums[j] + batch[j]["loss_sum"] = batch_losses[j] + batch[j]["token_num"] = batch_target_token_nums[j] if batch_bytes_nums: - answers[i + j]["byte_num"] = batch_bytes_nums[j] + batch[j]["byte_num"] = batch_bytes_nums[j] + answers.extend(batch) bar.update() diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py index 330083aa6a61..3898517ccab8 100644 --- a/applications/ColossalEval/colossal_eval/utils/conversation.py +++ b/applications/ColossalEval/colossal_eval/utils/conversation.py @@ -124,7 +124,7 @@ def dict(self): def get_few_shot_prefix( - conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int + few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int ) -> str: """ Get few shot prefix. @@ -157,7 +157,6 @@ def get_batch_prompt( batch: List[Dict], few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], - language: Optional[str], model_max_length: Optional[int], ) -> Tuple[List[Dict], List[Dict]]: """ @@ -192,7 +191,7 @@ def get_batch_prompt( else: raise Exception("When using few-shot, target answer should be a string.") - few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens) + few_shot_prefix = get_few_shot_prefix(few_shot_data, tokenizer, max_tokens) conv.append_message(conv.roles[0], few_shot_prefix + query_text) conv.append_message(conv.roles[1], None) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index a7307635d333..6a4592dfa9a0 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -4,6 +4,8 @@ from typing import Dict, List import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler +from colossal_eval.dataset.base import DistributedDataset from colossal_eval import dataset, models, utils import colossalai @@ -13,6 +15,7 @@ from colossalai.shardformer import ShardConfig logger = get_dist_logger() +os.environ["TOKENIZERS_PARALLELISM"] = "false" def rm_and_merge( @@ -35,6 +38,8 @@ def rm_and_merge( """ for model_name in model_names: + dataset_cat_num_mapping = utils.jload(os.path.join(save_path, model_name, "dataset_cat_num_mapping.json")) + for dataset_name, categories in dataset_names.items(): all_answers_with_dataset_class = {} all_answers_with_dataset_class["dataset_class"] = dataset_classes[dataset_name] @@ -66,6 +71,8 @@ def rm_and_merge( except Exception as e: print(e) + total_num = dataset_cat_num_mapping[dataset_name][category] + answers["data"] = answers["data"][:total_num] all_answers[category] = answers all_answers_with_dataset_class["inference_results"] = all_answers @@ -75,6 +82,7 @@ def rm_and_merge( all_answers_with_dataset_class, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"), ) + os.remove(os.path.join(save_path, model_name, "dataset_cat_num_mapping.json")) logger.info(f"Save inference results of model {model_name} for all dataset.") logger.info(f"Save inference results of all models for all dataset.") @@ -118,6 +126,7 @@ def main(args): debug_args = {} few_shot_args = {} multiturn_args = {} + dataset_cat_num_mapping = {} config = utils.jload(args.config) @@ -183,6 +192,7 @@ def main(args): model_name = model_parameter["name"] model_class = eval(f"models.{model_parameter['model_class']}") paramerters = model_parameter["parameters"] + batch_size = paramerters["batch_size"] paramerters.update({"logger": logger}) paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]}) paramerters.update({"shard_config": shard_config}) @@ -192,35 +202,26 @@ def main(args): raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.") for dataset_name, split_data in inference_data.items(): - start = 0 + cat_num_mapping = {} prev_questions = None for category, category_data in split_data.items(): num_turn = category_data["inference_kwargs"].get("turns", 1) + cat_num_mapping[category] = len(category_data["data"]) if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None: raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") answers_to_dump = copy.deepcopy(category_data) - partition_size = len(category_data["data"]) // dp_size - redundant = len(category_data["data"]) % dp_size - - # Ensure that the amount of data for inference is as consistent as possible across different processes. - lengths = [partition_size for _ in range(dp_size)] - for j in range(redundant): - lengths[(j + start) % dp_size] += 1 - - start = (start + redundant) % dp_size - for turn in range(num_turn): if turn == 0: - questions = category_data["data"][ - sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank] - ] + dist_dataset = DistributedDataset(category_data["data"]) else: - questions = prev_questions - + dist_dataset = DistributedDataset(prev_questions) + + sampler = DistributedSampler(dist_dataset, num_replicas=world_size, rank=rank, shuffle=False) + questions_loader = DataLoader(dist_dataset, batch_size=batch_size, sampler=sampler, num_workers=8, pin_memory=True, collate_fn=lambda x: x) answers_per_rank = model_.inference( - questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + data_loader=questions_loader, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] ) prev_questions = answers_per_rank @@ -236,11 +237,15 @@ def main(args): ), ) + dataset_cat_num_mapping[dataset_name] = cat_num_mapping + logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB") del model_ accelerator.empty_cache() + utils.jdump(dataset_cat_num_mapping, os.path.join(args.inference_save_path, model_name, "dataset_cat_num_mapping.json")) + dist.barrier() if rank == 0: model_names = [model_parameter["name"] for model_parameter in model_parameters] From f32698f9f87db28b1b5387bae0d4e29da523f6f3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 07:37:25 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../colossal_eval/dataset/agieval.py | 4 +--- .../colossal_eval/dataset/base.py | 8 ++++--- .../colossal_eval/dataset/ceval.py | 4 +--- .../colossal_eval/dataset/cmmlu.py | 4 +--- .../colossal_eval/dataset/gaokaobench.py | 4 +--- .../colossal_eval/dataset/mmlu.py | 4 +--- .../colossal_eval/models/huggingface.py | 3 +-- .../colossal_eval/utils/conversation.py | 4 +--- .../examples/dataset_evaluation/inference.py | 23 ++++++++++++++----- 9 files changed, 29 insertions(+), 29 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index 43d4cc222647..c1cfe37d7599 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -197,9 +197,7 @@ class AGIEvalDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} files = glob.glob(os.path.join(path, "*.jsonl")) diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py index 210be4920dda..a29f56fd1998 100644 --- a/applications/ColossalEval/colossal_eval/dataset/base.py +++ b/applications/ColossalEval/colossal_eval/dataset/base.py @@ -1,7 +1,8 @@ from abc import abstractstaticmethod -from torch.utils.data import Dataset from colossal_eval.utils import jdump +from torch.utils.data import Dataset + from colossalai.logging import DistributedLogger @@ -25,12 +26,13 @@ def save(self, save_path): def load(path, logger: DistributedLogger, *args, **kwargs): """Load the original dataset and convert it into the inference dataset""" + class DistributedDataset(Dataset): def __init__(self, data): self.data = data - + def __len__(self): return len(self.data) - + def __getitem__(self, idx): return self.data[idx] diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py index 3357b1131403..1023d1e23c1f 100644 --- a/applications/ColossalEval/colossal_eval/dataset/ceval.py +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -90,9 +90,7 @@ class CEvalDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py index 8025a7d98cca..05752c2486fa 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py @@ -101,9 +101,7 @@ class CMMLUDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py index 73bbf4fbd856..44ccea9cfa2c 100644 --- a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py +++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py @@ -69,9 +69,7 @@ class GaoKaoBenchDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]: files = os.listdir(os.path.join(path, "data", category)) diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py index 44daf90444c3..e9465c91b3ce 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py @@ -31,9 +31,7 @@ class MMLUDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 102b87a03fbf..a8b57421d926 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -1,12 +1,11 @@ import copy -import math from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch -from torch.utils.data import DataLoader from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 from peft import PeftModel +from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py index 3898517ccab8..4cbc5326c79b 100644 --- a/applications/ColossalEval/colossal_eval/utils/conversation.py +++ b/applications/ColossalEval/colossal_eval/utils/conversation.py @@ -123,9 +123,7 @@ def dict(self): } -def get_few_shot_prefix( - few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int -) -> str: +def get_few_shot_prefix(few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int) -> str: """ Get few shot prefix. diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 6a4592dfa9a0..b3449ac9d459 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -4,9 +4,9 @@ from typing import Dict, List import torch.distributed as dist -from torch.utils.data import DataLoader, DistributedSampler -from colossal_eval.dataset.base import DistributedDataset from colossal_eval import dataset, models, utils +from colossal_eval.dataset.base import DistributedDataset +from torch.utils.data import DataLoader, DistributedSampler import colossalai from colossalai.accelerator import get_accelerator @@ -217,11 +217,20 @@ def main(args): dist_dataset = DistributedDataset(category_data["data"]) else: dist_dataset = DistributedDataset(prev_questions) - + sampler = DistributedSampler(dist_dataset, num_replicas=world_size, rank=rank, shuffle=False) - questions_loader = DataLoader(dist_dataset, batch_size=batch_size, sampler=sampler, num_workers=8, pin_memory=True, collate_fn=lambda x: x) + questions_loader = DataLoader( + dist_dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=8, + pin_memory=True, + collate_fn=lambda x: x, + ) answers_per_rank = model_.inference( - data_loader=questions_loader, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + data_loader=questions_loader, + inference_kwargs=category_data["inference_kwargs"], + debug=debug_args[dataset_name], ) prev_questions = answers_per_rank @@ -244,7 +253,9 @@ def main(args): del model_ accelerator.empty_cache() - utils.jdump(dataset_cat_num_mapping, os.path.join(args.inference_save_path, model_name, "dataset_cat_num_mapping.json")) + utils.jdump( + dataset_cat_num_mapping, os.path.join(args.inference_save_path, model_name, "dataset_cat_num_mapping.json") + ) dist.barrier() if rank == 0: From 613de866db4431f990feba3263e57bd567e156c6 Mon Sep 17 00:00:00 2001 From: Michelle Date: Tue, 16 Jul 2024 07:25:26 +0000 Subject: [PATCH 03/10] support auto distributed data loader --- .../colossal_eval/dataset/agieval.py | 2 +- .../colossal_eval/dataset/base.py | 18 ++++++-- .../colossal_eval/dataset/ceval.py | 2 +- .../colossal_eval/dataset/cmmlu.py | 2 +- .../colossal_eval/dataset/colossalai.py | 2 +- .../colossal_eval/dataset/cvalues.py | 2 +- .../colossal_eval/dataset/gaokaobench.py | 2 +- .../colossal_eval/dataset/longbench.py | 2 +- .../colossal_eval/dataset/mmlu.py | 2 +- .../colossal_eval/dataset/mtbench.py | 6 +-- .../colossal_eval/dataset/safetybench_en.py | 2 +- .../colossal_eval/dataset/safetybench_zh.py | 2 +- .../colossal_eval/models/huggingface.py | 44 ++++++++++--------- .../colossal_eval/utils/conversation.py | 5 +-- .../examples/dataset_evaluation/inference.py | 39 +++++++++------- 15 files changed, 76 insertions(+), 56 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index d5f2302494e8..43d4cc222647 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -198,7 +198,7 @@ class AGIEvalDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs ) -> List[Dict]: dataset = {"test": {}} diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py index 531313d7e3c0..210be4920dda 100644 --- a/applications/ColossalEval/colossal_eval/dataset/base.py +++ b/applications/ColossalEval/colossal_eval/dataset/base.py @@ -1,6 +1,8 @@ from abc import abstractstaticmethod +from torch.utils.data import Dataset from colossal_eval.utils import jdump +from colossalai.logging import DistributedLogger class BaseDataset: @@ -12,13 +14,23 @@ class BaseDataset: logger: Logger for the dataset. """ - def __init__(self, path, logger, few_shot, forward_only=False, load_train=False, load_reference=False): - self.dataset = self.load(path, logger, few_shot, forward_only, load_train, load_reference) + def __init__(self, path, logger, *args, **kwargs): + self.dataset = self.load(path, logger, *args, **kwargs) def save(self, save_path): """Save the converted dataset""" jdump(self.dataset, save_path) @abstractstaticmethod - def load(path, logger): + def load(path, logger: DistributedLogger, *args, **kwargs): """Load the original dataset and convert it into the inference dataset""" + +class DistributedDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py index 915f4d9b0850..3357b1131403 100644 --- a/applications/ColossalEval/colossal_eval/dataset/ceval.py +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -91,7 +91,7 @@ class CEvalDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs ) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py index 477280663218..8025a7d98cca 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py @@ -102,7 +102,7 @@ class CMMLUDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs ) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: diff --git a/applications/ColossalEval/colossal_eval/dataset/colossalai.py b/applications/ColossalEval/colossal_eval/dataset/colossalai.py index 54ea478ae5d6..0337454fa788 100644 --- a/applications/ColossalEval/colossal_eval/dataset/colossalai.py +++ b/applications/ColossalEval/colossal_eval/dataset/colossalai.py @@ -37,7 +37,7 @@ class ColossalDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} data = jload(path) data_per_category = get_data_per_category(data) diff --git a/applications/ColossalEval/colossal_eval/dataset/cvalues.py b/applications/ColossalEval/colossal_eval/dataset/cvalues.py index 30e802a028c8..4023a4c76322 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cvalues.py +++ b/applications/ColossalEval/colossal_eval/dataset/cvalues.py @@ -28,7 +28,7 @@ class CValuesDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} file_path = os.path.join(path, "cvalues_responsibility_mc.jsonl") data_list = [] diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py index cda6276bfe05..73bbf4fbd856 100644 --- a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py +++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py @@ -70,7 +70,7 @@ class GaoKaoBenchDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, *args, **kwargs ) -> List[Dict]: dataset = {"test": {}} for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]: diff --git a/applications/ColossalEval/colossal_eval/dataset/longbench.py b/applications/ColossalEval/colossal_eval/dataset/longbench.py index 9ea5e3c7d77f..eb61efaa0d7c 100644 --- a/applications/ColossalEval/colossal_eval/dataset/longbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/longbench.py @@ -77,7 +77,7 @@ class LongBenchDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} files = os.listdir(path) diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py index dcda68e8f5ac..44daf90444c3 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py @@ -32,7 +32,7 @@ class MMLUDataset(BaseDataset): @staticmethod def load( - path: str, logger: DistributedLogger, few_shot: bool, forward_only: bool, load_train: bool, load_reference: bool + path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs ) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py index 03141556788f..ef474ec4ca23 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mtbench.py +++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py @@ -27,12 +27,12 @@ class MTBenchDataset(BaseDataset): This dataset class will convert the original dataset into the inference dataset. """ - def __init__(self, path, logger, few_shot): + def __init__(self, path, logger: DistributedLogger, *args, **kwargs): self.multiturn = True - self.dataset = self.load(path, logger, few_shot) + self.dataset = self.load(path, logger, *args, **kwargs) @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": defaultdict(dict)} file_path = os.path.join(path, "question.jsonl") diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py index e77a3da34060..8056c3dfd8bf 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_en.py @@ -130,7 +130,7 @@ class SafetyBenchENDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} data_files = [os.path.join(path, file_name) for file_name in FILES] for file_path in data_files: diff --git a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py index 3eca808bbc5b..f5f17e64c991 100644 --- a/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py +++ b/applications/ColossalEval/colossal_eval/dataset/safetybench_zh.py @@ -130,7 +130,7 @@ class SafetyBenchZHDataset(BaseDataset): """ @staticmethod - def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} data_files = [os.path.join(path, file_name) for file_name in FILES] for file_path in data_files: diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 23c399ccedbd..102b87a03fbf 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -4,6 +4,7 @@ import numpy as np import torch +from torch.utils.data import DataLoader from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 from peft import PeftModel from tqdm import tqdm @@ -325,7 +326,7 @@ def _get_input_ids_and_labels( return input_ids_list, labels_list, None - def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: + def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], debug: bool = False) -> List[Dict]: """ Infer the given data. This function will call self.generate() to get model outputs and also self.model() to get logits. @@ -359,26 +360,28 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} - turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1 + for sample_data in data_loader: + break + turn = 0 if not isinstance(sample_data[0]["output"], list) else len(sample_data[0]["output"]) + 1 turn_desc = "" if turn == 0 else f"-turn{turn}" bar = tqdm( - range(math.ceil(len(data) / self.batch_size)), - desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps", + range(len(data_loader)), + desc=f"{sample_data[0]['dataset']}-{sample_data[0]['category']}{turn_desc} Inference steps", disable=not is_rank_0(), ) loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - answers = copy.deepcopy(data) - for i in range(0, len(data), self.batch_size): - batch = data[i : i + self.batch_size] + answers = [] + + for i, batch in enumerate(data_loader): batch_prompt, batch_target = get_batch_prompt( - self.prompt_template, batch, few_shot_data, self.tokenizer, language, self.model_max_length + self.prompt_template, batch, few_shot_data, self.tokenizer, self.model_max_length ) if is_rank_0() and debug and i == 0: self.logger.info( - f"Inference arguments for dataset {data[0]['dataset']} category {data[0]['category']} is:\n{inference_kwargs}" + f"Inference arguments for dataset {batch[0]['dataset']} category {batch[0]['category']} is:\n{inference_kwargs}" ) self.logger.info("-" * 120) self.logger.info("An example prompt and prompt with target is:") @@ -402,7 +405,7 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b # Otherwise this will violate the single-choice setting. if calculate_loss: - labels = [self.str_label_map[answers[i + j]["target"]] for j in range(len(batch_decodes))] + labels = [self.str_label_map[batch[j]["target"]] for j in range(len(batch))] loss_over_choices = loss_fct(scores, torch.tensor(labels, dtype=torch.long)).numpy().tolist() @@ -411,29 +414,30 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b {choice: probs[i][self.str_label_map[choice]] for choice in self.choices} for i in range(len(probs)) ] - for j in range(len(batch_prompt)): + for j in range(len(batch)): if not pretrain: - if isinstance(answers[i + j]["output"], list): - answers[i + j]["output"].append(batch_decodes[j].strip()) + if isinstance(batch[j]["output"], list): + batch[j]["output"].append(batch_decodes[j].strip()) else: - answers[i + j]["output"] = batch_decodes[j].strip() + batch[j]["output"] = batch_decodes[j].strip() if isinstance(scores, torch.Tensor): - answers[i + j]["logits_over_choices"] = probs[j] + batch[j]["logits_over_choices"] = probs[j] if calculate_loss: - answers[i + j]["loss_over_choices"] = loss_over_choices[j] + batch[j]["loss_over_choices"] = loss_over_choices[j] if calculate_loss: - answers[i + j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() + batch[j]["loss"] = (np.array(batch_losses[j]) / np.array(batch_target_token_nums[j])).tolist() # loss_sum is specially used for pertrain dataset for calculating per-byte-perplexity. # However, loss (which is per sample loss) suffices for most cases. - answers[i + j]["loss_sum"] = batch_losses[j] - answers[i + j]["token_num"] = batch_target_token_nums[j] + batch[j]["loss_sum"] = batch_losses[j] + batch[j]["token_num"] = batch_target_token_nums[j] if batch_bytes_nums: - answers[i + j]["byte_num"] = batch_bytes_nums[j] + batch[j]["byte_num"] = batch_bytes_nums[j] + answers.extend(batch) bar.update() diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py index 330083aa6a61..3898517ccab8 100644 --- a/applications/ColossalEval/colossal_eval/utils/conversation.py +++ b/applications/ColossalEval/colossal_eval/utils/conversation.py @@ -124,7 +124,7 @@ def dict(self): def get_few_shot_prefix( - conv: Conversation, few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], language: str, max_tokens: int + few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int ) -> str: """ Get few shot prefix. @@ -157,7 +157,6 @@ def get_batch_prompt( batch: List[Dict], few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], - language: Optional[str], model_max_length: Optional[int], ) -> Tuple[List[Dict], List[Dict]]: """ @@ -192,7 +191,7 @@ def get_batch_prompt( else: raise Exception("When using few-shot, target answer should be a string.") - few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens) + few_shot_prefix = get_few_shot_prefix(few_shot_data, tokenizer, max_tokens) conv.append_message(conv.roles[0], few_shot_prefix + query_text) conv.append_message(conv.roles[1], None) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index a7307635d333..6a4592dfa9a0 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -4,6 +4,8 @@ from typing import Dict, List import torch.distributed as dist +from torch.utils.data import DataLoader, DistributedSampler +from colossal_eval.dataset.base import DistributedDataset from colossal_eval import dataset, models, utils import colossalai @@ -13,6 +15,7 @@ from colossalai.shardformer import ShardConfig logger = get_dist_logger() +os.environ["TOKENIZERS_PARALLELISM"] = "false" def rm_and_merge( @@ -35,6 +38,8 @@ def rm_and_merge( """ for model_name in model_names: + dataset_cat_num_mapping = utils.jload(os.path.join(save_path, model_name, "dataset_cat_num_mapping.json")) + for dataset_name, categories in dataset_names.items(): all_answers_with_dataset_class = {} all_answers_with_dataset_class["dataset_class"] = dataset_classes[dataset_name] @@ -66,6 +71,8 @@ def rm_and_merge( except Exception as e: print(e) + total_num = dataset_cat_num_mapping[dataset_name][category] + answers["data"] = answers["data"][:total_num] all_answers[category] = answers all_answers_with_dataset_class["inference_results"] = all_answers @@ -75,6 +82,7 @@ def rm_and_merge( all_answers_with_dataset_class, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"), ) + os.remove(os.path.join(save_path, model_name, "dataset_cat_num_mapping.json")) logger.info(f"Save inference results of model {model_name} for all dataset.") logger.info(f"Save inference results of all models for all dataset.") @@ -118,6 +126,7 @@ def main(args): debug_args = {} few_shot_args = {} multiturn_args = {} + dataset_cat_num_mapping = {} config = utils.jload(args.config) @@ -183,6 +192,7 @@ def main(args): model_name = model_parameter["name"] model_class = eval(f"models.{model_parameter['model_class']}") paramerters = model_parameter["parameters"] + batch_size = paramerters["batch_size"] paramerters.update({"logger": logger}) paramerters.update({"prompt_template": utils.prompt_templates[paramerters["prompt_template"]]}) paramerters.update({"shard_config": shard_config}) @@ -192,35 +202,26 @@ def main(args): raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.") for dataset_name, split_data in inference_data.items(): - start = 0 + cat_num_mapping = {} prev_questions = None for category, category_data in split_data.items(): num_turn = category_data["inference_kwargs"].get("turns", 1) + cat_num_mapping[category] = len(category_data["data"]) if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None: raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") answers_to_dump = copy.deepcopy(category_data) - partition_size = len(category_data["data"]) // dp_size - redundant = len(category_data["data"]) % dp_size - - # Ensure that the amount of data for inference is as consistent as possible across different processes. - lengths = [partition_size for _ in range(dp_size)] - for j in range(redundant): - lengths[(j + start) % dp_size] += 1 - - start = (start + redundant) % dp_size - for turn in range(num_turn): if turn == 0: - questions = category_data["data"][ - sum(lengths[0:dp_rank]) : sum(lengths[0:dp_rank]) + lengths[dp_rank] - ] + dist_dataset = DistributedDataset(category_data["data"]) else: - questions = prev_questions - + dist_dataset = DistributedDataset(prev_questions) + + sampler = DistributedSampler(dist_dataset, num_replicas=world_size, rank=rank, shuffle=False) + questions_loader = DataLoader(dist_dataset, batch_size=batch_size, sampler=sampler, num_workers=8, pin_memory=True, collate_fn=lambda x: x) answers_per_rank = model_.inference( - questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + data_loader=questions_loader, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] ) prev_questions = answers_per_rank @@ -236,11 +237,15 @@ def main(args): ), ) + dataset_cat_num_mapping[dataset_name] = cat_num_mapping + logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB") del model_ accelerator.empty_cache() + utils.jdump(dataset_cat_num_mapping, os.path.join(args.inference_save_path, model_name, "dataset_cat_num_mapping.json")) + dist.barrier() if rank == 0: model_names = [model_parameter["name"] for model_parameter in model_parameters] From e2ba76ba5186b7724d3c3b4c2913273f95ef73f7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 07:37:25 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../colossal_eval/dataset/agieval.py | 4 +--- .../colossal_eval/dataset/base.py | 8 ++++--- .../colossal_eval/dataset/ceval.py | 4 +--- .../colossal_eval/dataset/cmmlu.py | 4 +--- .../colossal_eval/dataset/gaokaobench.py | 4 +--- .../colossal_eval/dataset/mmlu.py | 4 +--- .../colossal_eval/models/huggingface.py | 3 +-- .../colossal_eval/utils/conversation.py | 4 +--- .../examples/dataset_evaluation/inference.py | 23 ++++++++++++++----- 9 files changed, 29 insertions(+), 29 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/dataset/agieval.py b/applications/ColossalEval/colossal_eval/dataset/agieval.py index 43d4cc222647..c1cfe37d7599 100644 --- a/applications/ColossalEval/colossal_eval/dataset/agieval.py +++ b/applications/ColossalEval/colossal_eval/dataset/agieval.py @@ -197,9 +197,7 @@ class AGIEvalDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} files = glob.glob(os.path.join(path, "*.jsonl")) diff --git a/applications/ColossalEval/colossal_eval/dataset/base.py b/applications/ColossalEval/colossal_eval/dataset/base.py index 210be4920dda..a29f56fd1998 100644 --- a/applications/ColossalEval/colossal_eval/dataset/base.py +++ b/applications/ColossalEval/colossal_eval/dataset/base.py @@ -1,7 +1,8 @@ from abc import abstractstaticmethod -from torch.utils.data import Dataset from colossal_eval.utils import jdump +from torch.utils.data import Dataset + from colossalai.logging import DistributedLogger @@ -25,12 +26,13 @@ def save(self, save_path): def load(path, logger: DistributedLogger, *args, **kwargs): """Load the original dataset and convert it into the inference dataset""" + class DistributedDataset(Dataset): def __init__(self, data): self.data = data - + def __len__(self): return len(self.data) - + def __getitem__(self, idx): return self.data[idx] diff --git a/applications/ColossalEval/colossal_eval/dataset/ceval.py b/applications/ColossalEval/colossal_eval/dataset/ceval.py index 3357b1131403..1023d1e23c1f 100644 --- a/applications/ColossalEval/colossal_eval/dataset/ceval.py +++ b/applications/ColossalEval/colossal_eval/dataset/ceval.py @@ -90,9 +90,7 @@ class CEvalDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) diff --git a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py index 8025a7d98cca..05752c2486fa 100644 --- a/applications/ColossalEval/colossal_eval/dataset/cmmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/cmmlu.py @@ -101,9 +101,7 @@ class CMMLUDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) diff --git a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py index 73bbf4fbd856..44ccea9cfa2c 100644 --- a/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py +++ b/applications/ColossalEval/colossal_eval/dataset/gaokaobench.py @@ -69,9 +69,7 @@ class GaoKaoBenchDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, *args, **kwargs) -> List[Dict]: dataset = {"test": {}} for category in ["Fill-in-the-blank_Questions", "Multiple-choice_Questions", "Open-ended_Questions"]: files = os.listdir(os.path.join(path, "data", category)) diff --git a/applications/ColossalEval/colossal_eval/dataset/mmlu.py b/applications/ColossalEval/colossal_eval/dataset/mmlu.py index 44daf90444c3..e9465c91b3ce 100644 --- a/applications/ColossalEval/colossal_eval/dataset/mmlu.py +++ b/applications/ColossalEval/colossal_eval/dataset/mmlu.py @@ -31,9 +31,7 @@ class MMLUDataset(BaseDataset): """ @staticmethod - def load( - path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs - ) -> List[Dict]: + def load(path: str, logger: DistributedLogger, few_shot: bool, *args, **kwargs) -> List[Dict]: dataset = {"dev": {}, "test": {}} for split in ["dev", "test"]: files = os.listdir(os.path.join(path, split)) diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 102b87a03fbf..a8b57421d926 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -1,12 +1,11 @@ import copy -import math from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch -from torch.utils.data import DataLoader from colossal_eval.utils import Conversation, get_batch_prompt, is_rank_0 from peft import PeftModel +from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py index 3898517ccab8..4cbc5326c79b 100644 --- a/applications/ColossalEval/colossal_eval/utils/conversation.py +++ b/applications/ColossalEval/colossal_eval/utils/conversation.py @@ -123,9 +123,7 @@ def dict(self): } -def get_few_shot_prefix( - few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int -) -> str: +def get_few_shot_prefix(few_shot_data: List[str], tokenizer: Optional[AutoTokenizer], max_tokens: int) -> str: """ Get few shot prefix. diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 6a4592dfa9a0..b3449ac9d459 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -4,9 +4,9 @@ from typing import Dict, List import torch.distributed as dist -from torch.utils.data import DataLoader, DistributedSampler -from colossal_eval.dataset.base import DistributedDataset from colossal_eval import dataset, models, utils +from colossal_eval.dataset.base import DistributedDataset +from torch.utils.data import DataLoader, DistributedSampler import colossalai from colossalai.accelerator import get_accelerator @@ -217,11 +217,20 @@ def main(args): dist_dataset = DistributedDataset(category_data["data"]) else: dist_dataset = DistributedDataset(prev_questions) - + sampler = DistributedSampler(dist_dataset, num_replicas=world_size, rank=rank, shuffle=False) - questions_loader = DataLoader(dist_dataset, batch_size=batch_size, sampler=sampler, num_workers=8, pin_memory=True, collate_fn=lambda x: x) + questions_loader = DataLoader( + dist_dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=8, + pin_memory=True, + collate_fn=lambda x: x, + ) answers_per_rank = model_.inference( - data_loader=questions_loader, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name] + data_loader=questions_loader, + inference_kwargs=category_data["inference_kwargs"], + debug=debug_args[dataset_name], ) prev_questions = answers_per_rank @@ -244,7 +253,9 @@ def main(args): del model_ accelerator.empty_cache() - utils.jdump(dataset_cat_num_mapping, os.path.join(args.inference_save_path, model_name, "dataset_cat_num_mapping.json")) + utils.jdump( + dataset_cat_num_mapping, os.path.join(args.inference_save_path, model_name, "dataset_cat_num_mapping.json") + ) dist.barrier() if rank == 0: From 2d3ae95d39becfc1aa70f362391f460522c79033 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 30 Jul 2024 08:55:21 +0000 Subject: [PATCH 05/10] fix tp error --- .../examples/dataset_evaluation/inference.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index b3449ac9d459..2efaf20b611e 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -110,13 +110,19 @@ def main(args): pg_mesh = ProcessGroupMesh(dp_size, args.tp_size) tp_group = pg_mesh.get_group_along_axis(TP_AXIS) + pg_mesh.get_group_along_axis(DP_AXIS) coordinates = pg_mesh._coord dp_rank = coordinates[DP_AXIS] tp_rank = coordinates[TP_AXIS] shard_config = ( - ShardConfig(tensor_parallel_process_group=tp_group, enable_tensor_parallelism=args.tp_size > 1) + ShardConfig( + tensor_parallel_process_group=tp_group, + enable_tensor_parallelism=args.tp_size > 1, + parallel_output=False, + enable_all_optimization=True, + ) if args.tp_size > 1 else None ) @@ -218,7 +224,12 @@ def main(args): else: dist_dataset = DistributedDataset(prev_questions) - sampler = DistributedSampler(dist_dataset, num_replicas=world_size, rank=rank, shuffle=False) + sampler = DistributedSampler( + dist_dataset, + num_replicas=pg_mesh.size(DP_AXIS), + rank=pg_mesh.coordinate(DP_AXIS), + shuffle=False, + ) questions_loader = DataLoader( dist_dataset, batch_size=batch_size, From 3d7d25416385f818e393e617ffbd8198457dfc6b Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 30 Jul 2024 09:53:08 +0000 Subject: [PATCH 06/10] remove unused parameters --- .../ColossalEval/colossal_eval/models/huggingface.py | 7 +------ .../ColossalEval/examples/dataset_evaluation/inference.py | 3 +++ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index a8b57421d926..48a7efed649d 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -359,14 +359,9 @@ def inference(self, data_loader: DataLoader, inference_kwargs: Dict[str, Any], d self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)} - for sample_data in data_loader: - break - turn = 0 if not isinstance(sample_data[0]["output"], list) else len(sample_data[0]["output"]) + 1 - turn_desc = "" if turn == 0 else f"-turn{turn}" - bar = tqdm( range(len(data_loader)), - desc=f"{sample_data[0]['dataset']}-{sample_data[0]['category']}{turn_desc} Inference steps", + desc=f"{inference_kwargs['dataset']}-{inference_kwargs['category']} Inference steps", disable=not is_rank_0(), ) loss_fct = torch.nn.CrossEntropyLoss(reduction="none") diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 2efaf20b611e..b41835105168 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -238,6 +238,9 @@ def main(args): pin_memory=True, collate_fn=lambda x: x, ) + category_data["inference_kwargs"]["dataset"] = dataset_name + category_data["inference_kwargs"]["category"] = category + answers_per_rank = model_.inference( data_loader=questions_loader, inference_kwargs=category_data["inference_kwargs"], From 88f5efde14ba9e9b738710683b9d5ae8c0527833 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 30 Jul 2024 11:16:01 +0000 Subject: [PATCH 07/10] remove unused --- applications/ColossalEval/colossal_eval/models/huggingface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py index 48a7efed649d..e91743525f0e 100644 --- a/applications/ColossalEval/colossal_eval/models/huggingface.py +++ b/applications/ColossalEval/colossal_eval/models/huggingface.py @@ -130,7 +130,7 @@ def _load_model( if shard_config is not None: self.model = AutoModel.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) - self.model, sharded_parameters = shard_former.optimize(self.model) + self.model, _ = shard_former.optimize(self.model) self.model.to(get_current_device()) if peft_path is not None: @@ -598,7 +598,7 @@ def _load_model( if shard_config is not None: self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs) shard_former = ShardFormer(shard_config) - self.model, sharded_parameters = shard_former.optimize(self.model) + self.model, _ = shard_former.optimize(self.model) self.model.to(get_current_device()) if peft_path is not None: From 5b3736841544161d5968fae1d62fa65a838d6db0 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 30 Jul 2024 11:21:55 +0000 Subject: [PATCH 08/10] update inference --- .../ColossalEval/examples/dataset_evaluation/inference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index b41835105168..6ba950ee19cc 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -110,7 +110,6 @@ def main(args): pg_mesh = ProcessGroupMesh(dp_size, args.tp_size) tp_group = pg_mesh.get_group_along_axis(TP_AXIS) - pg_mesh.get_group_along_axis(DP_AXIS) coordinates = pg_mesh._coord dp_rank = coordinates[DP_AXIS] From 79413227115fa32f0c04cdc68eeb87e37982bee7 Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 1 Aug 2024 06:27:15 +0000 Subject: [PATCH 09/10] update docs --- .../ColossalEval/colossal_eval/utils/conversation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py index 4cbc5326c79b..c0445e84ec76 100644 --- a/applications/ColossalEval/colossal_eval/utils/conversation.py +++ b/applications/ColossalEval/colossal_eval/utils/conversation.py @@ -128,8 +128,8 @@ def get_few_shot_prefix(few_shot_data: List[str], tokenizer: Optional[AutoTokeni Get few shot prefix. Args: - conv: Conversation template. - few_shot_examples: Few shot examples to generate few shot prompt prefix. + few_shot_data: Few shot examples to generate few shot prompt prefix. + tokenizer: tokenizer used to tokenize data. Returns: Few shot prompt prefix. @@ -164,6 +164,7 @@ def get_batch_prompt( conv: Conversation template. batch: Batch data to generate prompt from. few_shot_data: Few shot data to generate few shot prompt prefix. + tokenizer: tokenizer used to tokenize data. Returns: Tuple containg batch prompt and target. From ae728fe08bfda35fd1137f9f5e961c3afa03e3dd Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 1 Aug 2024 10:16:01 +0000 Subject: [PATCH 10/10] update inference --- .../examples/dataset_evaluation/inference.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py index 6ba950ee19cc..c651970ee37c 100644 --- a/applications/ColossalEval/examples/dataset_evaluation/inference.py +++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py @@ -38,8 +38,6 @@ def rm_and_merge( """ for model_name in model_names: - dataset_cat_num_mapping = utils.jload(os.path.join(save_path, model_name, "dataset_cat_num_mapping.json")) - for dataset_name, categories in dataset_names.items(): all_answers_with_dataset_class = {} all_answers_with_dataset_class["dataset_class"] = dataset_classes[dataset_name] @@ -59,7 +57,8 @@ def rm_and_merge( ) else: rank_answers = utils.jload(directory) - answers["data"].extend(rank_answers["data"]) + deduplidate_answers = [x for x in rank_answers["data"] if x not in answers["data"]] + answers["data"].extend(deduplidate_answers) answers["inference_kwargs"] = rank_answers["inference_kwargs"] for r in range(dp_size): @@ -70,9 +69,7 @@ def rm_and_merge( os.remove(directory) except Exception as e: print(e) - - total_num = dataset_cat_num_mapping[dataset_name][category] - answers["data"] = answers["data"][:total_num] + print(len(answers["data"])) all_answers[category] = answers all_answers_with_dataset_class["inference_results"] = all_answers @@ -82,7 +79,6 @@ def rm_and_merge( all_answers_with_dataset_class, os.path.join(save_path, model_name, f"{dataset_name}_inference_results.json"), ) - os.remove(os.path.join(save_path, model_name, "dataset_cat_num_mapping.json")) logger.info(f"Save inference results of model {model_name} for all dataset.") logger.info(f"Save inference results of all models for all dataset.") @@ -131,7 +127,6 @@ def main(args): debug_args = {} few_shot_args = {} multiturn_args = {} - dataset_cat_num_mapping = {} config = utils.jload(args.config) @@ -207,11 +202,9 @@ def main(args): raise ValueError(f"Model class {model_parameter['model_class']} is not a subclass of BaseModel.") for dataset_name, split_data in inference_data.items(): - cat_num_mapping = {} prev_questions = None for category, category_data in split_data.items(): num_turn = category_data["inference_kwargs"].get("turns", 1) - cat_num_mapping[category] = len(category_data["data"]) if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None: raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!") @@ -259,17 +252,11 @@ def main(args): ), ) - dataset_cat_num_mapping[dataset_name] = cat_num_mapping - logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB") del model_ accelerator.empty_cache() - utils.jdump( - dataset_cat_num_mapping, os.path.join(args.inference_save_path, model_name, "dataset_cat_num_mapping.json") - ) - dist.barrier() if rank == 0: model_names = [model_parameter["name"] for model_parameter in model_parameters]