diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 000000000000..6cc2f20c0959
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1 @@
+* @hpcaitech/colossalai-qa
diff --git a/.github/workflows/release_test_pypi_before_merge.yml b/.github/workflows/release_test_pypi_before_merge.yml
index 49c626265175..284ab4d1afb0 100644
--- a/.github/workflows/release_test_pypi_before_merge.yml
+++ b/.github/workflows/release_test_pypi_before_merge.yml
@@ -27,7 +27,9 @@ jobs:
echo $new_version > ./version.txt
echo "version=$new_version" >> $GITHUB_OUTPUT
- - run: python setup.py sdist build
+ - run: |
+ pip install --upgrade pip
+ python setup.py sdist build
# publish to PyPI if executed on the main branch
- name: Publish package to PyPI
diff --git a/.github/workflows/run_colossalqa_unit_tests.yml b/.github/workflows/run_colossalqa_unit_tests.yml
new file mode 100644
index 000000000000..763db277289f
--- /dev/null
+++ b/.github/workflows/run_colossalqa_unit_tests.yml
@@ -0,0 +1,54 @@
+name: Run colossalqa unit tests
+
+on:
+ pull_request:
+ types: [synchronize, opened, reopened]
+ paths:
+ - 'applications/ColossalQA/colossalqa/**'
+ - 'applications/ColossalQA/requirements.txt'
+ - 'applications/ColossalQA/setup.py'
+ - 'applications/ColossalQA/tests/**'
+ - 'applications/ColossalQA/pytest.ini'
+
+jobs:
+ tests:
+ name: Run colossalqa unit tests
+ if: |
+ github.event.pull_request.draft == false &&
+ github.base_ref == 'main' &&
+ github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
+ runs-on: [self-hosted, gpu]
+ container:
+ image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
+ volumes:
+ - /data/scratch/test_data_colossalqa:/data/scratch/test_data_colossalqa
+ - /data/scratch/llama-tiny:/data/scratch/llama-tiny
+ options: --gpus all --rm
+ timeout-minutes: 30
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - name: Checkout ColossalAI
+ uses: actions/checkout@v2
+
+ - name: Install colossalqa
+ run: |
+ cd applications/ColossalQA
+ pip install -e .
+
+ - name: Execute Unit Testing
+ run: |
+ cd applications/ColossalQA
+ pytest tests/
+ env:
+ NCCL_SHM_DISABLE: 1
+ MAX_JOBS: 8
+ ZH_MODEL_PATH: bigscience/bloom-560m
+ ZH_MODEL_NAME: bloom
+ EN_MODEL_PATH: bigscience/bloom-560m
+ EN_MODEL_NAME: bloom
+ TEST_DATA_PATH_EN: /data/scratch/test_data_colossalqa/companies.txt
+ TEST_DATA_PATH_ZH: /data/scratch/test_data_colossalqa/companies_zh.txt
+ TEST_DOCUMENT_LOADER_DATA_PATH: /data/scratch/test_data_colossalqa/tests/*
+ SQL_FILE_PATH: /data/scratch/test_data_colossalqa/sql_file_path
\ No newline at end of file
diff --git a/.gitmodules b/.gitmodules
index 2f1c34298a50..a4f30caa4cd1 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,7 +1,3 @@
-[submodule "inference"]
- path = inference
- url = https://github.com/hpcaitech/EnergonAI.git
- branch = main
[submodule "examples/tutorial/fastfold/FastFold"]
path = examples/tutorial/fastfold/FastFold
url = https://github.com/hpcaitech/FastFold
diff --git a/LICENSE b/LICENSE
index b3eb43520a6f..bacb03e72246 100644
--- a/LICENSE
+++ b/LICENSE
@@ -527,3 +527,28 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved.
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+
+
+ ---------------- LICENSE FOR LangChain TEAM ----------------
+
+ The MIT License
+
+ Copyright (c) Harrison Chase
+
+ Permission is hereby granted, free of charge, to any person obtaining a copy
+ of this software and associated documentation files (the "Software"), to deal
+ in the Software without restriction, including without limitation the rights
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ copies of the Software, and to permit persons to whom the Software is
+ furnished to do so, subject to the following conditions:
+
+ The above copyright notice and this permission notice shall be included in
+ all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ THE SOFTWARE.
diff --git a/applications/Chat/examples/community/peft/train_peft_prompts.py b/applications/Chat/examples/community/peft/train_peft_prompts.py
index 99a024f1463c..1dd9ffcdf1cd 100644
--- a/applications/Chat/examples/community/peft/train_peft_prompts.py
+++ b/applications/Chat/examples/community/peft/train_peft_prompts.py
@@ -118,7 +118,7 @@ def main(args):
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
- tokenizer.eos_token = "<\s>"
+ tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
diff --git a/applications/Chat/examples/community/peft/train_peft_sft.py b/applications/Chat/examples/community/peft/train_peft_sft.py
index 3bbef7208374..6d395deadd0e 100644
--- a/applications/Chat/examples/community/peft/train_peft_sft.py
+++ b/applications/Chat/examples/community/peft/train_peft_sft.py
@@ -68,7 +68,7 @@ def train(args):
padding_side="right",
use_fast=False,
)
- tokenizer.eos_token = "<\s>"
+ tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
diff --git a/applications/Chat/examples/inference.py b/applications/Chat/examples/inference.py
index 62e06bf7b3bb..9df8649d9c61 100644
--- a/applications/Chat/examples/inference.py
+++ b/applications/Chat/examples/inference.py
@@ -39,7 +39,7 @@ def eval(args):
tokenizer.pad_token = tokenizer.eos_token
elif args.model == "llama":
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
- tokenizer.eos_token = "<\s>"
+ tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py
index 8868e278d85e..40e06043ab57 100644
--- a/applications/Chat/examples/train_prompts.py
+++ b/applications/Chat/examples/train_prompts.py
@@ -125,7 +125,7 @@ def main(args):
tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
)
- tokenizer.eos_token = "<\s>"
+ tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
diff --git a/applications/Chat/examples/train_reward_model.py b/applications/Chat/examples/train_reward_model.py
index df6e8b6bdc26..fcdd29b2954b 100644
--- a/applications/Chat/examples/train_reward_model.py
+++ b/applications/Chat/examples/train_reward_model.py
@@ -72,7 +72,7 @@ def train(args):
tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
)
- tokenizer.eos_token = "<\s>"
+ tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
diff --git a/applications/Chat/examples/train_sft.py b/applications/Chat/examples/train_sft.py
index 66d08da30120..d00c04809a2d 100644
--- a/applications/Chat/examples/train_sft.py
+++ b/applications/Chat/examples/train_sft.py
@@ -75,7 +75,7 @@ def train(args):
tokenizer = LlamaTokenizer.from_pretrained(
"hf-internal-testing/llama-tokenizer" if args.tokenizer is None else args.tokenizer
)
- tokenizer.eos_token = "<\s>"
+ tokenizer.eos_token = ""
tokenizer.pad_token = tokenizer.unk_token
elif args.model == "chatglm":
tokenizer = ChatGLMTokenizer.from_pretrained(
diff --git a/applications/ColossalEval/colossal_eval/dataset/__init__.py b/applications/ColossalEval/colossal_eval/dataset/__init__.py
index 4ea173198f5a..5b029e2673b1 100644
--- a/applications/ColossalEval/colossal_eval/dataset/__init__.py
+++ b/applications/ColossalEval/colossal_eval/dataset/__init__.py
@@ -6,6 +6,7 @@
from .gaokaobench import GaoKaoBenchDataset
from .longbench import LongBenchDataset
from .mmlu import MMLUDataset
+from .mtbench import MTBenchDataset
__all__ = [
"AGIEvalDataset",
@@ -16,4 +17,5 @@
"LongBenchDataset",
"MMLUDataset",
"ColossalDataset",
+ "MTBenchDataset",
]
diff --git a/applications/ColossalEval/colossal_eval/dataset/mtbench.py b/applications/ColossalEval/colossal_eval/dataset/mtbench.py
new file mode 100644
index 000000000000..9e74a4d826e3
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/dataset/mtbench.py
@@ -0,0 +1,72 @@
+import copy
+import json
+import os
+from collections import defaultdict
+from typing import Dict, List
+
+from colossal_eval.utils import get_json_list
+
+from colossalai.logging import DistributedLogger
+
+from .base import BaseDataset
+
+default_inference_kwargs = {
+ "calculate_loss": False,
+ "all_classes": None,
+ "language": "English",
+ "pretrain": False,
+ "max_new_tokens": 1024,
+ "turns": 2,
+}
+
+
+class MTBenchDataset(BaseDataset):
+ """
+ Dataset class for mt_bench dataset.
+ Data source: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/data/mt_bench/question.jsonl
+ This dataset class will convert the original dataset into the inference dataset.
+ """
+
+ def __init__(self, path, logger, few_shot):
+ self.multiturn = True
+ self.dataset = self.load(path, logger, few_shot)
+
+ @staticmethod
+ def load(path: str, logger: DistributedLogger, few_shot: bool) -> List[Dict]:
+ dataset = {"test": defaultdict(dict)}
+
+ file_path = os.path.join(path, "question.jsonl")
+ ref_path = os.path.join(path, "reference_answer/gpt-4.jsonl")
+
+ reference = defaultdict(list)
+ ref_origin = get_json_list(ref_path)
+ for ref in ref_origin:
+ reference[ref["question_id"]] = ref["choices"][0]["turns"]
+
+ with open(file_path, "r", encoding="utf-8") as file:
+ for line in file:
+ question = json.loads(line)
+ category = question["category"]
+ turn_number = len(question["turns"])
+ data_point = {
+ "id": question["question_id"],
+ "dataset": "mtbench",
+ "split": "test",
+ "category": category,
+ "instruction": question["turns"],
+ "input": "",
+ "output": [],
+ "target": [""] * turn_number
+ if question["question_id"] not in reference
+ else reference[question["question_id"]],
+ }
+
+ if category in dataset["test"]:
+ dataset["test"][category]["data"].append(data_point)
+ else:
+ dataset["test"][category] = {
+ "data": [data_point],
+ "inference_kwargs": copy.deepcopy(default_inference_kwargs),
+ }
+
+ return dataset
diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py
index c70988707a15..57ccd1aa6a1e 100644
--- a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py
+++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/dataset_evaluator.py
@@ -1,12 +1,15 @@
+import os
from typing import Dict, List
import colossal_eval.evaluate.dataset_evaluator.metrics as metric_helper
import numpy as np
import tqdm
+from colossal_eval.utils import jdump
LabelBasedMetrics = ["first_token_accuracy", "matthews_correlation"]
LossBasedMetrics = ["perplexity", "ppl_score", "ppl_score_over_choices", "per_byte_perplexity", "per_byte_ppl_score"]
CombinedMetrics = ["combined_single_choice_accuracy"]
+GPTMetrics = ["mtbench_single_judge"]
OtherMetrics = [
"f1_score",
"f1_zh_score",
@@ -29,8 +32,9 @@ class DatasetEvaluator(object):
"""
- def __init__(self):
- pass
+ def __init__(self, config_path: str, save_path: str):
+ self.config_path = config_path
+ self.save_path = save_path
def _calculate_label_metrics(self, metric: str, category: str):
"""Calculate label-based metrics."""
@@ -60,6 +64,11 @@ def _calculate_label_metrics(self, metric: str, category: str):
sample["output"], ref, all_classes=self.data[category]["inference_kwargs"]["all_classes"]
),
)
+
+ score = max(
+ score,
+ metric_helper.accuracy_by_options(sample["input"], sample["output"], ref),
+ )
softmaxs.append(references[i] if score == 1 else -1)
else:
softmaxs.append(np.argmax(np.array(list(sample["softmax_over_choices"].values()))))
@@ -151,6 +160,24 @@ def _calculate_other_metrics(self, metric: str, category: str):
self.evaluation_results[metric][category] = (total_score, len(self.data[category]["data"]))
self.evaluation_results[metric]["ALL"] += total_score * weight
+ def _calculate_gpt_metrics(self, metric: str, category: str):
+ """Calculate gpt metrics."""
+ weight = len(self.data[category]["data"]) / self.metric_total_length[metric]
+
+ metric_method = eval("gpt_helper." + metric)
+
+ judgements, avg_ratings = metric_method(self.data[category]["data"], self.config_path)
+ self.judgements[category] = judgements
+
+ self.evaluation_results[metric][category] = (np.mean(avg_ratings), len(self.data[category]["data"]))
+ self.evaluation_results[metric]["ALL"] += np.mean(avg_ratings) * weight
+
+ for i in range(avg_ratings.shape[0]):
+ if f"{metric}_{i+1}" not in self.evaluation_results:
+ self.evaluation_results[f"{metric}_{i+1}"] = {cat: 0 for cat in (["ALL"] + self.categories)}
+ self.evaluation_results[f"{metric}_{i+1}"][category] = (avg_ratings[i], len(self.data[category]["data"]))
+ self.evaluation_results[f"{metric}_{i+1}"]["ALL"] += avg_ratings[i] * weight
+
def _calculate_loss_metrics(self, metric: str, category: str):
"""Calculate perplexity."""
if metric == "perplexity":
@@ -212,10 +239,20 @@ def _evaluate(self):
for category in self.suggested_categories[metric]:
self._calculate_combined_metrics(metric, category)
pbar.update(1)
+ elif metric in GPTMetrics:
+ for category in self.suggested_categories[metric]:
+ self._calculate_gpt_metrics(metric, category)
+ pbar.update(1)
elif metric in OtherMetrics:
for category in self.suggested_categories[metric]:
self._calculate_other_metrics(metric, category)
pbar.update(1)
+ else:
+ raise Exception(f"{metric} not supported.")
+
+ if self.judgements:
+ judgement_path = os.path.join(self.save_path, f"{self.model_name}_judgements.json")
+ jdump(self.judgements, judgement_path)
return self.evaluation_results
@@ -235,6 +272,7 @@ def get_evaluation_results(self, data: List[Dict], dataset_name: str, model_name
self.model_name = model_name
self.categories = list(data.keys())
self.metrics = metrics
+ self.judgements = {}
self.evaluation_results = {
metric: {category: 0 for category in (["ALL"] + self.categories)} for metric in self.metrics
diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/gpt_judge.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/gpt_judge.py
new file mode 100644
index 000000000000..cd41dd7fdff0
--- /dev/null
+++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/gpt_judge.py
@@ -0,0 +1,151 @@
+# Code adapted from https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge
+
+import ast
+import concurrent.futures
+import copy
+import json
+import os
+import re
+import time
+from typing import Any, Dict, List
+
+import numpy as np
+import openai
+import tqdm
+
+MODEL = "gpt-4"
+
+API_MAX_RETRY = 16
+API_RETRY_SLEEP = 10
+API_ERROR_OUTPUT = "$ERROR$"
+
+NEED_REF_CATS = ["math", "reasoning", "coding"]
+
+one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]")
+one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]")
+
+
+def load_mt_prompts(prompt_file: str):
+ prompts = {}
+ with open(prompt_file) as fin:
+ for line in fin:
+ line = json.loads(line)
+ prompts[line["name"]] = line
+ return prompts
+
+
+def get_mt_prompt(prompts: Dict[str, str], multiturn: bool, math: bool):
+ if math and multiturn:
+ return prompts["single-math-v1-multi-turn"]
+ elif math and not multiturn:
+ return prompts["single-math-v1"]
+ elif not math and multiturn:
+ return prompts["single-v1-multi-turn"]
+ elif not math and not multiturn:
+ return prompts["single-v1"]
+
+
+def chat_compeletion_openai(messages: List[Dict], temperature: float = 0.0, max_tokens: int = 2048):
+ output = API_ERROR_OUTPUT
+ model = MODEL
+ for _ in range(API_MAX_RETRY):
+ try:
+ response = openai.ChatCompletion.create(
+ model=model,
+ messages=messages,
+ n=1,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ output = response["choices"][0]["message"]["content"]
+ break
+ except openai.error.OpenAIError as e:
+ print(type(e), e)
+ time.sleep(API_RETRY_SLEEP)
+
+ return output
+
+
+def get_mtbench_judgements(question: Dict[str, Any], prompts: Dict[str, str]):
+ id = question["id"]
+ judgement = {"id": id, "judgements": [], "ratings": []}
+ category = question["category"]
+ math = category in NEED_REF_CATS
+ turn_number = len(question["instruction"])
+
+ for num in range(turn_number):
+ assert (len(question["target"]) >= 1 and math) or not math
+ kwargs = {}
+ if num >= 1:
+ prompt = get_mt_prompt(prompts, multiturn=True, math=math)
+ if len(question["target"]) >= 1 and math:
+ kwargs = {f"ref_answer_{i+1}": question["target"][i] for i in range(len(question["target"]))}
+ user_prompt = prompt["prompt_template"].format(
+ question_1=question["instruction"][0],
+ question_2=question["instruction"][1],
+ answer_1=question["output"][0],
+ answer_2=question["output"][1],
+ **kwargs,
+ )
+ else:
+ prompt = get_mt_prompt(prompts, multiturn=False, math=math)
+ if len(question["target"]) >= 1 and math:
+ kwargs = {"ref_answer_1": question["target"][0]}
+ user_prompt = prompt["prompt_template"].format(
+ question=question["instruction"][0],
+ answer=question["output"][0],
+ **kwargs,
+ )
+
+ rating = -1
+ sys_prompt = prompt["system_prompt"]
+ messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_prompt}]
+
+ judgement_str = chat_compeletion_openai(messages, temperature=0.0, max_tokens=2048)
+ match = re.search(one_score_pattern, judgement_str)
+ if not match:
+ match = re.search(one_score_pattern_backup, judgement_str)
+ if match:
+ rating = ast.literal_eval(match.groups()[0])
+ else:
+ rating = -1
+
+ judgement["judgements"].append(judgement_str)
+ judgement["ratings"].append(rating)
+
+ return judgement
+
+
+def mtbench_single_judge(data: List[Dict], config_path: str):
+ judgements = []
+
+ prompt_dir = os.path.dirname(config_path)
+ prompts = load_mt_prompts(os.path.join(prompt_dir, "mtbench_judge_prompts.jsonl"))
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ futures = []
+ for i, question in enumerate(data):
+ future = executor.submit(get_mtbench_judgements, question, prompts)
+ futures.append(future)
+
+ for future in tqdm.tqdm(
+ concurrent.futures.as_completed(futures),
+ desc=f"MTBench single judge for {data[0]['category']}",
+ total=len(futures),
+ ):
+ judgements.append(future.result())
+
+ judgements.sort(key=lambda x: x["id"])
+
+ judgements_by_id = {j["id"]: j for j in judgements}
+
+ data_to_dump = copy.deepcopy(data)
+
+ for d in data_to_dump:
+ id = d["id"]
+ d["judgements"] = judgements_by_id[id]["judgements"]
+ d["ratings"] = judgements_by_id[id]["ratings"]
+
+ avg_ratings = np.mean([j["ratings"] for j in judgements], axis=0)
+
+ return data_to_dump, avg_ratings
diff --git a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py
index 914465478dec..eae35bb9bb85 100644
--- a/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py
+++ b/applications/ColossalEval/colossal_eval/evaluate/dataset_evaluator/metrics.py
@@ -185,6 +185,7 @@
"ppl_score_over_choices": ["ALL"],
"ppl_score": ["ALL"],
},
+ "mtbench": {"mtbench_single_judge": ["ALL"]},
}
@@ -443,6 +444,20 @@ def multi_choice_accuracy(prediction, reference, **kwargs):
return score
+def accuracy_by_options(question, prediction, reference):
+ pattern = r"[A-Z]\. [^\n]+"
+ options = re.findall(pattern, question)
+ answer = prediction.split("\n\n")[0]
+
+ for option in options:
+ choice, content = option.split(". ", 1)
+
+ if choice == reference and content == answer:
+ return 1
+
+ return 0
+
+
def combined_single_choice_accuracy(prediction, reference, **kwargs):
return single_choice_accuracy(prediction, reference, **kwargs)
diff --git a/applications/ColossalEval/colossal_eval/models/huggingface.py b/applications/ColossalEval/colossal_eval/models/huggingface.py
index 9f785a6aa9d1..693e021533bc 100644
--- a/applications/ColossalEval/colossal_eval/models/huggingface.py
+++ b/applications/ColossalEval/colossal_eval/models/huggingface.py
@@ -96,7 +96,7 @@ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str], tokenizer_kw
self.logger.warning("pad_token_id is not set for the tokenizer. " "Using eos_token_id as pad_token_id.")
if self.tokenizer.eos_token:
self.tokenizer.pad_token = self.tokenizer.eos_token
- elif self.tokenizer.eod_id:
+ elif hasattr(self.tokenizer, "eod_id"):
# Qwen has an eod token "<|endoftext|>".
self.tokenizer.pad_token_id = self.tokenizer.eod_id
@@ -333,9 +333,12 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b
self.str_label_map = {choice: idx for idx, choice in enumerate(self.choices)}
+ turn = 0 if not isinstance(data[0]["output"], list) else len(data[0]["output"]) + 1
+ turn_desc = "" if turn == 0 else f"-turn{turn}"
+
bar = tqdm(
range(math.ceil(len(data) / self.batch_size)),
- desc=f"{data[0]['dataset']}-{data[0]['category']} Inference steps",
+ desc=f"{data[0]['dataset']}-{data[0]['category']}{turn_desc} Inference steps",
disable=not is_rank_0(),
)
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@@ -384,7 +387,10 @@ def inference(self, data: List[Dict], inference_kwargs: Dict[str, Any], debug: b
for j in range(len(batch_prompt)):
if not pretrain:
- answers[i + j]["output"] = batch_decodes[j].strip()
+ if isinstance(answers[i + j]["output"], list):
+ answers[i + j]["output"].append(batch_decodes[j].strip())
+ else:
+ answers[i + j]["output"] = batch_decodes[j].strip()
if isinstance(scores, torch.Tensor):
answers[i + j]["softmax_over_choices"] = probs[j]
diff --git a/applications/ColossalEval/colossal_eval/utils/conversation.py b/applications/ColossalEval/colossal_eval/utils/conversation.py
index 6c096a8523c0..54ea212466d4 100644
--- a/applications/ColossalEval/colossal_eval/utils/conversation.py
+++ b/applications/ColossalEval/colossal_eval/utils/conversation.py
@@ -171,6 +171,9 @@ def get_batch_prompt(
for b in batch:
few_shot_prefix = ""
if few_shot_data is not None:
+ assert not isinstance(b["instruction"], list), print(
+ f"When performing few-shot, {b['dataset']} shouldn't be a multiturn dataset."
+ )
# For few-shot, only need input. Otherwise use instruction (in AGIEval).
query_text = b["input"] if b.get("input", "") != "" else b["instruction"]
@@ -181,11 +184,24 @@ def get_batch_prompt(
raise Exception("When using few-shot, target answer should be a string.")
few_shot_prefix = get_few_shot_prefix(conv, few_shot_data, tokenizer, language, max_tokens)
- else:
- query_text = b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
- conv.append_message(conv.roles[0], few_shot_prefix + query_text)
- conv.append_message(conv.roles[1], None)
+ conv.append_message(conv.roles[0], few_shot_prefix + query_text)
+ conv.append_message(conv.roles[1], None)
+ else:
+ if not isinstance(b["instruction"], list):
+ query_text = (
+ b["instruction"] + "\n\n" + b["input"] if b.get("input", "") != "" else b["instruction"]
+ )
+ conv.append_message(conv.roles[0], query_text)
+ conv.append_message(conv.roles[1], None)
+ else:
+ assert len(b["instruction"]) >= len(b["output"]) + 1
+ cur_turns = len(b["output"])
+ for turn in range(cur_turns):
+ conv.append_message(conv.roles[0], b["instruction"][turn])
+ conv.append_message(conv.roles[1], b["output"][turn])
+ conv.append_message(conv.roles[0], b["instruction"][cur_turns])
+ conv.append_message(conv.roles[1], None)
batch_prompt.append(conv.get_prompt())
diff --git a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py
index ec81cf0cef71..5724c6e40693 100644
--- a/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py
+++ b/applications/ColossalEval/examples/dataset_evaluation/eval_dataset.py
@@ -11,7 +11,7 @@ def main(args):
evaluation_results = {dataset["name"]: {} for dataset in config["dataset"]}
evaluation_results_table = {dataset["name"]: {} for dataset in config["dataset"]}
- evaluator = DatasetEvaluator()
+ evaluator = DatasetEvaluator(args.config, args.evaluation_results_save_path)
for dataset_parameter in config["dataset"]:
dataset_name = dataset_parameter["name"]
@@ -26,6 +26,8 @@ def main(args):
results = evaluator.get_evaluation_results(data, dataset_name, model_name, metrics)
for metric, score in results.items():
+ if metric not in results_metric_model:
+ results_metric_model[metric] = {model["name"]: None for model in config["model"]}
results_metric_model[metric][model_name] = score["ALL"]
evaluation_results[dataset_name][model_name] = results
diff --git a/applications/ColossalEval/examples/dataset_evaluation/inference.py b/applications/ColossalEval/examples/dataset_evaluation/inference.py
index 657fc33bf1ef..b3579424ae1c 100644
--- a/applications/ColossalEval/examples/dataset_evaluation/inference.py
+++ b/applications/ColossalEval/examples/dataset_evaluation/inference.py
@@ -71,6 +71,7 @@ def main(args):
inference_data = {}
debug_args = {}
few_shot_args = {}
+ multiturn_args = {}
config = utils.jload(args.config)
@@ -102,6 +103,13 @@ def main(args):
dataset_ = dataset_class(path, logger, dataset_parameter["few_shot"])
dataset_.save(save_path)
+
+ if hasattr(dataset_, "multiturn") and dataset_.multiturn:
+ multiturn_args[dataset_name] = True
+ logger.info(f"{dataset_parameter['dataset_class']} is a multiturn dataset.")
+ else:
+ multiturn_args[dataset_name] = False
+
inference_data[dataset_name] = dataset_.dataset["test"]
for model_parameter in model_parameters:
@@ -117,7 +125,10 @@ def main(args):
for dataset_name, split_data in inference_data.items():
start = 0
+ prev_questions = None
for category, category_data in split_data.items():
+ num_turn = category_data["inference_kwargs"].get("turns", 1)
+
if few_shot_args[dataset_name] and category_data["inference_kwargs"].get("few_shot_data", None) is None:
raise Exception(f"Dataset {dataset_name} doesn't have few-shot data for category {category}!")
@@ -132,11 +143,16 @@ def main(args):
start = (start + redundant) % world_size
- questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
+ for turn in range(num_turn):
+ if turn == 0:
+ questions = category_data["data"][sum(lengths[0:rank]) : sum(lengths[0:rank]) + lengths[rank]]
+ else:
+ questions = prev_questions
- answers_per_rank = model_.inference(
- questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
- )
+ answers_per_rank = model_.inference(
+ questions, inference_kwargs=category_data["inference_kwargs"], debug=debug_args[dataset_name]
+ )
+ prev_questions = answers_per_rank
answers_to_dump["data"] = answers_per_rank
diff --git a/applications/ColossalQA/.gitignore b/applications/ColossalQA/.gitignore
new file mode 100644
index 000000000000..5f5e159a22fc
--- /dev/null
+++ b/applications/ColossalQA/.gitignore
@@ -0,0 +1,152 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+docs/.build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# IDE
+.idea/
+.vscode/
+
+# macos
+*.DS_Store
+#data/
+
+docs/.build
+
+# pytorch checkpoint
+*.pt
+
+# sql
+*.db
+
+# wandb log
+example/wandb/
+example/ui/gradio/
+example/vector_db_for_test
+examples/awesome-chatgpt-prompts/
diff --git a/applications/ColossalQA/README.md b/applications/ColossalQA/README.md
new file mode 100644
index 000000000000..d9ffe5beb605
--- /dev/null
+++ b/applications/ColossalQA/README.md
@@ -0,0 +1,258 @@
+# ColossalQA - Langchain-based Document Retrieval Conversation System
+
+## Table of Contents
+
+- [Table of Contents](#table-of-contents)
+- [Overall Implementation](#overall-implementation)
+- [Install](#install)
+- [How to Use](#how-to-use)
+- Examples
+ - [A Simple Web UI Demo](examples/webui_demo/README.md)
+ - [Local Chinese Retrieval QA + Chat](examples/retrieval_conversation_zh.py)
+ - [Local English Retrieval QA + Chat](examples/retrieval_conversation_en.py)
+ - [Local Bi-lingual Retrieval QA + Chat](examples/retrieval_conversation_universal.py)
+ - [Experimental AI Agent Based on Chatgpt + Chat](examples/conversation_agent_chatgpt.py)
+- Use cases
+ - [English customer service chatbot](examples/retrieval_conversation_en_customer_service.py)
+ - [Chinese customer service intent classification](examples/retrieval_intent_classification_zh_customer_service.py)
+
+**As Colossal-AI is undergoing some major updates, this project will be actively maintained to stay in line with the Colossal-AI project.**
+
+## Overall Implementation
+
+### Highlevel Design
+
+
+
+
+Fig.1. Design of the document retrieval conversation system
+
+
+Retrieval-based Question Answering (QA) is a crucial application of natural language processing that aims to find the most relevant answers based on the information from a corpus of text documents in response to user queries. Vector stores, which represent documents and queries as vectors in a high-dimensional space, have gained popularity for their effectiveness in retrieval QA tasks.
+
+#### Step 1: Collect Data
+
+A successful retrieval QA system starts with high-quality data. You need a collection of text documents that's related to your application. You may also need to manually design how your data will be presented to the language model.
+
+#### Step 2: Split Data
+
+Document data is usually too long to fit into the prompt due to the context length limitation of LLMs. Supporting documents need to be splited into short chunks before constructing vector stores. In this demo, we use neural text spliter for better performance.
+
+#### Step 3: Construct Vector Stores
+Choose a embedding function and embed your text chunk into high dimensional vectors. Once you have vectors for your documents, you need to create a vector store. The vector store should efficiently index and retrieve documents based on vector similarity. In this demo, we use [Chroma](https://python.langchain.com/docs/integrations/vectorstores/chroma) and incrementally update indexes of vector stores. Through incremental update, one can update and maintain a vector store without recalculating every embedding.
+You are free to choose any vectorstore from a varity of [vector stores](https://python.langchain.com/docs/integrations/vectorstores/) supported by Langchain. However, the incremental update only works with LangChain vectorstore's that support:
+- Document addition by id (add_documents method with ids argument)
+- Delete by id (delete method with)
+
+#### Step 4: Retrieve Relative Text
+Upon querying, we will run a reference resolution on user's input, the goal of this step is to remove ambiguous reference in user's query such as "this company", "him". We then embed the query with the same embedding function and query the vectorstore to retrieve the top-k most similar documents.
+
+#### Step 5: Format Prompt
+The prompt carries essential information including task description, conversation history, retrived documents, and user's query for the LLM to generate a response. Please refer to this [README](./colossalqa/prompt/README.md) for more details.
+
+#### Step 6: Inference
+Pass the prompt to the LLM with additional generaton arguments to get agent response. You can control the generation with additional arguments such as temperature, top_k, top_p, max_new_tokens. You can also define when to stop by passing the stop substring to the retrieval QA chain.
+
+#### Step 7: Update Memory
+We designed a memory module that automatically summarize overlength conversation to fit the max context length of LLM. In this step, we update the memory with the newly generated response. To fix into the context length of a given LLM, we sumarize the overlength part of historical conversation and present the rest in round-based conversation format. Fig.2. shows how the memory is updated. Please refer to this [README](./colossalqa/prompt/README.md) for dialogue format.
+
+
+
+Fig.2. Design of the memory module
+
+
+### Supported Language Models (LLMs) and Embedding Models
+
+Our platform accommodates two kinds of LLMs: API-accessible and locally run models. For the API-style LLMs, we support ChatGPT, Pangu, and models deployed through the vLLM API Server. For locally operated LLMs, we are compatible with any language model that can be initiated using [`transformers.AutoModel.from_pretrained`](https://huggingface.co/transformers/v3.0.2/model_doc/auto.html#transformers.AutoModel.from_pretrained). However, due to the dependence of retrieval-based QA on the language model's abilities in zero-shot learning, instruction following, and logical reasoning, smaller models are typically not advised. In our local demo, we utilize ChatGLM2 for Chinese and LLaMa2 for English. Modifying the base LLM requires corresponding adjustments to the prompts.
+
+Here are some sample codes to load different types of LLM.
+
+```python
+# For locally-run LLM
+from colossalqa.local.llm import ColossalAPI, ColossalLLM
+api = ColossalAPI('chatglm2', 'path_to_chatglm2_checkpoint')
+llm = ColossalLLM(n=1, api=api)
+
+# For LLMs running on the vLLM API Server
+from colossalqa.local.llm import VllmAPI, VllmLLM
+vllm_api = VllmAPI("Your_vLLM_Host", "Your_vLLM_Port")
+llm = VllmLLM(n=1, api=vllm_api)
+
+# For ChatGPT LLM
+from langchain.llms import OpenAI
+llm = OpenAI(openai_api_key="YOUR_OPENAI_API_KEY")
+
+# For Pangu LLM
+# set up your authentification info
+from colossalqa.local.pangu_llm import Pangu
+os.environ["URL"] = ""
+os.environ["URLNAME"] = ""
+os.environ["PASSWORD"] = ""
+os.environ["DOMAIN_NAME"] = ""
+
+llm = Pangu(id=1)
+llm.set_auth_config()
+```
+
+Regarding embedding models, we support all models that can be loaded via ["langchain.embeddings.HuggingFaceEmbeddings"](https://api.python.langchain.com/en/latest/embeddings/langchain.embeddings.huggingface.HuggingFaceEmbeddings.html). The default embedding model used in this demo is ["moka-ai/m3e-base"](https://huggingface.co/moka-ai/m3e-base), which enables consistent text similarity computations in both Chinese and English.
+
+In the future, supported LLM will also include models running on colossal inference and serving framework.
+
+## Install
+
+Install colossalqa
+```bash
+# python==3.8.17
+cd ColossalAI/applications/ColossalQA
+pip install -e .
+```
+
+To use the vLLM for providing LLM services via an API, please consult the official guide [here](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html#api-server) to start the API server. It's important to set up a new virtual environment for installing vLLM, as there are currently some dependency conflicts between vLLM and ColossalQA when installed on the same machine.
+
+## How to Use
+
+### Collect Your Data
+
+For ChatGPT based Agent we support document retrieval and simple sql search.
+If you want to run the demo locally, we provided document retrieval based conversation system built upon langchain. It accept a wide range of documents. After collecting your data, put your data under a folder.
+
+Read comments under ./colossalqa/data_loader for more detail regarding supported data formats.
+
+### Run The Script
+
+We provide a simple Web UI demo of ColossalQA, enabling you to upload your files as a knowledge base and interact with them through a chat interface in your browser. More details can be found [here](examples/webui_demo/README.md)
+
+
+We also provided some scripts for Chinese document retrieval based conversation system, English document retrieval based conversation system, Bi-lingual document retrieval based conversation system and an experimental AI agent with document retrieval and SQL query functionality. The Bi-lingual one is a high-level wrapper for the other two clases. We write different scripts for different languages because retrieval QA requires different embedding models, LLMs, prompts for different language setting. For now, we use LLaMa2 for English retrieval QA and ChatGLM2 for Chinese retrieval QA for better performance.
+
+To run the bi-lingual scripts.
+```bash
+python retrieval_conversation_universal.py \
+ --en_model_path /path/to/Llama-2-7b-hf \
+ --zh_model_path /path/to/chatglm2-6b \
+ --zh_model_name chatglm2 \
+ --en_model_name llama \
+ --sql_file_path /path/to/any/folder
+```
+
+To run retrieval_conversation_en.py.
+```bash
+python retrieval_conversation_en.py \
+ --model_path /path/to/Llama-2-7b-hf \
+ --model_name llama \
+ --sql_file_path /path/to/any/folder
+```
+
+To run retrieval_conversation_zh.py.
+```bash
+python retrieval_conversation_zh.py \
+ --model_path /path/to/chatglm2-6b \
+ --model_name chatglm2 \
+ --sql_file_path /path/to/any/folder
+```
+
+To run retrieval_conversation_chatgpt.py.
+```bash
+python retrieval_conversation_chatgpt.py \
+ --open_ai_key_path /path/to/plain/text/openai/key/file \
+ --sql_file_path /path/to/any/folder
+```
+
+To run conversation_agent_chatgpt.py.
+```bash
+python conversation_agent_chatgpt.py \
+ --open_ai_key_path /path/to/plain/text/openai/key/file
+```
+
+After runing the script, it will ask you to provide the path to your data during the execution of the script. You can also pass a glob path to load multiple files at once. Please read this [guide](https://docs.python.org/3/library/glob.html) on how to define glob path. Follow the instruction and provide all files for your retrieval conversation system then type "ESC" to finish loading documents. If csv files are provided, please use "," as delimiter and "\"" as quotation mark. For json and jsonl files. The default format is
+```
+{
+ "data":[
+ {"content":"XXX"},
+ {"content":"XXX"}
+ ...
+ ]
+}
+```
+For other formats, please refer to [this document](https://python.langchain.com/docs/modules/data_connection/document_loaders/json) on how to define schema for data loading. There are no other formatting constraints for loading documents type files. For loading table type files, we use pandas, please refer to [Pandas-Input/Output](https://pandas.pydata.org/pandas-docs/stable/reference/io.html) for file format details.
+
+We also support another kay-value mode that utilizes a user-defined key to calculate the embeddings of the vector store. If a query matches a specific key, the value corresponding to that key will be used to generate the prompt. For instance, in the document below, "My coupon isn't working." will be employed during indexing, whereas "Question: My coupon isn't working.\nAnswer: We apologize for ... apply it to?" will appear in the final prompt. This format is typically useful when the task involves carrying on a conversation with readily accessible conversation data, such as customer service, question answering.
+```python
+Document(page_content="My coupon isn't working.", metadata={'is_key_value_mapping': True, 'seq_num': 36, 'source': 'XXX.json', 'value': "Question: My coupon isn't working.\nAnswer:We apologize for the inconvenience. Can you please provide the coupon code and the product name or SKU you're trying to apply it to?"})
+```
+
+For now, we only support the key-value mode for json data files. You can run the script retrieval_conversation_en_customer_service.py by the following command.
+
+```bash
+python retrieval_conversation_en_customer_service.py \
+ --model_path /path/to/Llama-2-7b-hf \
+ --model_name llama \
+ --sql_file_path /path/to/any/folder
+```
+
+## The Plan
+
+- [x] build document retrieval QA tool
+- [x] Add memory
+- [x] Add demo for AI agent with SQL query
+- [x] Add customer retriever for fast construction and retrieving (with incremental update)
+
+## Reference
+
+```bibtex
+@software{Chase_LangChain_2022,
+author = {Chase, Harrison},
+month = oct,
+title = {{LangChain}},
+url = {https://github.com/hwchase17/langchain},
+year = {2022}
+}
+```
+```bibtex
+@inproceedings{DBLP:conf/asru/ZhangCLLW21,
+ author = {Qinglin Zhang and
+ Qian Chen and
+ Yali Li and
+ Jiaqing Liu and
+ Wen Wang},
+ title = {Sequence Model with Self-Adaptive Sliding Window for Efficient Spoken
+ Document Segmentation},
+ booktitle = {{IEEE} Automatic Speech Recognition and Understanding Workshop, {ASRU}
+ 2021, Cartagena, Colombia, December 13-17, 2021},
+ pages = {411--418},
+ publisher = {{IEEE}},
+ year = {2021},
+ url = {https://doi.org/10.1109/ASRU51503.2021.9688078},
+ doi = {10.1109/ASRU51503.2021.9688078},
+ timestamp = {Wed, 09 Feb 2022 09:03:04 +0100},
+ biburl = {https://dblp.org/rec/conf/asru/ZhangCLLW21.bib},
+ bibsource = {dblp computer science bibliography, https://dblp.org}
+}
+```
+```bibtex
+@misc{touvron2023llama,
+ title={Llama 2: Open Foundation and Fine-Tuned Chat Models},
+ author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom},
+ year={2023},
+ eprint={2307.09288},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
+```
+```bibtex
+@article{zeng2022glm,
+ title={Glm-130b: An open bilingual pre-trained model},
+ author={Zeng, Aohan and Liu, Xiao and Du, Zhengxiao and Wang, Zihan and Lai, Hanyu and Ding, Ming and Yang, Zhuoyi and Xu, Yifan and Zheng, Wendi and Xia, Xiao and others},
+ journal={arXiv preprint arXiv:2210.02414},
+ year={2022}
+}
+```
+```bibtex
+@inproceedings{du2022glm,
+ title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
+ author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
+ booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
+ pages={320--335},
+ year={2022}
+}
+```
diff --git a/colossalai/inference/pipeline/modeling/__init__.py b/applications/ColossalQA/colossalqa/__init__.py
similarity index 100%
rename from colossalai/inference/pipeline/modeling/__init__.py
rename to applications/ColossalQA/colossalqa/__init__.py
diff --git a/examples/inference/serving/test_ci.sh b/applications/ColossalQA/colossalqa/chain/__init__.py
similarity index 100%
rename from examples/inference/serving/test_ci.sh
rename to applications/ColossalQA/colossalqa/chain/__init__.py
diff --git a/applications/ColossalQA/colossalqa/chain/memory/__init__.py b/applications/ColossalQA/colossalqa/chain/memory/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalQA/colossalqa/chain/memory/summary.py b/applications/ColossalQA/colossalqa/chain/memory/summary.py
new file mode 100644
index 000000000000..1d63bbc4a47e
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/chain/memory/summary.py
@@ -0,0 +1,103 @@
+"""
+Custom SummarizerMixin base class and ConversationSummaryMemory class
+
+Modified from Original Source
+
+This code is based on LangChain Ai's langchain, which can be found at
+https://github.com/langchain-ai/langchain
+The original code is licensed under the MIT license.
+"""
+
+from __future__ import annotations
+
+from typing import Any, Dict, List, Type
+
+from langchain.chains.llm import LLMChain
+from langchain.memory.chat_memory import BaseChatMemory
+from langchain.memory.prompt import SUMMARY_PROMPT
+from langchain.pydantic_v1 import BaseModel, root_validator
+from langchain.schema import BaseChatMessageHistory, BasePromptTemplate
+from langchain.schema.language_model import BaseLanguageModel
+from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string
+
+
+class SummarizerMixin(BaseModel):
+ """
+ Mixin for summarizer.
+ """
+
+ human_prefix: str = "Human"
+ ai_prefix: str = "Assistant"
+ llm: BaseLanguageModel
+ prompt: BasePromptTemplate = SUMMARY_PROMPT
+ summary_message_cls: Type[BaseMessage] = SystemMessage
+ llm_kwargs: Dict = {}
+
+ def predict_new_summary(self, messages: List[BaseMessage], existing_summary: str, stop: List = []) -> str:
+ """
+ Recursively summarize a conversation by generating a new summary using
+ the last round of conversation and the existing summary.
+ """
+ new_lines = get_buffer_string(
+ messages,
+ human_prefix=self.human_prefix,
+ ai_prefix=self.ai_prefix,
+ )
+
+ chain = LLMChain(llm=self.llm, prompt=self.prompt, llm_kwargs=self.llm_kwargs)
+ return chain.predict(summary=existing_summary, new_lines=new_lines, stop=stop)
+
+
+class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
+ """Conversation summarizer to chat memory."""
+
+ buffer: str = ""
+ memory_key: str = "history"
+
+ @classmethod
+ def from_messages(
+ cls,
+ llm: BaseLanguageModel,
+ chat_memory: BaseChatMessageHistory,
+ summarize_step: int = 2,
+ **kwargs: Any,
+ ) -> ConversationSummaryMemory:
+ obj = cls(llm=llm, chat_memory=chat_memory, **kwargs)
+ for i in range(0, len(obj.chat_memory.messages), summarize_step):
+ obj.buffer = obj.predict_new_summary(obj.chat_memory.messages[i : i + summarize_step], obj.buffer)
+ return obj
+
+ @property
+ def memory_variables(self) -> List[str]:
+ """Will always return list of memory variables."""
+ return [self.memory_key]
+
+ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
+ """Return history buffer."""
+ if self.return_messages:
+ buffer: Any = [self.summary_message_cls(content=self.buffer)]
+ else:
+ buffer = self.buffer
+ return {self.memory_key: buffer}
+
+ @root_validator()
+ def validate_prompt_input_variables(cls, values: Dict) -> Dict:
+ """Validate that prompt input variables are consistent."""
+ prompt_variables = values["prompt"].input_variables
+ expected_keys = {"summary", "new_lines"}
+ if expected_keys != set(prompt_variables):
+ raise ValueError(
+ "Got unexpected prompt input variables. The prompt expects "
+ f"{prompt_variables}, but it should have {expected_keys}."
+ )
+ return values
+
+ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
+ """Save context from this conversation to buffer."""
+ super().save_context(inputs, outputs)
+ self.buffer = self.predict_new_summary(self.chat_memory.messages[-2:], self.buffer)
+
+ def clear(self) -> None:
+ """Clear memory contents."""
+ super().clear()
+ self.buffer = ""
diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/__init__.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
new file mode 100644
index 000000000000..e80befdaccfa
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/base.py
@@ -0,0 +1,214 @@
+"""
+Chain for question-answering against a vector database.
+
+Modified from Original Source
+
+This code is based on LangChain Ai's langchain, which can be found at
+https://github.com/langchain-ai/langchain
+The original code is licensed under the MIT license.
+"""
+from __future__ import annotations
+
+import copy
+import inspect
+from typing import Any, Dict, List, Optional
+
+from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
+from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
+from langchain.callbacks.manager import AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, Callbacks
+from langchain.chains.llm import LLMChain
+from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
+from langchain.chains.retrieval_qa.base import BaseRetrievalQA
+from langchain.prompts import PromptTemplate
+from langchain.pydantic_v1 import Field
+from langchain.schema import BaseRetriever, Document
+from langchain.schema.language_model import BaseLanguageModel
+
+class CustomBaseRetrievalQA(BaseRetrievalQA):
+ """Base class for question-answering chains."""
+
+ @classmethod
+ def from_llm(
+ cls,
+ llm: BaseLanguageModel,
+ prompt: Optional[PromptTemplate] = None,
+ callbacks: Callbacks = None,
+ **kwargs: Any,
+ ) -> BaseRetrievalQA:
+ """Initialize from LLM."""
+ llm_kwargs = kwargs.pop("llm_kwargs", {})
+ _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
+ llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks, llm_kwargs=llm_kwargs)
+ document_prompt = kwargs.get(
+ "document_prompt", PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")
+ )
+ combine_documents_chain = CustomStuffDocumentsChain(
+ llm_chain=llm_chain,
+ document_variable_name="context",
+ document_prompt=document_prompt,
+ callbacks=callbacks,
+ )
+
+ return cls(
+ combine_documents_chain=combine_documents_chain,
+ callbacks=callbacks,
+ **kwargs,
+ )
+
+ @classmethod
+ def from_chain_type(
+ cls,
+ llm: BaseLanguageModel,
+ chain_type: str = "stuff",
+ chain_type_kwargs: Optional[dict] = None,
+ **kwargs: Any,
+ ) -> BaseRetrievalQA:
+ """Load chain from chain type."""
+ llm_kwargs = kwargs.pop("llm_kwargs", {})
+ _chain_type_kwargs = chain_type_kwargs or {}
+ combine_documents_chain = load_qa_chain(llm, chain_type=chain_type, **_chain_type_kwargs, llm_kwargs=llm_kwargs)
+ return cls(combine_documents_chain=combine_documents_chain, **kwargs)
+
+ def _call(
+ self,
+ inputs: Dict[str, Any],
+ run_manager: Optional[CallbackManagerForChainRun] = None,
+ ) -> Dict[str, Any]:
+ """Run get_relevant_text and llm on input query.
+
+ If chain has 'return_source_documents' as 'True', returns
+ the retrieved documents as well under the key 'source_documents'.
+
+ Example:
+ .. code-block:: python
+
+ res = indexqa({'query': 'This is my query'})
+ answer, docs = res['result'], res['source_documents']
+ """
+ _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
+ question = inputs[self.input_key]
+ accepts_run_manager = "run_manager" in inspect.signature(self._get_docs).parameters
+ if accepts_run_manager:
+ docs = self._get_docs(question, run_manager=_run_manager)
+ else:
+ docs = self._get_docs(question) # type: ignore[call-arg]
+
+ kwargs = {
+ k: v
+ for k, v in inputs.items()
+ if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
+ }
+ answers = []
+ if self.combine_documents_chain.memory is not None:
+ buffered_history_backup, summarized_history_temp_backup = copy.deepcopy(
+ self.combine_documents_chain.memory.buffered_history
+ ), copy.deepcopy(self.combine_documents_chain.memory.summarized_history_temp)
+ else:
+ buffered_history_backup = None
+ summarized_history_temp_backup = None
+
+ answer = self.combine_documents_chain.run(
+ input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
+ )
+ if summarized_history_temp_backup is not None and buffered_history_backup is not None:
+ (
+ self.combine_documents_chain.memory.buffered_history,
+ self.combine_documents_chain.memory.summarized_history_temp,
+ ) = copy.deepcopy(buffered_history_backup), copy.deepcopy(summarized_history_temp_backup)
+
+ # if rejection_trigger_keywords is not given, return the response from LLM directly
+ rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
+ answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) else None
+ if answer is None:
+ answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
+ if self.combine_documents_chain.memory is not None:
+ self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
+
+ if self.return_source_documents:
+ return {self.output_key: answer, "source_documents": docs}
+ else:
+ return {self.output_key: answer}
+
+ async def _acall(
+ self,
+ inputs: Dict[str, Any],
+ run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
+ ) -> Dict[str, Any]:
+ """Run get_relevant_text and llm on input query.
+
+ If chain has 'return_source_documents' as 'True', returns
+ the retrieved documents as well under the key 'source_documents'.
+
+ Example:
+ .. code-block:: python
+
+ res = indexqa({'query': 'This is my query'})
+ answer, docs = res['result'], res['source_documents']
+ """
+ _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
+ question = inputs[self.input_key]
+ accepts_run_manager = "run_manager" in inspect.signature(self._aget_docs).parameters
+ if accepts_run_manager:
+ docs = await self._aget_docs(question, run_manager=_run_manager)
+ else:
+ docs = await self._aget_docs(question) # type: ignore[call-arg]
+ kwargs = {
+ k: v
+ for k, v in inputs.items()
+ if k in ["stop", "temperature", "top_k", "top_p", "max_new_tokens", "doc_prefix"]
+ }
+ answer = await self.combine_documents_chain.arun(
+ input_documents=docs, question=question, callbacks=_run_manager.get_child(), **kwargs
+ )
+ # if rejection_trigger_keywords is not given, return the response from LLM directly
+ rejection_trigger_keywrods = inputs.get('rejection_trigger_keywrods', [])
+ answer = answer if all([rej not in answer for rej in rejection_trigger_keywrods]) or len(rejection_trigger_keywrods)==0 else None
+ if answer is None:
+ answer = inputs.get('rejection_answer', "抱歉,根据提供的信息无法回答该问题。")
+ self.combine_documents_chain.memory.save_context({"question": question}, {"output": answer})
+
+ if self.return_source_documents:
+ return {self.output_key: answer, "source_documents": docs}
+ else:
+ return {self.output_key: answer}
+
+
+class RetrievalQA(CustomBaseRetrievalQA):
+ """Chain for question-answering against an index.
+
+ Example:
+ .. code-block:: python
+
+ from langchain.llms import OpenAI
+ from langchain.chains import RetrievalQA
+ from langchain.faiss import FAISS
+ from langchain.vectorstores.base import VectorStoreRetriever
+ retriever = VectorStoreRetriever(vectorstore=FAISS(...))
+ retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)
+
+ """
+
+ retriever: BaseRetriever = Field(exclude=True)
+
+ def _get_docs(
+ self,
+ question: str,
+ *,
+ run_manager: CallbackManagerForChainRun,
+ ) -> List[Document]:
+ """Get docs."""
+ return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())
+
+ async def _aget_docs(
+ self,
+ question: str,
+ *,
+ run_manager: AsyncCallbackManagerForChainRun,
+ ) -> List[Document]:
+ """Get docs."""
+ return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())
+
+ @property
+ def _chain_type(self) -> str:
+ """Return the chain type."""
+ return "retrieval_qa"
diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py
new file mode 100644
index 000000000000..a2b1f81e34b9
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/load_chain.py
@@ -0,0 +1,87 @@
+"""
+Load question answering chains.
+For now, only the stuffed chain is modified
+
+Modified from Original Source
+
+This code is based on LangChain Ai's langchain, which can be found at
+https://github.com/langchain-ai/langchain
+The original code is licensed under the MIT license.
+"""
+import copy
+from typing import Any, Mapping, Optional, Protocol
+
+from colossalqa.chain.retrieval_qa.stuff import CustomStuffDocumentsChain
+from langchain.callbacks.base import BaseCallbackManager
+from langchain.callbacks.manager import Callbacks
+from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
+from langchain.chains.llm import LLMChain
+from langchain.chains.question_answering import stuff_prompt
+from langchain.schema.language_model import BaseLanguageModel
+from langchain.schema.prompt_template import BasePromptTemplate
+
+
+class LoadingCallable(Protocol):
+ """Interface for loading the combine documents chain."""
+
+ def __call__(self, llm: BaseLanguageModel, **kwargs: Any) -> BaseCombineDocumentsChain:
+ """Callable to load the combine documents chain."""
+
+
+def _load_stuff_chain(
+ llm: BaseLanguageModel,
+ prompt: Optional[BasePromptTemplate] = None,
+ document_variable_name: str = "context",
+ verbose: Optional[bool] = None,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ callbacks: Callbacks = None,
+ **kwargs: Any,
+) -> CustomStuffDocumentsChain:
+ _prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
+ if "llm_kwargs" in kwargs:
+ llm_kwargs = copy.deepcopy(kwargs["llm_kwargs"])
+ del kwargs["llm_kwargs"]
+ else:
+ llm_kwargs = {}
+ llm_chain = LLMChain(
+ llm=llm,
+ prompt=_prompt,
+ verbose=verbose,
+ callback_manager=callback_manager,
+ callbacks=callbacks,
+ llm_kwargs=llm_kwargs,
+ )
+ return CustomStuffDocumentsChain(
+ llm_chain=llm_chain,
+ document_variable_name=document_variable_name,
+ verbose=verbose,
+ callback_manager=callback_manager,
+ callbacks=callbacks,
+ **kwargs,
+ )
+
+
+def load_qa_chain(
+ llm: BaseLanguageModel,
+ chain_type: str = "stuff",
+ verbose: Optional[bool] = None,
+ callback_manager: Optional[BaseCallbackManager] = None,
+ **kwargs: Any,
+) -> BaseCombineDocumentsChain:
+ """Load question answering chain.
+
+ Args:
+ llm: Language Model to use in the chain.
+ chain_type: Type of document combining chain to use. Should be one of "stuff",
+ "map_reduce", "map_rerank", and "refine".
+ verbose: Whether chains should be run in verbose mode or not. Note that this
+ applies to all chains that make up the final chain.
+ callback_manager: Callback manager to use for the chain.
+
+ Returns:
+ A chain to use for question answering.
+ """
+ loader_mapping: Mapping[str, LoadingCallable] = {"stuff": _load_stuff_chain}
+ if chain_type not in loader_mapping:
+ raise ValueError(f"Got unsupported chain type: {chain_type}. " f"Should be one of {loader_mapping.keys()}")
+ return loader_mapping[chain_type](llm, verbose=verbose, callback_manager=callback_manager, **kwargs)
diff --git a/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py b/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py
new file mode 100644
index 000000000000..bf7ad0ffce28
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/chain/retrieval_qa/stuff.py
@@ -0,0 +1,91 @@
+"""
+Chain that combines documents by stuffing into context
+
+Modified from Original Source
+
+This code is based on LangChain Ai's langchain, which can be found at
+https://github.com/langchain-ai/langchain
+The original code is licensed under the MIT license.
+"""
+import copy
+from typing import Any, List
+
+from langchain.chains.combine_documents.stuff import StuffDocumentsChain
+from langchain.docstore.document import Document
+from langchain.schema import format_document
+
+
+class CustomStuffDocumentsChain(StuffDocumentsChain):
+ """Chain that combines documents by stuffing into context.
+
+ This chain takes a list of documents and first combines them into a single string.
+ It does this by formatting each document into a string with the `document_prompt`
+ and then joining them together with `document_separator`. It then adds that new
+ string to the inputs with the variable name set by `document_variable_name`.
+ Those inputs are then passed to the `llm_chain`.
+
+ Example:
+ .. code-block:: python
+
+ from langchain.chains import StuffDocumentsChain, LLMChain
+ from langchain.prompts import PromptTemplate
+ from langchain.llms import OpenAI
+
+ # This controls how each document will be formatted. Specifically,
+ # it will be passed to `format_document` - see that function for more
+ # details.
+ document_prompt = PromptTemplate(
+ input_variables=["page_content"],
+ template="{page_content}"
+ )
+ document_variable_name = "context"
+ llm = OpenAI()
+ # The prompt here should take as an input variable the
+ # `document_variable_name`
+ prompt = PromptTemplate.from_template(
+ "Summarize this content: {context}"
+ )
+ llm_chain = LLMChain(llm=llm, prompt=prompt)
+ chain = StuffDocumentsChain(
+ llm_chain=llm_chain,
+ document_prompt=document_prompt,
+ document_variable_name=document_variable_name
+ )
+ """
+
+ def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
+ """Construct inputs from kwargs and docs.
+
+ Format and the join all the documents together into one input with name
+ `self.document_variable_name`. The pluck any additional variables
+ from **kwargs.
+
+ Args:
+ docs: List of documents to format and then join into single input
+ **kwargs: additional inputs to chain, will pluck any other required
+ arguments from here.
+
+ Returns:
+ dictionary of inputs to LLMChain
+ """
+ # Format each document according to the prompt
+
+ # if the document is in the key-value format has a 'is_key_value_mapping'=True in meta_data and has 'value' in metadata
+ # use the value to replace the key
+ doc_prefix = kwargs.get("doc_prefix", "Supporting Document")
+ docs_ = []
+ for id, doc in enumerate(docs):
+ doc_ = copy.deepcopy(doc)
+ if doc_.metadata.get("is_key_value_mapping", False) and "value" in doc_.metadata:
+ doc_.page_content = str(doc_.metadata["value"])
+ prefix = doc_prefix + str(id)
+ doc_.page_content = str(prefix + ":" + (" " if doc_.page_content[0] != " " else "") + doc_.page_content)
+ docs_.append(doc_)
+
+ doc_strings = [format_document(doc, self.document_prompt) for doc in docs_]
+ arg_list = ["stop", "temperature", "top_k", "top_p", "max_new_tokens"]
+ arg_list.extend(self.llm_chain.prompt.input_variables)
+ # Join the documents together to put them in the prompt.
+ inputs = {k: v for k, v in kwargs.items() if k in arg_list}
+ inputs[self.document_variable_name] = self.document_separator.join(doc_strings)
+ return inputs
diff --git a/applications/ColossalQA/colossalqa/data_loader/__init__.py b/applications/ColossalQA/colossalqa/data_loader/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalQA/colossalqa/data_loader/document_loader.py b/applications/ColossalQA/colossalqa/data_loader/document_loader.py
new file mode 100644
index 000000000000..0fe1e4d1a00c
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/data_loader/document_loader.py
@@ -0,0 +1,128 @@
+"""
+Class for loading document type data
+"""
+
+import glob
+from typing import List
+
+from colossalqa.mylogging import get_logger
+from langchain.document_loaders import (
+ JSONLoader,
+ PyPDFLoader,
+ TextLoader,
+ UnstructuredHTMLLoader,
+ UnstructuredMarkdownLoader,
+)
+from langchain.document_loaders.csv_loader import CSVLoader
+
+logger = get_logger()
+
+SUPPORTED_DATA_FORMAT = [".csv", ".json", ".html", ".md", ".pdf", ".txt", ".jsonl"]
+
+
+class DocumentLoader:
+ """
+ Load documents from different files into list of langchain Documents
+ """
+
+ def __init__(self, files: List, **kwargs) -> None:
+ """
+ Args:
+ files: list of files (list[file path, name])
+ **kwargs: keyword type arguments, useful for certain document types
+ """
+ self.data = {}
+ self.kwargs = kwargs
+
+ for item in files:
+ path = item[0] if isinstance(item, list) else item
+ logger.info(f"Loading data from {path}")
+ self.load_data(path)
+ logger.info("Data loaded")
+
+ self.all_data = []
+ for key in self.data:
+ if isinstance(self.data[key], list):
+ for item in self.data[key]:
+ if isinstance(item, list):
+ self.all_data.extend(item)
+ else:
+ self.all_data.append(item)
+
+ def load_data(self, path: str) -> None:
+ """
+ Load data. Please refer to https://python.langchain.com/docs/modules/data_connection/document_loaders/
+ for sepcific format requirements.
+ Args:
+ path: path to a file
+ To load files with glob path, here are some examples.
+ Load all file from directory: folder1/folder2/*
+ Load all pdf file from directory: folder1/folder2/*.pdf
+ """
+ files = []
+
+ # Handle glob expression
+ try:
+ files = glob.glob(path)
+ except Exception as e:
+ logger.error(e)
+ if len(files) == 0:
+ raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
+ elif len(files) == 1:
+ path = files[0]
+ else:
+ for file in files:
+ self.load_data(file)
+ return
+
+ # Load data if the path is a file
+ logger.info(f"load {path}", verbose=True)
+ if path.endswith(".csv"):
+ # Load csv
+ loader = CSVLoader(file_path=path, encoding="utf8")
+ data = loader.load()
+ self.data[path] = data
+ elif path.endswith(".txt"):
+ # Load txt
+ loader = TextLoader(path, encoding="utf8")
+ data = loader.load()
+ self.data[path] = data
+ elif path.endswith("html"):
+ # Load html
+ loader = UnstructuredHTMLLoader(path, encoding="utf8")
+ data = loader.load()
+ self.data[path] = data
+ elif path.endswith("json"):
+ # Load json
+ loader = JSONLoader(
+ file_path=path,
+ jq_schema=self.kwargs.get("jq_schema", ".data[]"),
+ content_key=self.kwargs.get("content_key", "content"),
+ metadata_func=self.kwargs.get("metadata_func", None),
+ )
+
+ data = loader.load()
+ self.data[path] = data
+ elif path.endswith("jsonl"):
+ # Load jsonl
+ loader = JSONLoader(
+ file_path=path, jq_schema=self.kwargs.get("jq_schema", ".data[].content"), json_lines=True
+ )
+ data = loader.load()
+ self.data[path] = data
+ elif path.endswith(".md"):
+ # Load markdown
+ loader = UnstructuredMarkdownLoader(path)
+ data = loader.load()
+ self.data[path] = data
+ elif path.endswith(".pdf"):
+ # Load pdf
+ loader = PyPDFLoader(path)
+ data = loader.load_and_split()
+ self.data[path] = data
+ else:
+ if "." in path.split("/")[-1]:
+ raise ValueError(f"Unsupported file format {path}. Supported formats: {SUPPORTED_DATA_FORMAT}")
+ else:
+ # May ba a directory, we strictly follow the glob path and will not load files in subdirectories
+ pass
diff --git a/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py b/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py
new file mode 100644
index 000000000000..cad48254498e
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/data_loader/table_dataloader.py
@@ -0,0 +1,119 @@
+'''
+Class for loading table type data. please refer to Pandas-Input/Output for file format details.
+'''
+
+
+import os
+import glob
+import pandas as pd
+from sqlalchemy import create_engine
+from colossalqa.utils import drop_table
+from colossalqa.mylogging import get_logger
+
+logger = get_logger()
+
+SUPPORTED_DATA_FORMAT = ['.csv','.xlsx', '.xls','.json','.html','.h5', '.hdf5','.parquet','.feather','.dta']
+
+class TableLoader:
+ '''
+ Load tables from different files and serve a sql database for database operations
+ '''
+ def __init__(self, files: str,
+ sql_path:str='sqlite:///mydatabase.db',
+ verbose=False, **kwargs) -> None:
+ '''
+ Args:
+ files: list of files (list[file path, name])
+ sql_path: how to serve the sql database
+ **kwargs: keyword type arguments, useful for certain document types
+ '''
+ self.data = {}
+ self.verbose = verbose
+ self.sql_path = sql_path
+ self.kwargs = kwargs
+ self.sql_engine = create_engine(self.sql_path)
+ drop_table(self.sql_engine)
+
+ self.sql_engine = create_engine(self.sql_path)
+ for item in files:
+ path = item[0]
+ dataset_name = item[1]
+ if not os.path.exists(path):
+ raise FileNotFoundError(f"{path} doesn't exists")
+ if not any([path.endswith(i) for i in SUPPORTED_DATA_FORMAT]):
+ raise TypeError(f"{path} not supported. Supported type {SUPPORTED_DATA_FORMAT}")
+
+ logger.info("loading data", verbose=self.verbose)
+ self.load_data(path)
+ logger.info("data loaded", verbose=self.verbose)
+ self.to_sql(path, dataset_name)
+
+ def load_data(self, path):
+ '''
+ Load data and serve the data as sql database.
+ Data must be in pandas format
+ '''
+ files = []
+ # Handle glob expression
+ try:
+ files = glob.glob(path)
+ except Exception as e:
+ logger.error(e)
+ if len(files)==0:
+ raise ValueError("Unsupported file/directory format. For directories, please use glob expression")
+ elif len(files)==1:
+ path = files[0]
+ else:
+ for file in files:
+ self.load_data(file)
+
+ if path.endswith('.csv'):
+ # Load csv
+ self.data[path] = pd.read_csv(path)
+ elif path.endswith('.xlsx') or path.endswith('.xls'):
+ # Load excel
+ self.data[path] = pd.read_excel(path) # You can adjust the sheet_name as needed
+ elif path.endswith('.json'):
+ # Load json
+ self.data[path] = pd.read_json(path)
+ elif path.endswith('.html'):
+ # Load html
+ html_tables = pd.read_html(path)
+ # Choose the desired table from the list of DataFrame objects
+ self.data[path] = html_tables[0] # You may need to adjust this index
+ elif path.endswith('.h5') or path.endswith('.hdf5'):
+ # Load h5
+ self.data[path] = pd.read_hdf(path, key=self.kwargs.get('key', 'data')) # You can adjust the key as needed
+ elif path.endswith('.parquet'):
+ # Load parquet
+ self.data[path] = pd.read_parquet(path, engine='fastparquet')
+ elif path.endswith('.feather'):
+ # Load feather
+ self.data[path] = pd.read_feather(path)
+ elif path.endswith('.dta'):
+ # Load dta
+ self.data[path] = pd.read_stata(path)
+ else:
+ raise ValueError("Unsupported file format")
+
+ def to_sql(self, path, table_name):
+ '''
+ Serve the data as sql database.
+ '''
+ self.data[path].to_sql(table_name, con=self.sql_engine, if_exists='replace', index=False)
+ logger.info(f"Loaded to Sqlite3\nPath: {path}", verbose=self.verbose)
+ return self.sql_path
+
+ def get_sql_path(self):
+ return self.sql_path
+
+ def __del__(self):
+ if self.sql_engine:
+ drop_table(self.sql_engine)
+ self.sql_engine.dispose()
+ del self.data
+ del self.sql_engine
+
+
+
+
diff --git a/applications/ColossalQA/colossalqa/local/__init__.py b/applications/ColossalQA/colossalqa/local/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
new file mode 100644
index 000000000000..62aead66c54b
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/local/colossalcloud_llm.py
@@ -0,0 +1,125 @@
+"""
+LLM wrapper for LLMs running on ColossalCloud Platform
+
+Usage:
+
+os.environ['URL'] = ""
+os.environ['HOST'] = ""
+
+gen_config = {
+ 'max_new_tokens': 100,
+ # 'top_k': 2,
+ 'top_p': 0.9,
+ 'temperature': 0.5,
+ 'repetition_penalty': 2,
+ }
+
+llm = ColossalCloudLLM(n=1)
+llm.set_auth_config()
+resp = llm(prompt='What do you call a three-ton kangaroo?', **gen_config)
+print(resp) # super-heavyweight awesome-natured yawning Australian creature!
+
+"""
+import json
+from typing import Any, List, Mapping, Optional
+
+import requests
+from langchain.llms.base import LLM
+from langchain.utils import get_from_dict_or_env
+
+
+class ColossalCloudLLM(LLM):
+ """
+ A custom LLM class that integrates LLMs running on the ColossalCloud Platform
+
+ """
+ n: int
+ gen_config: dict = None
+ auth_config: dict = None
+ valid_gen_para: list = ['max_new_tokens', 'top_k',
+ 'top_p', 'temperature', 'repetition_penalty']
+
+ def __init__(self, gen_config=None, **kwargs):
+ """
+ Args:
+ gen_config: config for generation,
+ max_new_tokens: 50 by default
+ top_k: (1, vocab_size)
+ top_p: (0, 1) if not None
+ temperature: (0, inf) if not None
+ repetition_penalty: (1, inf) if not None
+ """
+ super(ColossalCloudLLM, self).__init__(**kwargs)
+ if gen_config is None:
+ self.gen_config = {"max_new_tokens": 50}
+ else:
+ assert "max_new_tokens" in gen_config, "max_new_tokens is a compulsory key in the gen config"
+ self.gen_config = gen_config
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {"n": self.n}
+
+ @property
+ def _llm_type(self) -> str:
+ return 'ColossalCloudLLM'
+
+ def set_auth_config(self, **kwargs):
+ url = get_from_dict_or_env(kwargs, "url", "URL")
+ host = get_from_dict_or_env(kwargs, "host", "HOST")
+
+ auth_config = {}
+ auth_config['endpoint'] = url
+ auth_config['Host'] = host
+ self.auth_config = auth_config
+
+ def _call(self, prompt: str, stop=None, **kwargs: Any) -> str:
+ """
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: A list of strings to stop generation when encountered
+
+ Returns:
+ The string generated by the model
+ """
+ # Update the generation arguments
+ for key, value in kwargs.items():
+ if key not in self.valid_gen_para:
+ raise KeyError(f"Invalid generation parameter: '{key}'. Valid keys are: {', '.join(self.valid_gen_para)}")
+ if key in self.gen_config:
+ self.gen_config[key] = value
+
+ resp_text = self.text_completion(prompt, self.gen_config, self.auth_config)
+ # TODO: This may cause excessive tokens count
+ if stop is not None:
+ for stopping_words in stop:
+ if stopping_words in resp_text:
+ resp_text = resp_text.split(stopping_words)[0]
+ return resp_text
+
+
+ def text_completion(self, prompt, gen_config, auth_config):
+ # Complusory Parameters
+ endpoint = auth_config.pop('endpoint')
+ max_new_tokens = gen_config.pop('max_new_tokens')
+ # Optional Parameters
+ optional_params = ['top_k', 'top_p', 'temperature', 'repetition_penalty'] # Self.optional
+ gen_config = {key: gen_config[key] for key in optional_params if key in gen_config}
+ # Define the data payload
+ data = {
+ "max_new_tokens": max_new_tokens,
+ "history": [
+ {"instruction": prompt, "response": ""}
+ ],
+ **gen_config
+ }
+ headers = {
+ "Content-Type": "application/json",
+ **auth_config # 'Host',
+ }
+ # Make the POST request
+ response = requests.post(endpoint, headers=headers, data=json.dumps(data))
+ response.raise_for_status() # raise error if return code is not 200(success)
+ # Check the response
+ return response.text
diff --git a/applications/ColossalQA/colossalqa/local/llm.py b/applications/ColossalQA/colossalqa/local/llm.py
new file mode 100644
index 000000000000..ff7346adcf61
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/local/llm.py
@@ -0,0 +1,183 @@
+"""
+API and LLM warpper class for running LLMs locally
+
+Usage:
+
+import os
+model_path = os.environ.get("ZH_MODEL_PATH")
+model_name = "chatglm2"
+colossal_api = ColossalAPI(model_name, model_path)
+llm = ColossalLLM(n=1, api=colossal_api)
+TEST_PROMPT_CHATGLM="续写文章:惊蛰一过,春寒加剧。先是料料峭峭,继而雨季开始,"
+logger.info(llm(TEST_PROMPT_CHATGLM, max_new_tokens=100), verbose=True)
+
+"""
+from typing import Any, List, Mapping, Optional
+
+import torch
+from colossalqa.local.utils import get_response, post_http_request
+from colossalqa.mylogging import get_logger
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+from langchain.llms.base import LLM
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+logger = get_logger()
+
+
+class ColossalAPI:
+ """
+ API for calling LLM.generate
+ """
+
+ __instances = dict()
+
+ def __init__(self, model_type: str, model_path: str, ckpt_path: str = None) -> None:
+ """
+ Configurate model
+ """
+ if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
+ return
+ else:
+ ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")] = self
+ self.model_type = model_type
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True)
+
+ if ckpt_path is not None:
+ state_dict = torch.load(ckpt_path)
+ self.model.load_state_dict(state_dict)
+ self.model.to(torch.cuda.current_device())
+
+ # Configurate tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ self.model.eval()
+
+ @staticmethod
+ def get_api(model_type: str, model_path: str, ckpt_path: str = None):
+ if model_type + model_path + (ckpt_path or "") in ColossalAPI.__instances:
+ return ColossalAPI.__instances[model_type + model_path + (ckpt_path or "")]
+ else:
+ return ColossalAPI(model_type, model_path, ckpt_path)
+
+ def generate(self, input: str, **kwargs) -> str:
+ """
+ Generate response given the prompt
+ Args:
+ input: input string
+ **kwargs: language model keyword type arguments, such as top_k, top_p, temperature, max_new_tokens...
+ Returns:
+ output: output string
+ """
+ if self.model_type in ["chatglm", "chatglm2"]:
+ inputs = {
+ k: v.to(torch.cuda.current_device()) for k, v in self.tokenizer(input, return_tensors="pt").items()
+ }
+ else:
+ inputs = {
+ "input_ids": self.tokenizer(input, return_tensors="pt")["input_ids"].to(torch.cuda.current_device())
+ }
+
+ output = self.model.generate(**inputs, **kwargs)
+ output = output.cpu()
+ prompt_len = inputs["input_ids"].size(1)
+ response = output[0, prompt_len:]
+ output = self.tokenizer.decode(response, skip_special_tokens=True)
+ return output
+
+
+class VllmAPI:
+ def __init__(self, host: str = "localhost", port: int = 8077) -> None:
+ # Configurate api for model served through web
+ self.host = host
+ self.port = port
+ self.url = f"http://{self.host}:{self.port}/generate"
+
+ def generate(self, input: str, **kwargs):
+ output = get_response(post_http_request(input, self.url, **kwargs))[0]
+ return output[len(input) :]
+
+
+class ColossalLLM(LLM):
+ """
+ Langchain LLM wrapper for a local LLM
+ """
+
+ n: int
+ api: Any
+ kwargs = {"max_new_tokens": 100}
+
+ @property
+ def _llm_type(self) -> str:
+ return "custom"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
+ for k in self.kwargs:
+ if k not in kwargs:
+ kwargs[k] = self.kwargs[k]
+
+ generate_args = {k: kwargs[k] for k in kwargs if k not in ["stop", "n"]}
+ out = self.api.generate(prompt, **generate_args)
+ if isinstance(stop, list) and len(stop) != 0:
+ for stopping_words in stop:
+ if stopping_words in out:
+ out = out.split(stopping_words)[0]
+ logger.info(f"{prompt}{out}", verbose=self.verbose)
+ return out
+
+ @property
+ def _identifying_params(self) -> Mapping[str, int]:
+ """Get the identifying parameters."""
+ return {"n": self.n}
+
+
+class VllmLLM(LLM):
+ """
+ Langchain LLM wrapper for a local LLM
+ """
+
+ n: int
+ api: Any
+ kwargs = {"max_new_tokens": 100}
+
+ @property
+ def _llm_type(self) -> str:
+ return "custom"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ for k in self.kwargs:
+ if k not in kwargs:
+ kwargs[k] = self.kwargs[k]
+ logger.info(f"kwargs:{kwargs}\nstop:{stop}\nprompt:{prompt}", verbose=self.verbose)
+ generate_args = {k: kwargs[k] for k in kwargs if k in ["n", "max_tokens", "temperature", "stream"]}
+ out = self.api.generate(prompt, **generate_args)
+ if len(stop) != 0:
+ for stopping_words in stop:
+ if stopping_words in out:
+ out = out.split(stopping_words)[0]
+ logger.info(f"{prompt}{out}", verbose=self.verbose)
+ return out
+
+ def set_host_port(self, host: str = "localhost", port: int = 8077, **kwargs) -> None:
+ if "max_tokens" not in kwargs:
+ kwargs["max_tokens"] = 100
+ self.kwargs = kwargs
+ self.api = VllmAPI(host=host, port=port)
+
+ @property
+ def _identifying_params(self) -> Mapping[str, int]:
+ """Get the identifying parameters."""
+ return {"n": self.n}
+
diff --git a/applications/ColossalQA/colossalqa/local/pangu_llm.py b/applications/ColossalQA/colossalqa/local/pangu_llm.py
new file mode 100644
index 000000000000..b8befa92b96f
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/local/pangu_llm.py
@@ -0,0 +1,150 @@
+"""
+LLM wrapper for Pangu
+
+Usage:
+
+# URL: “盘古大模型套件管理”->点击“服务管理”->“模型列表”->点击想要使用的模型的“复制路径”
+# USERNAME: 华为云控制台:“我的凭证”->“API凭证”下的“IAM用户名”,也就是你登录IAM账户的名字
+# PASSWORD: IAM用户的密码
+# DOMAIN_NAME: 华为云控制台:“我的凭证”->“API凭证”下的“用户名”,也就是公司管理IAM账户的总账户名
+
+os.environ["URL"] = ""
+os.environ["URLNAME"] = ""
+os.environ["PASSWORD"] = ""
+os.environ["DOMAIN_NAME"] = ""
+
+pg = Pangu(id=1)
+pg.set_auth_config()
+
+res = pg('你是谁') # 您好,我是华为盘古大模型。我能够通过和您对话互动为您提供帮助。请问您有什么想问我的吗?
+"""
+
+import http.client
+import json
+from typing import Any, List, Mapping, Optional
+
+import requests
+from langchain.llms.base import LLM
+from langchain.utils import get_from_dict_or_env
+
+
+class Pangu(LLM):
+ """
+ A custom LLM class that integrates pangu models
+
+ """
+
+ n: int
+ gen_config: dict = None
+ auth_config: dict = None
+
+ def __init__(self, gen_config=None, **kwargs):
+ super(Pangu, self).__init__(**kwargs)
+ if gen_config is None:
+ self.gen_config = {"user": "User", "max_tokens": 50, "temperature": 0.95, "n": 1}
+ else:
+ self.gen_config = gen_config
+
+ @property
+ def _identifying_params(self) -> Mapping[str, Any]:
+ """Get the identifying parameters."""
+ return {"n": self.n}
+
+ @property
+ def _llm_type(self) -> str:
+ return "pangu"
+
+ def _call(self, prompt: str, stop: Optional[List[str]] = None, **kwargs) -> str:
+ """
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: A list of strings to stop generation when encountered
+
+ Returns:
+ The string generated by the model
+ """
+ # Update the generation arguments
+ for key, value in kwargs.items():
+ if key in self.gen_config:
+ self.gen_config[key] = value
+
+ response = self.text_completion(prompt, self.gen_config, self.auth_config)
+ text = response["choices"][0]["text"]
+ if stop is not None:
+ for stopping_words in stop:
+ if stopping_words in text:
+ text = text.split(stopping_words)[0]
+ return text
+
+ def set_auth_config(self, **kwargs):
+ url = get_from_dict_or_env(kwargs, "url", "URL")
+ username = get_from_dict_or_env(kwargs, "username", "USERNAME")
+ password = get_from_dict_or_env(kwargs, "password", "PASSWORD")
+ domain_name = get_from_dict_or_env(kwargs, "domain_name", "DOMAIN_NAME")
+
+ region = url.split(".")[1]
+ auth_config = {}
+ auth_config["endpoint"] = url[url.find("https://") + 8 : url.find(".com") + 4]
+ auth_config["resource_path"] = url[url.find(".com") + 4 :]
+ auth_config["auth_token"] = self.get_latest_auth_token(region, username, password, domain_name)
+ self.auth_config = auth_config
+
+ def get_latest_auth_token(self, region, username, password, domain_name):
+ url = f"https://iam.{region}.myhuaweicloud.com/v3/auth/tokens"
+ payload = json.dumps(
+ {
+ "auth": {
+ "identity": {
+ "methods": ["password"],
+ "password": {"user": {"name": username, "password": password, "domain": {"name": domain_name}}},
+ },
+ "scope": {"project": {"name": region}},
+ }
+ }
+ )
+ headers = {"Content-Type": "application/json"}
+
+ response = requests.request("POST", url, headers=headers, data=payload)
+ return response.headers["X-Subject-Token"]
+
+ def text_completion(self, text, gen_config, auth_config):
+ conn = http.client.HTTPSConnection(auth_config["endpoint"])
+ payload = json.dumps(
+ {
+ "prompt": text,
+ "user": gen_config["user"],
+ "max_tokens": gen_config["max_tokens"],
+ "temperature": gen_config["temperature"],
+ "n": gen_config["n"],
+ }
+ )
+ headers = {
+ "X-Auth-Token": auth_config["auth_token"],
+ "Content-Type": "application/json",
+ }
+ conn.request("POST", auth_config["resource_path"], payload, headers)
+ res = conn.getresponse()
+ data = res.read()
+ data = json.loads(data.decode("utf-8"))
+ return data
+
+ def chat_model(self, messages, gen_config, auth_config):
+ conn = http.client.HTTPSConnection(auth_config["endpoint"])
+ payload = json.dumps(
+ {
+ "messages": messages,
+ "user": gen_config["user"],
+ "max_tokens": gen_config["max_tokens"],
+ "temperature": gen_config["temperature"],
+ "n": gen_config["n"],
+ }
+ )
+ headers = {
+ "X-Auth-Token": auth_config["auth_token"],
+ "Content-Type": "application/json",
+ }
+ conn.request("POST", auth_config["resource_path"], payload, headers)
+ res = conn.getresponse()
+ data = res.read()
+ data = json.loads(data.decode("utf-8"))
+ return data
diff --git a/applications/ColossalQA/colossalqa/local/utils.py b/applications/ColossalQA/colossalqa/local/utils.py
new file mode 100644
index 000000000000..ed90264cad8d
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/local/utils.py
@@ -0,0 +1,29 @@
+"""
+Generation utilities
+"""
+import json
+from typing import List
+
+import requests
+
+
+def post_http_request(
+ prompt: str, api_url: str, n: int = 1, max_tokens: int = 100, temperature: float = 0.0, stream: bool = False
+) -> requests.Response:
+ headers = {"User-Agent": "Test Client"}
+ pload = {
+ "prompt": prompt,
+ "n": 1,
+ "use_beam_search": False,
+ "temperature": temperature,
+ "max_tokens": max_tokens,
+ "stream": stream,
+ }
+ response = requests.post(api_url, headers=headers, json=pload, stream=True, timeout=3)
+ return response
+
+
+def get_response(response: requests.Response) -> List[str]:
+ data = json.loads(response.content)
+ output = data["text"]
+ return output
diff --git a/applications/ColossalQA/colossalqa/memory.py b/applications/ColossalQA/colossalqa/memory.py
new file mode 100644
index 000000000000..255df68a367e
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/memory.py
@@ -0,0 +1,168 @@
+"""
+Implement a memory class for storing conversation history
+Support long term and short term memory
+"""
+from typing import Any, Dict, List
+
+from colossalqa.chain.memory.summary import ConversationSummaryMemory
+from colossalqa.chain.retrieval_qa.load_chain import load_qa_chain
+from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
+from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
+from langchain.schema import BaseChatMessageHistory
+from langchain.schema.messages import BaseMessage
+from langchain.schema.retriever import BaseRetriever
+from pydantic import Field
+
+
+class ConversationBufferWithSummary(ConversationSummaryMemory):
+ """Memory class for storing information about entities."""
+
+ # Define dictionary to store information about entities.
+ # Store the most recent conversation history
+ buffered_history: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
+ # Temp buffer
+ summarized_history_temp: BaseChatMessageHistory = Field(default_factory=ChatMessageHistory)
+ human_prefix: str = "Human"
+ ai_prefix: str = "Assistant"
+ buffer: str = "" # Formated conversation in str
+ existing_summary: str = "" # Summarization of stale converstion in str
+ # Define key to pass information about entities into prompt.
+ memory_key: str = "chat_history"
+ input_key: str = "question"
+ retriever: BaseRetriever = None
+ max_tokens: int = 2000
+ chain: BaseCombineDocumentsChain = None
+ input_chain_type_kwargs: List = {}
+
+ @property
+ def buffer(self) -> Any:
+ """String buffer of memory."""
+ return self.buffer_as_messages if self.return_messages else self.buffer_as_str
+
+ @property
+ def buffer_as_str(self) -> str:
+ """Exposes the buffer as a string in case return_messages is True."""
+ self.buffer = self.format_dialogue()
+ return self.buffer
+
+ @property
+ def buffer_as_messages(self) -> List[BaseMessage]:
+ """Exposes the buffer as a list of messages in case return_messages is False."""
+ return self.buffered_history.messages
+
+ def clear(self):
+ """Clear all the memory"""
+ self.buffered_history.clear()
+ self.summarized_history_temp.clear()
+
+ def initiate_document_retrieval_chain(
+ self, llm: Any, prompt_template: Any, retriever: Any, chain_type_kwargs: Dict[str, Any] = {}
+ ) -> None:
+ """
+ Since we need to calculate the length of the prompt, we need to initiate a retrieval chain
+ to calculate the length of the prompt.
+ Args:
+ llm: the language model for the retrieval chain (we won't actually return the output)
+ prompt_template: the prompt template for constructing the retrieval chain
+ retriever: the retriever for the retrieval chain
+ max_tokens: the max length of the prompt (not include the output)
+ chain_type_kwargs: the kwargs for the retrieval chain
+ memory_key: the key for the chat history
+ input_key: the key for the input query
+ """
+ self.retriever = retriever
+ input_chain_type_kwargs = {k: v for k, v in chain_type_kwargs.items() if k not in [self.memory_key]}
+ self.input_chain_type_kwargs = input_chain_type_kwargs
+ self.chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_template, **self.input_chain_type_kwargs)
+
+ @property
+ def memory_variables(self) -> List[str]:
+ """Define the variables we are providing to the prompt."""
+ return [self.memory_key]
+
+ def format_dialogue(self, lang: str = "en") -> str:
+ """Format memory into two parts--- summarization of historical conversation and most recent conversation"""
+ if len(self.summarized_history_temp.messages) != 0:
+ for i in range(int(len(self.summarized_history_temp.messages) / 2)):
+ self.existing_summary = (
+ self.predict_new_summary(
+ self.summarized_history_temp.messages[i * 2 : i * 2 + 2], self.existing_summary, stop=["\n\n"]
+ )
+ .strip()
+ .split("\n")[0]
+ .strip()
+ )
+ for i in range(int(len(self.summarized_history_temp.messages) / 2)):
+ self.summarized_history_temp.messages.pop(0)
+ self.summarized_history_temp.messages.pop(0)
+ conversation_buffer = []
+ for t in self.buffered_history.messages:
+ if t.type == "human":
+ prefix = self.human_prefix
+ else:
+ prefix = self.ai_prefix
+ conversation_buffer.append(prefix + ": " + t.content)
+ conversation_buffer = "\n".join(conversation_buffer)
+ if len(self.existing_summary) > 0:
+ if lang == "en":
+ message = f"A summarization of historical conversation:\n{self.existing_summary}\nMost recent conversation:\n{conversation_buffer}"
+ elif lang == "zh":
+ message = f"历史对话概要:\n{self.existing_summary}\n最近的对话:\n{conversation_buffer}"
+ else:
+ raise ValueError("Unsupported language")
+ return message
+ else:
+ message = conversation_buffer
+ return message
+
+ def get_conversation_length(self):
+ """Get the length of the formatted conversation"""
+ prompt = self.format_dialogue()
+ length = self.llm.get_num_tokens(prompt)
+ return length
+
+ def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, str]:
+ """Load the memory variables.
+ Summarize oversize conversation to fit into the length constraint defined by max_tokene
+ Args:
+ inputs: the kwargs of the chain of your definition
+ Returns:
+ a dict that maps from memory key to the formated dialogue
+ the formated dialogue has the following format
+ if conversation is too long:
+ A summarization of historical conversation:
+ {summarization}
+ Most recent conversation:
+ Human: XXX
+ Assistant: XXX
+ ...
+ otherwise
+ Human: XXX
+ Assistant: XXX
+ ...
+ """
+ # Calculate remain length
+ if "input_documents" in inputs:
+ # Run in a retrieval qa chain
+ docs = inputs["input_documents"]
+ else:
+ # For test
+ docs = self.retriever.get_relevant_documents(inputs[self.input_key])
+ inputs[self.memory_key] = ""
+ inputs = {k: v for k, v in inputs.items() if k in [self.chain.input_key, self.input_key, self.memory_key]}
+ prompt_length = self.chain.prompt_length(docs, **inputs)
+ remain = self.max_tokens - prompt_length
+ while self.get_conversation_length() > remain:
+ if len(self.buffered_history.messages) <= 2:
+ raise RuntimeError("Exeeed max_tokens, trunck size of retrieved documents is too large")
+ temp = self.buffered_history.messages.pop(0)
+ self.summarized_history_temp.messages.append(temp)
+ temp = self.buffered_history.messages.pop(0)
+ self.summarized_history_temp.messages.append(temp)
+ return {self.memory_key: self.format_dialogue()}
+
+ def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
+ """Save context from this conversation to buffer."""
+ input_str, output_str = self._get_input_output(inputs, outputs)
+ self.buffered_history.add_user_message(input_str.strip())
+ self.buffered_history.add_ai_message(output_str.strip())
diff --git a/applications/ColossalQA/colossalqa/mylogging.py b/applications/ColossalQA/colossalqa/mylogging.py
new file mode 100644
index 000000000000..574c33b41685
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/mylogging.py
@@ -0,0 +1,92 @@
+"""
+Class for logging with extra control for debugging
+"""
+import logging
+
+
+class ColossalQALogger:
+ """This is a distributed event logger class essentially based on :class:`logging`.
+
+ Args:
+ name (str): The name of the logger.
+
+ Note:
+ Logging types: ``info``, ``warning``, ``debug`` and ``error``
+ """
+
+ __instances = dict()
+
+ def __init__(self, name):
+ if name in ColossalQALogger.__instances:
+ raise ValueError("Logger with the same name has been created")
+ else:
+ self._name = name
+ self._logger = logging.getLogger(name)
+
+ ColossalQALogger.__instances[name] = self
+
+ @staticmethod
+ def get_instance(name: str):
+ """Get the unique single logger instance based on name.
+
+ Args:
+ name (str): The name of the logger.
+
+ Returns:
+ DistributedLogger: A DistributedLogger object
+ """
+ if name in ColossalQALogger.__instances:
+ return ColossalQALogger.__instances[name]
+ else:
+ logger = ColossalQALogger(name=name)
+ return logger
+
+ def info(self, message: str, verbose: bool = False) -> None:
+ """Log an info message.
+
+ Args:
+ message (str): The message to be logged.
+ verbose (bool): Whether to print the message to stdout.
+ """
+ if verbose:
+ logging.basicConfig(level=logging.INFO)
+ self._logger.info(message)
+
+ def warning(self, message: str, verbose: bool = False) -> None:
+ """Log a warning message.
+
+ Args:
+ message (str): The message to be logged.
+ verbose (bool): Whether to print the message to stdout.
+ """
+ if verbose:
+ self._logger.warning(message)
+
+ def debug(self, message: str, verbose: bool = False) -> None:
+ """Log a debug message.
+
+ Args:
+ message (str): The message to be logged.
+ verbose (bool): Whether to print the message to stdout.
+ """
+ if verbose:
+ self._logger.debug(message)
+
+ def error(self, message: str) -> None:
+ """Log an error message.
+
+ Args:
+ message (str): The message to be logged.
+ """
+ self._logger.error(message)
+
+
+def get_logger(name: str = None, level=logging.INFO) -> ColossalQALogger:
+ """
+ Get the logger by name, if name is None, return the default logger
+ """
+ if name:
+ logger = ColossalQALogger.get_instance(name=name)
+ else:
+ logger = ColossalQALogger.get_instance(name="colossalqa")
+ return logger
diff --git a/applications/ColossalQA/colossalqa/prompt/README.md b/applications/ColossalQA/colossalqa/prompt/README.md
new file mode 100644
index 000000000000..e5c74906b113
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/prompt/README.md
@@ -0,0 +1,144 @@
+# Prompt Design Guide
+
+For the retriever conversation system, users can customize three prompts.
+
+## The Retrieval QA Prompt
+This is the prompt for retrieval QA, the input is user's inputs, the retrieved documents, the historical conversation.
+
+### Chinese
+```
+你是一个善于解答用户问题的AI助手。在保证安全的前提下,回答问题要尽可能有帮助。你的答案不应该包含任何有害的、不道德的、种族主义的、性别歧视的、危险的或非法的内容。请确保你的回答是公正和积极的。
+如果不能根据给定的上下文推断出答案,请不要分享虚假、不确定的信息。
+使用提供的背景信息和聊天记录对用户的输入作出回应或继续对话。您应该只生成一个回复。不需要跟进回答。请使用中文作答。
+
+背景信息:
+[retrieved documents]
+
+聊天记录:
+[historical conversation, overlength chat history will be summarized]
+
+用户: [question]
+Assistant:
+```
+
+### English
+```
+[INST] <>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+If the answer cannot be infered based on the given context, please don't share false information.<>
+Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
+
+context:
+[retrieved documents]
+
+chat history
+[historical conversation, overlength chat history will be summarized]
+
+Human: {question}
+Assistant:
+```
+
+## Summarization Prompt
+This prompt is used by the memory module to recursively summarize overlength conversation to shrink the length of the prompt.
+
+## Disambiguity Prompt
+This prompt is used to perform zero-shot reference resolution to disambiguate entity references within user's questions.
+
+## Final Prompt Examples
+Assume k=3 for the retriever.
+
+### English
+Note that the "[INST] <>...<>" template is the specific prompt format used in LLaMA2.
+#### Normal Length
+```
+[INST] <>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+If the answer cannot be infered based on the given context, please don't share false information.<>
+Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
+
+context:
+[document 1]
+
+[document 2]
+
+[document 3]
+
+chat history
+Human: XXX
+Assistant: XXX
+...
+
+Human: {question}
+Assistant:
+```
+
+#### Overlength
+```
+[INST] <>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+If the answer cannot be infered based on the given context, please don't share false information.<>
+Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
+
+context:
+[document 1]
+
+[document 2]
+
+[document 3]
+
+chat history
+A summarization of historical conversation:
+[one line summary of historical conversation]
+Most recent conversation:
+Human: XXX
+Assistant: XXX
+...
+
+Human: {question}
+Assistant:
+```
+
+### Chinese
+#### Normal Length
+```
+你是一个善于解答用户问题的AI助手。在保证安全的前提下,回答问题要尽可能有帮助。你的答案不应该包含任何有害的、不道德的、种族主义的、性别歧视的、危险的或非法的内容。请确保你的回答是公正和积极的。
+如果不能根据给定的上下文推断出答案,请不要分享虚假、不确定的信息。
+使用提供的背景信息和聊天记录对用户的输入作出回应或继续对话。您应该只生成一个回复。不需要跟进回答。请使用中文作答。
+
+背景信息:
+[document 1]
+
+[document 2]
+
+[document 3]
+
+聊天记录:
+用户: XXX
+Assistant: XXX
+...
+
+用户: [question]
+Assistant:
+```
+
+#### Overlength
+```
+你是一个善于解答用户问题的AI助手。在保证安全的前提下,回答问题要尽可能有帮助。你的答案不应该包含任何有害的、不道德的、种族主义的、性别歧视的、危险的或非法的内容。请确保你的回答是公正和积极的。
+如果不能根据给定的上下文推断出答案,请不要分享虚假、不确定的信息。
+使用提供的背景信息和聊天记录对用户的输入作出回应或继续对话。您应该只生成一个回复。不需要跟进回答。请使用中文作答。
+
+背景信息:
+[document 1]
+
+[document 2]
+
+[document 3]
+
+聊天记录:
+历史对话概要:
+[one line summary of historical conversation]
+最近的对话:
+用户: XXX
+Assistant: XXX
+...
+
+用户: [question]
+Assistant:
+```
diff --git a/applications/ColossalQA/colossalqa/prompt/prompt.py b/applications/ColossalQA/colossalqa/prompt/prompt.py
new file mode 100644
index 000000000000..a7723078689e
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/prompt/prompt.py
@@ -0,0 +1,124 @@
+"""
+All custom prompt templates are defined here.
+"""
+
+from langchain.prompts.prompt import PromptTemplate
+
+_CUSTOM_SUMMARIZER_TEMPLATE_ZH = """请递进式地总结所提供的当前对话,将当前对话的摘要内容添加到先前已有的摘要上,返回一个融合了当前对话的新的摘要。
+
+例1:
+已有的摘要:
+人类问Assistant对人工智能的看法。人工智能认为人工智能是一种善的力量。
+
+新的对话内容:
+人类: 为什么你认为人工智能是一种好的力量?
+Assistant: 因为人工智能将帮助人类充分发挥潜力。
+
+新的摘要:
+人类问Assistant对人工智能的看法。人工智能认为人工智能是一种积极的力量,因为它将帮助人类充分发挥潜力。
+示例结束
+
+已有的摘要:
+{summary}
+
+新的对话内容:
+{new_lines}
+
+新的摘要:"""
+
+
+# Chinese retrieval qa prompt
+
+_ZH_RETRIEVAL_QA_PROMPT = """<指令>根据下列支持文档和对话历史,简洁和专业地来回答问题。如果无法从支持文档中得到答案,请说 “根据已知信息无法回答该问题”。回答中请不要涉及支持文档中没有提及的信息,答案请使用中文。 指令>
+
+{context}
+
+<对话历史>
+{chat_history}
+对话历史>
+
+<问题>{question}问题>
+答案:"""
+
+ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS = ["无法回答该问题"]
+ZH_RETRIEVAL_QA_REJECTION_ANSWER = "抱歉,根据提供的信息无法回答该问题。"
+
+
+_ZH_RETRIEVAL_CLASSIFICATION_USE_CASE = """使用提供的参考案例判断客户遇到的故障所属的故障原因分类。
+
+背景信息:
+{context}
+
+客服记录:
+{question}
+故障原因分类:"""
+
+_ZH_DISAMBIGUATION_PROMPT = """你是一个乐于助人、恭敬而诚实的助手。你总是按照指示去做。
+请用聊天记录中提到的具体名称或实体名称替换给定句子中的任何模糊或有歧义的指代,如果没有提供聊天记录或句子中不包含模糊或有歧义的指代,则只输出原始句子。您的输出应该是消除歧义的句子本身(与“消除歧义的句子:”在同一行中),并且不包含任何其他内容。
+
+下面是一个例子:
+聊天记录:
+用户: 我有一个朋友,张三。你认识他吗?
+Assistant: 我认识一个叫张三的人
+
+句子: 他最喜欢的食物是什么?
+消除歧义的句子: 张三最喜欢的食物是什么?
+
+聊天记录:
+{chat_history}
+
+句子: {input}
+消除歧义的句子:"""
+
+# English retrieval qa prompt
+
+_EN_RETRIEVAL_QA_PROMPT = """[INST] <>Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist content.
+If the answer cannot be infered based on the given context, please say "I cannot answer the question based on the information given.".<>
+Use the context and chat history to answer the question.
+
+context:
+{context}
+
+chat history
+{chat_history}
+
+question: {question}
+answer:"""
+EN_RETRIEVAL_QA_TRIGGER_KEYWORDS = ["cannot answer the question"]
+EN_RETRIEVAL_QA_REJECTION_ANSWER = "Sorry, this question cannot be answered based on the information provided."
+
+_EN_DISAMBIGUATION_PROMPT = """[INST] <>You are a helpful, respectful and honest assistant. You always follow the instruction.<>
+Please replace any ambiguous references in the given sentence with the specific names or entities mentioned in the chat history or just output the original sentence if no chat history is provided or if the sentence doesn't contain ambiguous references. Your output should be the disambiguated sentence itself (in the same line as "disambiguated sentence:") and contain nothing else.
+
+Here is an example:
+Chat history:
+Human: I have a friend, Mike. Do you know him?
+Assistant: Yes, I know a person named Mike
+
+sentence: What's his favorate food?
+disambiguated sentence: What's Mike's favorate food?
+[/INST]
+Chat history:
+{chat_history}
+
+sentence: {input}
+disambiguated sentence:"""
+
+
+PROMPT_RETRIEVAL_QA_EN = PromptTemplate(
+ template=_EN_RETRIEVAL_QA_PROMPT, input_variables=["question", "chat_history", "context"]
+)
+
+PROMPT_DISAMBIGUATE_EN = PromptTemplate(template=_EN_DISAMBIGUATION_PROMPT, input_variables=["chat_history", "input"])
+
+SUMMARY_PROMPT_ZH = PromptTemplate(input_variables=["summary", "new_lines"], template=_CUSTOM_SUMMARIZER_TEMPLATE_ZH)
+
+PROMPT_DISAMBIGUATE_ZH = PromptTemplate(template=_ZH_DISAMBIGUATION_PROMPT, input_variables=["chat_history", "input"])
+
+PROMPT_RETRIEVAL_QA_ZH = PromptTemplate(
+ template=_ZH_RETRIEVAL_QA_PROMPT, input_variables=["question", "chat_history", "context"]
+)
+
+PROMPT_RETRIEVAL_CLASSIFICATION_USE_CASE_ZH = PromptTemplate(
+ template=_ZH_RETRIEVAL_CLASSIFICATION_USE_CASE, input_variables=["question", "context"]
+)
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_en.py b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py
new file mode 100644
index 000000000000..d2626321d68d
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_en.py
@@ -0,0 +1,87 @@
+"""
+Script for Chinese retrieval based conversation system backed by ChatGLM
+"""
+from typing import Tuple
+
+from colossalqa.chain.retrieval_qa.base import RetrievalQA
+from colossalqa.local.llm import ColossalAPI, ColossalLLM
+from colossalqa.memory import ConversationBufferWithSummary
+from colossalqa.mylogging import get_logger
+from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_EN, PROMPT_RETRIEVAL_QA_EN
+from colossalqa.retriever import CustomRetriever
+from langchain import LLMChain
+
+logger = get_logger()
+
+
+class EnglishRetrievalConversation:
+ """
+ Wrapper class for Chinese retrieval conversation system
+ """
+
+ def __init__(self, retriever: CustomRetriever, model_path: str, model_name: str) -> None:
+ """
+ Setup retrieval qa chain for Chinese retrieval based QA
+ """
+ logger.info(f"model_name: {model_name}; model_path: {model_path}", verbose=True)
+ colossal_api = ColossalAPI.get_api(model_name, model_path)
+ self.llm = ColossalLLM(n=1, api=colossal_api)
+
+ # Define the retriever
+ self.retriever = retriever
+
+ # Define the chain to preprocess the input
+ # Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
+ # Prompt is summarization prompt
+ self.llm_chain_disambiguate = LLMChain(
+ llm=self.llm,
+ prompt=PROMPT_DISAMBIGUATE_EN,
+ llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True},
+ )
+
+ self.retriever.set_rephrase_handler(self.disambiguity)
+ # Define memory with summarization ability
+ self.memory = ConversationBufferWithSummary(
+ llm=self.llm, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
+ )
+ self.memory.initiate_document_retrieval_chain(
+ self.llm,
+ PROMPT_RETRIEVAL_QA_EN,
+ self.retriever,
+ chain_type_kwargs={
+ "chat_history": "",
+ },
+ )
+ self.retrieval_chain = RetrievalQA.from_chain_type(
+ llm=self.llm,
+ verbose=False,
+ chain_type="stuff",
+ retriever=self.retriever,
+ chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_EN, "memory": self.memory},
+ llm_kwargs={"max_new_tokens": 50, "temperature": 0.75, "do_sample": True},
+ )
+
+ def disambiguity(self, input: str):
+ out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
+ return out.split("\n")[0]
+
+ @classmethod
+ def from_retriever(
+ cls, retriever: CustomRetriever, model_path: str, model_name: str
+ ) -> "EnglishRetrievalConversation":
+ return cls(retriever, model_path, model_name)
+
+ def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
+ if memory:
+ # TODO add translation chain here
+ self.memory.buffered_history.messages = memory.buffered_history.messages
+ self.memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
+ return (
+ self.retrieval_chain.run(
+ query=user_input,
+ stop=[self.memory.human_prefix + ": "],
+ rejection_trigger_keywrods=["cannot answer the question"],
+ rejection_answer="Sorry, this question cannot be answered based on the information provided.",
+ ).split("\n")[0],
+ self.memory,
+ )
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
new file mode 100644
index 000000000000..76bec715fb6e
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_universal.py
@@ -0,0 +1,138 @@
+"""
+Multilingual retrieval based conversation system
+"""
+from typing import List
+
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.mylogging import get_logger
+from colossalqa.retrieval_conversation_en import EnglishRetrievalConversation
+from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
+from colossalqa.retriever import CustomRetriever
+from colossalqa.text_splitter import ChineseTextSplitter
+from colossalqa.utils import detect_lang_naive
+from langchain.embeddings import HuggingFaceEmbeddings
+from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
+
+logger = get_logger()
+
+
+class UniversalRetrievalConversation:
+ """
+ Wrapper class for bilingual retrieval conversation system
+ """
+
+ def __init__(
+ self,
+ embedding_model_path: str = "moka-ai/m3e-base",
+ embedding_model_device: str = "cpu",
+ zh_model_path: str = None,
+ zh_model_name: str = None,
+ en_model_path: str = None,
+ en_model_name: str = None,
+ sql_file_path: str = None,
+ files_zh: List[List[str]] = None,
+ files_en: List[List[str]] = None,
+ text_splitter_chunk_size=100,
+ text_splitter_chunk_overlap=10,
+ ) -> None:
+ """
+ Warpper for multilingual retrieval qa class (Chinese + English)
+ Args:
+ embedding_model_path: local or huggingface embedding model
+ embedding_model_device:
+ files_zh: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for Chinese retrieval QA
+ files_en: [[file_path, name_of_file, separator],...] defines the files used as supporting documents for English retrieval QA
+ """
+ self.embedding = HuggingFaceEmbeddings(
+ model_name=embedding_model_path,
+ model_kwargs={"device": embedding_model_device},
+ encode_kwargs={"normalize_embeddings": False},
+ )
+ print("Select files for constructing Chinese retriever")
+ docs_zh = self.load_supporting_docs(
+ files=files_zh,
+ text_splitter=ChineseTextSplitter(
+ chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
+ ),
+ )
+ # Create retriever
+ self.information_retriever_zh = CustomRetriever(
+ k=3, sql_file_path=sql_file_path.replace(".db", "_zh.db"), verbose=True
+ )
+ self.information_retriever_zh.add_documents(
+ docs=docs_zh, cleanup="incremental", mode="by_source", embedding=self.embedding
+ )
+
+ print("Select files for constructing English retriever")
+ docs_en = self.load_supporting_docs(
+ files=files_en,
+ text_splitter=RecursiveCharacterTextSplitter(
+ chunk_size=text_splitter_chunk_size, chunk_overlap=text_splitter_chunk_overlap
+ ),
+ )
+ # Create retriever
+ self.information_retriever_en = CustomRetriever(
+ k=3, sql_file_path=sql_file_path.replace(".db", "_en.db"), verbose=True
+ )
+ self.information_retriever_en.add_documents(
+ docs=docs_en, cleanup="incremental", mode="by_source", embedding=self.embedding
+ )
+
+ self.chinese_retrieval_conversation = ChineseRetrievalConversation.from_retriever(
+ self.information_retriever_zh, model_path=zh_model_path, model_name=zh_model_name
+ )
+ self.english_retrieval_conversation = EnglishRetrievalConversation.from_retriever(
+ self.information_retriever_en, model_path=en_model_path, model_name=en_model_name
+ )
+ self.memory = None
+
+ def load_supporting_docs(self, files: List[List[str]] = None, text_splitter: TextSplitter = None):
+ """
+ Load supporting documents, currently, all documents will be stored in one vector store
+ """
+ documents = []
+ if files:
+ for file in files:
+ retriever_data = DocumentLoader([[file["data_path"], file["name"]]]).all_data
+ splits = text_splitter.split_documents(retriever_data)
+ documents.extend(splits)
+ else:
+ while True:
+ file = input("Select a file to load or press Enter to exit:")
+ if file == "":
+ break
+ data_name = input("Enter a short description of the data:")
+ separator = input(
+ "Enter a separator to force separating text into chunks, if no separator is given, the defaut separator is '\\n\\n', press ENTER directly to skip:"
+ )
+ separator = separator if separator != "" else "\n\n"
+ retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
+
+ # Split
+ splits = text_splitter.split_documents(retriever_data)
+ documents.extend(splits)
+ return documents
+
+ def start_test_session(self):
+ """
+ Simple multilingual session for testing purpose, with naive language selection mechanism
+ """
+ while True:
+ user_input = input("User: ")
+ lang = detect_lang_naive(user_input)
+ if "END" == user_input:
+ print("Agent: Happy to chat with you :)")
+ break
+ agent_response = self.run(user_input, which_language=lang)
+ print(f"Agent: {agent_response}")
+
+ def run(self, user_input: str, which_language=str):
+ """
+ Generate the response given the user input and a str indicates the language requirement of the output string
+ """
+ assert which_language in ["zh", "en"]
+ if which_language == "zh":
+ agent_response, self.memory = self.chinese_retrieval_conversation.run(user_input, self.memory)
+ else:
+ agent_response, self.memory = self.english_retrieval_conversation.run(user_input, self.memory)
+ return agent_response.split("\n")[0]
diff --git a/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py
new file mode 100644
index 000000000000..484be21c1553
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/retrieval_conversation_zh.py
@@ -0,0 +1,94 @@
+"""
+Script for Chinese retrieval based conversation system backed by ChatGLM
+"""
+from typing import Tuple
+
+from colossalqa.chain.retrieval_qa.base import RetrievalQA
+from colossalqa.local.llm import ColossalAPI, ColossalLLM
+from colossalqa.memory import ConversationBufferWithSummary
+from colossalqa.mylogging import get_logger
+from colossalqa.prompt.prompt import PROMPT_DISAMBIGUATE_ZH, PROMPT_RETRIEVAL_QA_ZH, SUMMARY_PROMPT_ZH
+from colossalqa.retriever import CustomRetriever
+from langchain import LLMChain
+
+logger = get_logger()
+
+
+class ChineseRetrievalConversation:
+ """
+ Wrapper class for Chinese retrieval conversation system
+ """
+
+ def __init__(self, retriever: CustomRetriever, model_path: str, model_name: str) -> None:
+ """
+ Setup retrieval qa chain for Chinese retrieval based QA
+ """
+ # Local coati api
+ logger.info(f"model_name: {model_name}; model_path: {model_path}", verbose=True)
+ colossal_api = ColossalAPI.get_api(model_name, model_path)
+ self.llm = ColossalLLM(n=1, api=colossal_api)
+
+ # Define the retriever
+ self.retriever = retriever
+
+ # Define the chain to preprocess the input
+ # Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
+ # Prompt is summarization prompt
+ self.llm_chain_disambiguate = LLMChain(
+ llm=self.llm,
+ prompt=PROMPT_DISAMBIGUATE_ZH,
+ llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True},
+ )
+
+ self.retriever.set_rephrase_handler(self.disambiguity)
+ # Define memory with summarization ability
+ self.memory = ConversationBufferWithSummary(
+ llm=self.llm,
+ prompt=SUMMARY_PROMPT_ZH,
+ human_prefix="用户",
+ ai_prefix="Assistant",
+ max_tokens=2000,
+ llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True},
+ )
+ self.memory.initiate_document_retrieval_chain(
+ self.llm,
+ PROMPT_RETRIEVAL_QA_ZH,
+ self.retriever,
+ chain_type_kwargs={
+ "chat_history": "",
+ },
+ )
+ self.retrieval_chain = RetrievalQA.from_chain_type(
+ llm=self.llm,
+ verbose=False,
+ chain_type="stuff",
+ retriever=self.retriever,
+ chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_ZH, "memory": self.memory},
+ llm_kwargs={"max_new_tokens": 150, "temperature": 0.9, "do_sample": True},
+ )
+
+ def disambiguity(self, input: str):
+ out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
+ return out.split("\n")[0]
+
+ @classmethod
+ def from_retriever(
+ cls, retriever: CustomRetriever, model_path: str, model_name: str
+ ) -> "ChineseRetrievalConversation":
+ return cls(retriever, model_path, model_name)
+
+ def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
+ if memory:
+ # TODO add translation chain here
+ self.memory.buffered_history.messages = memory.buffered_history.messages
+ self.memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
+ return (
+ self.retrieval_chain.run(
+ query=user_input,
+ stop=["答案>"],
+ doc_prefix="支持文档",
+ rejection_trigger_keywrods=["无法回答该问题"],
+ rejection_answer="抱歉,根据提供的信息无法回答该问题。",
+ ).split("\n")[0],
+ self.memory,
+ )
diff --git a/applications/ColossalQA/colossalqa/retriever.py b/applications/ColossalQA/colossalqa/retriever.py
new file mode 100644
index 000000000000..9ea6d5b080cd
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/retriever.py
@@ -0,0 +1,166 @@
+"""
+Code for custom retriver with incremental update
+"""
+import copy
+import hashlib
+import os
+from collections import defaultdict
+from typing import Any, Callable, Dict, List
+
+from colossalqa.mylogging import get_logger
+from langchain.callbacks.manager import CallbackManagerForRetrieverRun
+from langchain.embeddings.base import Embeddings
+from langchain.indexes import SQLRecordManager, index
+from langchain.schema.retriever import BaseRetriever, Document
+from langchain.vectorstores.base import VectorStore
+from langchain.vectorstores.chroma import Chroma
+
+logger = get_logger()
+
+
+class CustomRetriever(BaseRetriever):
+ """
+ Custom retriever class with support for incremental update of indexes
+ """
+
+ vector_stores: Dict[str, VectorStore] = {}
+ sql_index_database: Dict[str, str] = {}
+ record_managers: Dict[str, SQLRecordManager] = {}
+ sql_db_chains = []
+ k = 3
+ rephrase_handler: Callable = None
+ buffer: Dict = []
+ buffer_size: int = 5
+ verbose: bool = False
+ sql_file_path: str = None
+
+ @classmethod
+ def from_documents(
+ cls,
+ documents: List[Document],
+ embeddings: Embeddings,
+ **kwargs: Any,
+ ) -> BaseRetriever:
+ k = kwargs.pop("k", 3)
+ cleanup = kwargs.pop("cleanup", "incremental")
+ mode = kwargs.pop("mode", "by_source")
+ ret = cls(k=k)
+ ret.add_documents(documents, embedding=embeddings, cleanup=cleanup, mode=mode)
+ return ret
+
+ def add_documents(
+ self,
+ docs: Dict[str, Document] = [],
+ cleanup: str = "incremental",
+ mode: str = "by_source",
+ embedding: Embeddings = None,
+ ) -> None:
+ """
+ Add documents to retriever
+ Args:
+ docs: the documents to add
+ cleanup: choose from "incremental" (update embeddings, skip existing embeddings) and "full" (destory and rebuild retriever)
+ mode: choose from "by source" (documents are grouped by source) and "merge" (documents are merged into one vector store)
+ """
+ if cleanup == "full":
+ # Cleanup
+ for source in self.vector_stores:
+ os.remove(self.sql_index_database[source])
+ # Add documents
+ data_by_source = defaultdict(list)
+ if mode == "by_source":
+ for doc in docs:
+ data_by_source[doc.metadata["source"]].append(doc)
+ elif mode == "merge":
+ data_by_source["merged"] = docs
+ for source in data_by_source:
+ if source not in self.vector_stores:
+ hash_encoding = hashlib.sha3_224(source.encode()).hexdigest()
+ if os.path.exists(f"{self.sql_file_path}/{hash_encoding}.db"):
+ # Remove the stale file
+ os.remove(f"{self.sql_file_path}/{hash_encoding}.db")
+ # Create a new sql database to store indexes, sql files are stored in the same directory as the source file
+ sql_path = f"sqlite:///{self.sql_file_path}/{hash_encoding}.db"
+ self.vector_stores[source] = Chroma(embedding_function=embedding, collection_name=hash_encoding)
+ self.sql_index_database[source] = f"{self.sql_file_path}/{hash_encoding}.db"
+ self.record_managers[source] = SQLRecordManager(source, db_url=sql_path)
+ self.record_managers[source].create_schema()
+ index(
+ data_by_source[source],
+ self.record_managers[source],
+ self.vector_stores[source],
+ cleanup=cleanup,
+ source_id_key="source",
+ )
+
+ def __del__(self):
+ for source in self.sql_index_database:
+ if os.path.exists(self.sql_index_database[source]):
+ os.remove(self.sql_index_database[source])
+
+ def set_sql_database_chain(self, db_chains) -> None:
+ """
+ set sql agent chain to retrieve information from sql database
+ Not used in this version
+ """
+ self.sql_db_chains = db_chains
+
+ def set_rephrase_handler(self, handler: Callable = None) -> None:
+ """
+ Set a handler to preprocess the input str before feed into the retriever
+ """
+ self.rephrase_handler = handler
+
+ def _get_relevant_documents(
+ self,
+ query: str,
+ *,
+ run_manager: CallbackManagerForRetrieverRun = None,
+ score_threshold: float = None,
+ return_scores: bool = False,
+ ) -> List[Document]:
+ """
+ This function is called by the retriever to get the relevant documents.
+ recent vistied queries are stored in buffer, if the query is in buffer, return the documents directly
+
+ Args:
+ query: the query to be searched
+ run_manager: the callback manager for retriever run
+ Returns:
+ documents: the relevant documents
+ """
+ for buffered_doc in self.buffer:
+ if buffered_doc[0] == query:
+ return buffered_doc[1]
+ query_ = str(query)
+ # Use your existing retriever to get the documents
+ if self.rephrase_handler:
+ query = self.rephrase_handler(query)
+ documents = []
+ for k in self.vector_stores:
+ # Retrieve documents from each retriever
+ vectorstore = self.vector_stores[k]
+ documents.extend(vectorstore.similarity_search_with_score(query, self.k, score_threshold=score_threshold))
+ # print(documents)
+ # Return the top k documents among all retrievers
+ documents = sorted(documents, key=lambda x: x[1], reverse=False)[: self.k]
+ if return_scores:
+ # Return score
+ documents = copy.deepcopy(documents)
+ for doc in documents:
+ doc[0].metadata["score"] = doc[1]
+ documents = [doc[0] for doc in documents]
+ # Retrieve documents from sql database (not applicable for the local chains)
+ for sql_chain in self.sql_db_chains:
+ documents.append(
+ Document(
+ page_content=f"Query: {query} Answer: {sql_chain.run(query)}", metadata={"source": "sql_query"}
+ )
+ )
+ if len(self.buffer) < self.buffer_size:
+ self.buffer.append([query_, documents])
+ else:
+ self.buffer.pop(0)
+ self.buffer.append([query_, documents])
+ logger.info(f"retrieved documents:\n{str(documents)}", verbose=self.verbose)
+ return documents
diff --git a/applications/ColossalQA/colossalqa/text_splitter/__init__.py b/applications/ColossalQA/colossalqa/text_splitter/__init__.py
new file mode 100644
index 000000000000..b56fdfe8d582
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/text_splitter/__init__.py
@@ -0,0 +1 @@
+from .chinese_text_splitter import ChineseTextSplitter
diff --git a/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py b/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py
new file mode 100644
index 000000000000..3815f5ed2621
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/text_splitter/chinese_text_splitter.py
@@ -0,0 +1,56 @@
+"""
+Code for Chinese text splitter
+"""
+from typing import Any, List, Optional
+
+from colossalqa.text_splitter.utils import get_cleaned_paragraph
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+
+class ChineseTextSplitter(RecursiveCharacterTextSplitter):
+ def __init__(self, separators: Optional[List[str]] = None, is_separator_regrx: bool = False, **kwargs: Any):
+ self._separators = separators or ["\n\n", "\n", ",", "。", "!", "?", "?"]
+ if "chunk_size" not in kwargs:
+ kwargs["chunk_size"] = 50
+ if "chunk_overlap" not in kwargs:
+ kwargs["chunk_overlap"] = 10
+ super().__init__(separators=separators, keep_separator=True, **kwargs)
+ self._is_separator_regex = is_separator_regrx
+
+ def split_text(self, text: str) -> List[str]:
+ """Return the list of separated text chunks"""
+ cleaned_paragraph = get_cleaned_paragraph(text)
+ splitted = []
+ for paragraph in cleaned_paragraph:
+ segs = super().split_text(paragraph)
+ for i in range(len(segs) - 1):
+ if segs[i][-1] not in self._separators:
+ pos = text.find(segs[i])
+ pos_end = pos + len(segs[i])
+ if i > 0:
+ last_sentence_start = max([text.rfind(m, 0, pos) for m in ["。", "!", "?"]])
+ pos = last_sentence_start + 1
+ segs[i] = str(text[pos:pos_end])
+ if i != len(segs) - 1:
+ next_sentence_end = max([text.find(m, pos_end) for m in ["。", "!", "?"]])
+ segs[i] = str(text[pos : next_sentence_end + 1])
+ splitted.append(segs[i])
+ if len(splitted) <= 1:
+ return splitted
+ splitted_text = []
+ i = 1
+ if splitted[0] not in splitted[1]:
+ splitted_text.append([splitted[0], 0])
+ if splitted[-1] not in splitted[-2]:
+ splitted_text.append([splitted[-1], len(splitted) - 1])
+ while i < len(splitted) - 1:
+ if splitted[i] not in splitted[i + 1] and splitted[i] not in splitted[i - 1]:
+ splitted_text.append([splitted[i], i])
+ i += 1
+ splitted_text = sorted(splitted_text, key=lambda x: x[1])
+ splitted_text = [splitted_text[i][0] for i in range(len(splitted_text))]
+ ret = []
+ for s in splitted_text:
+ if s not in ret:
+ ret.append(s)
+ return ret
diff --git a/applications/ColossalQA/colossalqa/text_splitter/utils.py b/applications/ColossalQA/colossalqa/text_splitter/utils.py
new file mode 100644
index 000000000000..250b46d9742a
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/text_splitter/utils.py
@@ -0,0 +1,19 @@
+import re
+
+
+def remove_format(text: str) -> str:
+ # if the accout of \t, \r, \v, \f is less than 3, replace \t, \r, \v, \f with space
+ if len(re.findall(r"\s", text.replace(" ", ""))) > 3:
+ # in case this is a line of a table
+ return text
+ return re.sub(r"\s", " ", text)
+
+
+# remove newlines
+def get_cleaned_paragraph(s: str) -> str:
+ text = str(s)
+ text = re.sub(r"\n{3,}", r"\n", text) # replace \n\n\n... with \n
+ text = re.sub("\n\n", "", text)
+ lines = text.split("\n")
+ lines_remove_format = [remove_format(line) for line in lines]
+ return lines_remove_format
diff --git a/applications/ColossalQA/colossalqa/utils.py b/applications/ColossalQA/colossalqa/utils.py
new file mode 100644
index 000000000000..cd8c3e5acec8
--- /dev/null
+++ b/applications/ColossalQA/colossalqa/utils.py
@@ -0,0 +1,61 @@
+import re
+from typing import Union
+
+from colossalqa.mylogging import get_logger
+from sqlalchemy import Engine, MetaData, create_engine
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.ext.declarative import declarative_base
+
+logger = get_logger()
+
+
+def drop_table(engine: Engine) -> None:
+ """
+ Drop all existing table
+ """
+ Base = declarative_base()
+ metadata = MetaData()
+ metadata.reflect(bind=engine)
+ for key in metadata.tables:
+ table = metadata.tables[key]
+ if table is not None:
+ Base.metadata.drop_all(engine, [table], checkfirst=True)
+
+
+def create_empty_sql_database(database_uri):
+ try:
+ # Create an SQLAlchemy engine to connect to the database
+ engine = create_engine(database_uri)
+
+ # Create the database
+ engine.connect()
+
+ logger.info(f"Database created at {database_uri}")
+ except SQLAlchemyError as e:
+ logger.error(f"Error creating database: {str(e)}")
+ return engine, database_uri
+
+
+def destroy_sql_database(sql_engine: Union[Engine, str]) -> None:
+ """
+ Destroy an sql database
+ """
+ if isinstance(sql_engine, str):
+ sql_engine = create_engine(sql_engine)
+ drop_table(sql_engine)
+ sql_engine.dispose()
+ sql_engine = None
+
+
+def detect_lang_naive(s):
+ """
+ Naive function for language detection, should be replaced by an independant layer
+ """
+ remove_nota = "[’·°–!\"#$%&'()*+,-./:;<=>?@,。?★、…【】()《》?“”‘’![\\]^_`{|}~]+"
+ s = re.sub(remove_nota, "", s)
+ s = re.sub("[0-9]", "", s).strip()
+ res = re.sub("[a-zA-Z]", "", s).strip()
+ if len(res) <= 0:
+ return "en"
+ else:
+ return "zh"
diff --git a/applications/ColossalQA/data/data_sample/companies.txt b/applications/ColossalQA/data/data_sample/companies.txt
new file mode 100644
index 000000000000..05c6148f18a5
--- /dev/null
+++ b/applications/ColossalQA/data/data_sample/companies.txt
@@ -0,0 +1,6 @@
+Overview The Straits Times is the English flagship daily of SPH Media, one of the leading media companies in Asia. Launched on July 15, 1845, its comprehensive coverage of news from home and around the world makes The Straits Times the most-read newspaper in Singapore. Quality news, in-depth analyses, impactful commentaries and breaking stories are packaged to give readers riveting accounts of events in Singapore, the region, and beyond. The most read newspaper in Singapore, both in terms of print and digital, it reaches 1.33 million people every day. The Straits Times' key strength is in its world class coverage of news outside Singapore. With 20 bureaus in major cities around the world, The Straits Times correspondents bring world news to readers on a Singapore platter, helping readers to appreciate world events from a Singaporean perspective. Website http://www.straitstimes.com Phone 63196319Phone number is 63196319 Industry Newspaper Publishing Company size 1,001-5,000 employees 183 on LinkedIn Includes members with current employer listed as The Straits Times, including part-time roles. Headquarters Singapore, Singapore Founded 1845 Specialties News and Digital media
+About With over 500 properties worldwide, Marriott Hotels has reimagined hospitality to exceed the expectations of business, group, and leisure travelers.
+Marriott Hotels, Marriott’s flagship brand of quality-tier, full-service hotels and resorts, provides consistent, dependable and genuinely caring experiences to guests on their terms. Marriott is a brilliant host to guests who effortlessly blend life and work, and who are inspired by how modern travel enhances them both. Our hotels offer warm, professional service; sophisticated yet functional guest room design; lobby spaces that facilitate working, dining and socializing; restaurants and bars serving international cuisine prepared simply and from the freshest ingredients; meeting and event spaces and services that are gold standard; and expansive, 24-hour fitness facilities.
+Overview AERCO International, Inc. is a recognized leader in delivering cost-effective, condensing commercial boilers, high-efficiency water heaters across a variety of markets including education, lodging, government, office buildings, healthcare, industrial and multifamily housing. AERCO's system design approach provides customer-specific solutions that deliver superior building performance at a lower operating cost while assuring uptime reliability. When AERCO was founded in 1949, it introduced a revolutionary design for an indirect-fired water heater that heated water on demand, and without storage, at a controlled temperature. This innovation became today's standard for water heaters, maximizing the recovery of latent heat energy and significantly increasing operating efficiency. AERCO continued to innovate and in 1988, introduced the first condensing and fully modulating boiler and water heater to the commercial market. The modulating capability of these products, still unsurpassed more than 25 years later, matches the equipment's output to real-time heating demand, ensuring the units draw no more fuel to operate than is absolutely necessary. This not only saves precious energy, but also ensures money doesn't needlessly disappear "up the stack." AERCO differentiates itself through a solution-based model, leveraging decades of engineering experience and industry application expertise to understand each customer’s unique needs. By partnering directly with customers and end-users to understand their project-specific requirements, AERCO provides tailored application solutions that are comprised of original product technologies including high efficiency condensing products, compact footprints, high turndown ratios, unique fuel delivery, leading control systems and proprietary design elements that combine to deliver up to 99% efficiency. Website http://www.aerco.com Phone 845-580-8000Phone number is 845-580-8000 Industry Industrial Machinery Manufacturing Company size 51-200 employees 119 on LinkedIn Includes members with current employer listed as AERCO International, Inc., including part-time roles. Headquarters Blauvelt, NY Founded 1949 Specialties Leading manufacturer of condensing boilers, water heating and energy recovery products and The originator of semi-instantaneous water heating
+Prince PLC: Overview We are a global leader of quality water solutions for residential, industrial, municipal, and commercial settings. Our family of brands offers one of the most varied product lines in the world, with world-class, water-related solutions focused on: • Plumbing & Flow Control • Water Quality & Conditioning • Water Reuse & Drainage • HVAC • Municipal Waterworks Strategic Goals Watts Water is traded on the New York Stock Exchange under the symbol “WTS.” As a public company, growing shareholder value is critical. To that end, we focus on a five-part Global Strategy: Growth, Commercial Excellence, Operational Excellence, “One Watts Water,” and a Talent & Performance Culture. Follow us on all social media platforms @WattsWater Website http://www.watts.com/ Industry Wholesale Building Materials Company size 5,001-10,000 employees 2,248 on LinkedIn Includes members with current employer listed as Watts Water Technologies, including part-time roles. Headquarters North Andover, MA Specialties Plumbing, HVAC, Water Quality, Gas, Conditioning, Waterworks, and Drainage
+About Courtyard Hotels is Marriott International’s largest hotel brand, with more than 1,100 hotels in over 50 countries worldwide. So, no matter where passion takes you, you’ll find us there to help you follow it. Proud members of Marriott Bonvoy.
\ No newline at end of file
diff --git a/applications/ColossalQA/data/data_sample/companies_zh.txt b/applications/ColossalQA/data/data_sample/companies_zh.txt
new file mode 100644
index 000000000000..a67a93590ee8
--- /dev/null
+++ b/applications/ColossalQA/data/data_sample/companies_zh.txt
@@ -0,0 +1,6 @@
+《海峡时报》是SPH传媒旗下的英文旗舰日报,SPH传媒是亚洲领先的传媒公司之一。《海峡时报》创刊于1845年7月15日,全面报道国内外新闻,是新加坡发行量最大的报纸。高质量的新闻、深入的分析、有影响力的评论和突发事件,为读者提供新加坡、该地区乃至其他地区的引人入胜的事件报道。无论是纸媒还是电子版,它都是新加坡阅读量最大的报纸,每天有133万人阅读。《海峡时报》的主要优势在于它对新加坡以外新闻的世界级报道。《海峡时报》记者在全球主要城市设有20个分社,用新加坡的盘子把世界新闻带给读者,帮助读者从新加坡的角度了解世界大事。网站http://www.straitstimes.com电话63196319电话63196319工业报纸出版公司规模1,001-5,000员工LinkedIn 183包括目前雇主为海峡时报的成员,包括兼职工作。总部位于新加坡,新加坡成立于1845年,专业从事新闻和数字媒体
+万豪酒店在全球拥有500多家酒店,以超越商务、团体和休闲旅客的期望,重塑酒店服务。
+万豪酒店(Marriott Hotels)是万豪旗下优质、全方位服务酒店和度假村的旗舰品牌,为客人提供始终如一、可靠和真诚关怀的体验。万豪是一个出色的主人,客人可以轻松地将生活和工作融合在一起,并受到现代旅行如何增强两者的启发。我们的酒店提供热情、专业的服务;精致而实用的客房设计;大堂空间,方便工作、餐饮和社交;餐厅和酒吧提供简单的国际美食和最新鲜的食材;会议及活动场地及服务均属黄金标准;还有宽敞的24小时健身设施。
+AERCO International, Inc.是公认的领导者,为教育、住宿、政府、办公楼、医疗保健、工业和多户住宅等各种市场提供具有成本效益的冷凝商用锅炉和高效热水器。AERCO的系统设计方法为客户提供特定的解决方案,以较低的运营成本提供卓越的建筑性能,同时确保正常运行时间的可靠性。AERCO成立于1949年,它推出了一种革命性的设计,用于间接燃烧热水器,在控制温度下按需加热水,而无需储存。这一创新成为当今热水器的标准,最大限度地回收潜热能量,显著提高运行效率。AERCO不断创新,并于1988年向商业市场推出了第一台冷凝和全调制锅炉和热水器。这些产品的调制能力,在超过25年后仍然无与伦比,使设备的输出与实时加热需求相匹配,确保机组不会消耗更多的燃料来运行,除非绝对必要。这不仅节省了宝贵的能源,还确保了钱不会不必要地消失在“堆栈”上。AERCO通过基于解决方案的模式脱颖而出,利用数十年的工程经验和行业应用专业知识来了解每个客户的独特需求。通过与客户和最终用户直接合作,了解他们的项目具体要求,AERCO提供量身定制的应用解决方案,这些解决方案由原创产品技术组成,包括高效冷凝产品,紧凑的足迹,高降压比,独特的燃料输送,领先的控制系统和专有设计元素,结合起来可提供高达99%的效率。网址http://www.aerco.com电话845-580- 8000电话号码845-580-8000工业工业机械制造公司规模51-200名员工LinkedIn上包括当前雇主AERCO International, Inc的成员,包括兼职职位。总部成立于1949年,纽约州布劳维尔特,专长:冷凝锅炉,水加热和能源回收产品的领先制造商,半瞬时水加热的鼻祖
+Prince PLC:概述Prince PLC是为住宅、工业、市政和商业环境提供优质水解决方案的全球领导者。我们的品牌家族提供世界上最多样化的产品线之一,拥有世界级的水相关解决方案,专注于:•管道和流量控制•水质和调理•水再利用和排水•hvac•市政水务战略目标瓦茨水务在纽约证券交易所上市,代码为“WTS”。作为一家上市公司,股东价值的增长至关重要。为此,我们将重点放在五部分全球战略上:增长、卓越商业、卓越运营、“一瓦茨水”以及人才与绩效文化。在所有社交媒体平台关注我们@WattsWater网站http://www.watts.com/行业批发建材公司规模5,001-10,000名员工领英2,248名包括目前雇主为WattsWater Technologies的成员,包括兼职职位。总部北安多弗,MA专业管道,暖通空调,水质,气体,空调,自来水厂和排水
+万怡酒店是万豪国际最大的酒店品牌,在全球50多个国家拥有1100多家酒店。所以,无论你的激情带你去哪里,你都会发现我们会帮助你追随它。万豪酒店的骄傲会员。
\ No newline at end of file
diff --git a/applications/ColossalQA/data/data_sample/csv_organization_100.csv b/applications/ColossalQA/data/data_sample/csv_organization_100.csv
new file mode 100644
index 000000000000..dbe97d5fd774
--- /dev/null
+++ b/applications/ColossalQA/data/data_sample/csv_organization_100.csv
@@ -0,0 +1,101 @@
+Index,Organization Id,Company Name,Website,Country,Description,Founded,Industry,Number of employees
+1,FAB0d41d5b5d22c,Ferrell LLC,https://price.net/,Papua New Guinea,Horizontal empowering knowledgebase,1990,Plastics,3498
+2,6A7EdDEA9FaDC52,"Mckinney, Riley and Day",http://www.hall-buchanan.info/,Finland,User-centric system-worthy leverage,2015,Glass / Ceramics / Concrete,4952
+3,0bFED1ADAE4bcC1,Hester Ltd,http://sullivan-reed.com/,China,Switchable scalable moratorium,1971,Public Safety,5287
+4,2bFC1Be8a4ce42f,Holder-Sellers,https://becker.com/,Turkmenistan,De-engineered systemic artificial intelligence,2004,Automotive,921
+5,9eE8A6a4Eb96C24,Mayer Group,http://www.brewer.com/,Mauritius,Synchronized needs-based challenge,1991,Transportation,7870
+6,cC757116fe1C085,Henry-Thompson,http://morse.net/,Bahamas,Face-to-face well-modulated customer loyalty,1992,Primary / Secondary Education,4914
+7,219233e8aFF1BC3,Hansen-Everett,https://www.kidd.org/,Pakistan,Seamless disintermediate collaboration,2018,Publishing Industry,7832
+8,ccc93DCF81a31CD,Mcintosh-Mora,https://www.brooks.com/,Heard Island and McDonald Islands,Centralized attitude-oriented capability,1970,Import / Export,4389
+9,0B4F93aA06ED03e,Carr Inc,http://ross.com/,Kuwait,Distributed impactful customer loyalty,1996,Plastics,8167
+10,738b5aDe6B1C6A5,Gaines Inc,http://sandoval-hooper.com/,Uzbekistan,Multi-lateral scalable protocol,1997,Outsourcing / Offshoring,9698
+11,AE61b8Ffebbc476,Kidd Group,http://www.lyons.com/,Bouvet Island (Bouvetoya),Proactive foreground paradigm,2001,Primary / Secondary Education,7473
+12,eb3B7D06cCdD609,Crane-Clarke,https://www.sandoval.com/,Denmark,Front-line clear-thinking encryption,2014,Food / Beverages,9011
+13,8D0c29189C9798B,"Keller, Campos and Black",https://www.garner.info/,Liberia,Ameliorated directional emulation,2020,Museums / Institutions,2862
+14,D2c91cc03CA394c,Glover-Pope,http://www.silva.biz/,United Arab Emirates,Persevering contextually-based approach,2013,Medical Practice,9079
+15,C8AC1eaf9C036F4,Pacheco-Spears,https://aguilar.com/,Sweden,Secured logistical synergy,1984,Maritime,769
+16,b5D10A14f7a8AfE,Hodge-Ayers,http://www.archer-elliott.com/,Honduras,Future-proofed radical implementation,1990,Facilities Services,8508
+17,68139b5C4De03B4,"Bowers, Guerra and Krause",http://www.carrillo-nicholson.com/,Uganda,De-engineered transitional strategy,1972,Primary / Secondary Education,6986
+18,5c2EffEfdba2BdF,Mckenzie-Melton,http://montoya-thompson.com/,Hong Kong,Reverse-engineered heuristic alliance,1998,Investment Management / Hedge Fund / Private Equity,4589
+19,ba179F19F7925f5,Branch-Mann,http://www.lozano.com/,Botswana,Adaptive intangible frame,1999,Architecture / Planning,7961
+20,c1Ce9B350BAc66b,Weiss and Sons,https://barrett.com/,Korea,Sharable optimal functionalities,2011,Plastics,5984
+21,8de40AC4e6EaCa4,"Velez, Payne and Coffey",http://burton.com/,Luxembourg,Mandatory coherent synergy,1986,Wholesale,5010
+22,Aad86a4F0385F2d,Harrell LLC,http://www.frey-rosario.com/,Guadeloupe,Reverse-engineered mission-critical moratorium,2018,Construction,2185
+23,22aC3FFd64fD703,"Eaton, Reynolds and Vargas",http://www.freeman.biz/,Monaco,Self-enabling multi-tasking process improvement,2014,Luxury Goods / Jewelry,8987
+24,5Ec4C272bCf085c,Robbins-Cummings,http://donaldson-wilkins.com/,Belgium,Organic non-volatile hierarchy,1991,Pharmaceuticals,5038
+25,5fDBeA8BB91a000,Jenkins Inc,http://www.kirk.biz/,South Africa,Front-line systematic help-desk,2002,Insurance,1215
+26,dFfD6a6F9AC2d9C,"Greene, Benjamin and Novak",http://www.kent.net/,Romania,Centralized leadingedge moratorium,2012,Museums / Institutions,4941
+27,4B217cC5a0674C5,"Dickson, Richmond and Clay",http://everett.com/,Czech Republic,Team-oriented tangible complexity,1980,Real Estate / Mortgage,3122
+28,88b1f1cDcf59a37,Prince-David,http://thompson.com/,Christmas Island,Virtual holistic methodology,1970,Banking / Mortgage,1046
+29,f9F7bBCAEeC360F,Ayala LLC,http://www.zhang.com/,Philippines,Open-source zero administration hierarchy,2021,Legal Services,7664
+30,7Cb3AeFcE4Ba31e,Rivas Group,https://hebert.org/,Australia,Open-architected well-modulated capacity,1998,Logistics / Procurement,4155
+31,ccBcC32adcbc530,"Sloan, Mays and Whitehead",http://lawson.com/,Chad,Face-to-face high-level conglomeration,1997,Civil Engineering,365
+32,f5afd686b3d05F5,"Durham, Allen and Barnes",http://chan-stafford.org/,Zimbabwe,Synergistic web-enabled framework,1993,Mechanical or Industrial Engineering,6135
+33,38C6cfC5074Fa5e,Fritz-Franklin,http://www.lambert.com/,Nepal,Automated 4thgeneration website,1972,Hospitality,4516
+34,5Cd7efccCcba38f,Burch-Ewing,http://cline.net/,Taiwan,User-centric 4thgeneration system engine,1981,Venture Capital / VC,7443
+35,9E6Acb51e3F9d6F,"Glass, Barrera and Turner",https://dunlap.com/,Kyrgyz Republic,Multi-channeled 3rdgeneration open system,2020,Utilities,2610
+36,4D4d7E18321eaeC,Pineda-Cox,http://aguilar.org/,Bolivia,Fundamental asynchronous capability,2010,Human Resources / HR,1312
+37,485f5d06B938F2b,"Baker, Mccann and Macdonald",http://www.anderson-barker.com/,Kenya,Cross-group user-facing focus group,2013,Legislative Office,1638
+38,19E3a5Bf6dBDc4F,Cuevas-Moss,https://dodson-castaneda.net/,Guatemala,Extended human-resource intranet,1994,Music,9995
+39,6883A965c7b68F7,Hahn PLC,http://newman.com/,Belarus,Organic logistical leverage,2012,Electrical / Electronic Manufacturing,3715
+40,AC5B7AA74Aa4A2E,"Valentine, Ferguson and Kramer",http://stuart.net/,Jersey,Centralized secondary time-frame,1997,Non - Profit / Volunteering,3585
+41,decab0D5027CA6a,Arroyo Inc,https://www.turner.com/,Grenada,Managed demand-driven website,2006,Writing / Editing,9067
+42,dF084FbBb613eea,Walls LLC,http://www.reese-vasquez.biz/,Cape Verde,Self-enabling fresh-thinking installation,1989,Investment Management / Hedge Fund / Private Equity,1678
+43,A2D89Ab9bCcAd4e,"Mitchell, Warren and Schneider",https://fox.biz/,Trinidad and Tobago,Enhanced intangible time-frame,2021,Capital Markets / Hedge Fund / Private Equity,3816
+44,77aDc905434a49f,Prince PLC,https://www.watts.com/,Sweden,Profit-focused coherent installation,2016,Individual / Family Services,7645
+45,235fdEFE2cfDa5F,Brock-Blackwell,http://www.small.com/,Benin,Secured foreground emulation,1986,Online Publishing,7034
+46,1eD64cFe986BBbE,Walton-Barnett,https://ashley-schaefer.com/,Western Sahara,Right-sized clear-thinking flexibility,2001,Luxury Goods / Jewelry,1746
+47,CbBbFcdd0eaE2cF,Bartlett-Arroyo,https://cruz.com/,Northern Mariana Islands,Realigned didactic function,1976,Civic / Social Organization,3987
+48,49aECbDaE6aBD53,"Wallace, Madden and Morris",http://www.blevins-fernandez.biz/,Germany,Persistent real-time customer loyalty,2016,Pharmaceuticals,9443
+49,7b3fe6e7E72bFa4,Berg-Sparks,https://cisneros-love.com/,Canada,Stand-alone static implementation,1974,Arts / Crafts,2073
+50,c6DedA82A8aef7E,Gonzales Ltd,http://bird.com/,Tonga,Managed human-resource policy,1988,Consumer Goods,9069
+51,7D9FBF85cdC3871,Lawson and Sons,https://www.wong.com/,French Southern Territories,Compatible analyzing intranet,2021,Arts / Crafts,3527
+52,7dd18Fb7cB07b65,"Mcguire, Mcconnell and Olsen",https://melton-briggs.com/,Korea,Profound client-server frame,1988,Printing,8445
+53,EF5B55FadccB8Fe,Charles-Phillips,https://bowman.com/,Cote d'Ivoire,Monitored client-server implementation,2012,Mental Health Care,3450
+54,f8D4B99e11fAF5D,Odom Ltd,https://www.humphrey-hess.com/,Cote d'Ivoire,Advanced static process improvement,2012,Management Consulting,1825
+55,e24D21BFd3bF1E5,Richard PLC,https://holden-coleman.net/,Mayotte,Object-based optimizing model,1971,Broadcast Media,4942
+56,B9BdfEB6D3Ca44E,Sampson Ltd,https://blevins.com/,Cayman Islands,Intuitive local adapter,2005,Farming,1418
+57,2a74D6f3D3B268e,"Cherry, Le and Callahan",https://waller-delacruz.biz/,Nigeria,Universal human-resource collaboration,2017,Entertainment / Movie Production,7202
+58,Bf3F3f62c8aBC33,Cherry PLC,https://www.avila.info/,Marshall Islands,Persistent tertiary website,1980,Plastics,8245
+59,aeBe26B80a7a23c,Melton-Nichols,https://kennedy.com/,Palau,User-friendly clear-thinking productivity,2021,Legislative Office,8741
+60,aAeb29ad43886C6,Potter-Walsh,http://thomas-french.org/,Turkey,Optional non-volatile open system,2008,Human Resources / HR,6923
+61,bD1bc6bB6d1FeD3,Freeman-Chen,https://mathis.com/,Timor-Leste,Phased next generation adapter,1973,International Trade / Development,346
+62,EB9f456e8b7022a,Soto Group,https://norris.info/,Vietnam,Enterprise-wide executive installation,1988,Business Supplies / Equipment,9097
+63,Dfef38C51D8DAe3,"Poole, Cruz and Whitney",https://reed.info/,Reunion,Balanced analyzing groupware,1978,Marketing / Advertising / Sales,2992
+64,055ffEfB2Dd95B0,Riley Ltd,http://wiley.com/,Brazil,Optional exuding superstructure,1986,Textiles,9315
+65,cBfe4dbAE1699da,"Erickson, Andrews and Bailey",https://www.hobbs-grant.com/,Eritrea,Vision-oriented secondary project,2014,Consumer Electronics,7829
+66,fdFbecbadcdCdf1,"Wilkinson, Charles and Arroyo",http://hunter-mcfarland.com/,United States Virgin Islands,Assimilated 24/7 archive,1996,Building Materials,602
+67,5DCb8A5a5ca03c0,Floyd Ltd,http://www.whitney.com/,Falkland Islands (Malvinas),Function-based fault-tolerant concept,2017,Public Relations / PR,2911
+68,ce57DCbcFD6d618,Newman-Galloway,https://www.scott.com/,Luxembourg,Enhanced foreground collaboration,1987,Information Technology / IT,3934
+69,5aaD187dc929371,Frazier-Butler,https://www.daugherty-farley.info/,Northern Mariana Islands,Persistent interactive circuit,1972,Outsourcing / Offshoring,5130
+70,902D7Ac8b6d476b,Newton Inc,https://www.richmond-manning.info/,Netherlands Antilles,Fundamental stable info-mediaries,1976,Military Industry,563
+71,32BB9Ff4d939788,Duffy-Levy,https://www.potter.com/,Guernsey,Diverse exuding installation,1982,Wireless,6146
+72,adcB0afbE58bAe3,Wagner LLC,https://decker-esparza.com/,Uruguay,Reactive attitude-oriented toolset,1987,International Affairs,6874
+73,dfcA1c84AdB61Ac,Mccall-Holmes,http://www.dean.com/,Benin,Object-based value-added database,2009,Legal Services,696
+74,208044AC2fe52F3,Massey LLC,https://frazier.biz/,Suriname,Configurable zero administration Graphical User Interface,1986,Accounting,5004
+75,f3C365f0c1A0623,Hicks LLC,http://alvarez.biz/,Pakistan,Quality-focused client-server Graphical User Interface,1970,Computer Software / Engineering,8480
+76,ec5Bdd3CBAfaB93,"Cole, Russell and Avery",http://www.blankenship.com/,Mongolia,De-engineered fault-tolerant challenge,2000,Law Enforcement,7012
+77,DDB19Be7eeB56B4,Cummings-Rojas,https://simon-pearson.com/,Svalbard & Jan Mayen Islands,User-centric modular customer loyalty,2012,Financial Services,7529
+78,dd6CA3d0bc3cAfc,"Beasley, Greene and Mahoney",http://www.petersen-lawrence.com/,Togo,Extended content-based methodology,1976,Religious Institutions,869
+79,A0B9d56e61070e3,"Beasley, Sims and Allison",http://burke.info/,Latvia,Secured zero tolerance hub,1972,Facilities Services,6182
+80,cBa7EFe5D05Adaf,Crawford-Rivera,https://black-ramirez.org/,Cuba,Persevering exuding budgetary management,1999,Online Publishing,7805
+81,Ea3f6D52Ec73563,Montes-Hensley,https://krueger.org/,Liechtenstein,Multi-tiered secondary productivity,2009,Printing,8433
+82,bC0CEd48A8000E0,Velazquez-Odom,https://stokes.com/,Djibouti,Streamlined 6thgeneration function,2002,Alternative Dispute Resolution,4044
+83,c89b9b59BC4baa1,Eaton-Morales,https://www.reeves-graham.com/,Micronesia,Customer-focused explicit frame,1990,Capital Markets / Hedge Fund / Private Equity,7013
+84,FEC51bce8421a7b,"Roberson, Pennington and Palmer",http://www.keith-fisher.com/,Cameroon,Adaptive bi-directional hierarchy,1993,Telecommunications,5571
+85,e0E8e27eAc9CAd5,"George, Russo and Guerra",https://drake.com/,Sweden,Centralized non-volatile capability,1989,Military Industry,2880
+86,B97a6CF9bf5983C,Davila Inc,https://mcconnell.info/,Cocos (Keeling) Islands,Profit-focused dedicated frame,2017,Consumer Electronics,2215
+87,a0a6f9b3DbcBEb5,Mays-Preston,http://www.browning-key.com/,Mali,User-centric heuristic focus group,2006,Military Industry,5786
+88,8cC1bDa330a5871,Pineda-Morton,https://www.carr.com/,United States Virgin Islands,Grass-roots methodical info-mediaries,1991,Printing,6168
+89,ED889CB2FE9cbd3,Huang and Sons,https://www.bolton.com/,Eritrea,Re-contextualized dynamic hierarchy,1981,Semiconductors,7484
+90,F4Dc1417BC6cb8f,Gilbert-Simon,https://www.bradford.biz/,Burundi,Grass-roots radical parallelism,1973,Newspapers / Journalism,1927
+91,7ABc3c7ecA03B34,Sampson-Griffith,http://hendricks.org/,Benin,Multi-layered composite paradigm,1972,Textiles,3881
+92,4e0719FBE38e0aB,Miles-Dominguez,http://www.turner.com/,Gibraltar,Organized empowering forecast,1996,Civic / Social Organization,897
+93,dEbDAAeDfaed00A,Rowe and Sons,https://www.simpson.org/,El Salvador,Balanced multimedia knowledgebase,1978,Facilities Services,8172
+94,61BDeCfeFD0cEF5,"Valenzuela, Holmes and Rowland",https://www.dorsey.net/,Taiwan,Persistent tertiary focus group,1999,Transportation,1483
+95,4e91eD25f486110,"Best, Wade and Shepard",https://zimmerman.com/,Zimbabwe,Innovative background definition,1991,Gambling / Casinos,4873
+96,0a0bfFbBbB8eC7c,Holmes Group,https://mcdowell.org/,Ethiopia,Right-sized zero tolerance focus group,1975,Photography,2988
+97,BA6Cd9Dae2Efd62,Good Ltd,http://duffy.com/,Anguilla,Reverse-engineered composite moratorium,1971,Consumer Services,4292
+98,E7df80C60Abd7f9,Clements-Espinoza,http://www.flowers.net/,Falkland Islands (Malvinas),Progressive modular hub,1991,Broadcast Media,236
+99,AFc285dbE2fEd24,Mendez Inc,https://www.burke.net/,Kyrgyz Republic,User-friendly exuding migration,1993,Education Management,339
+100,e9eB5A60Cef8354,Watkins-Kaiser,http://www.herring.com/,Togo,Synergistic background access,2009,Financial Services,2785
\ No newline at end of file
diff --git a/applications/ColossalQA/data/data_sample/custom_service.json b/applications/ColossalQA/data/data_sample/custom_service.json
new file mode 100644
index 000000000000..33dbdf705707
--- /dev/null
+++ b/applications/ColossalQA/data/data_sample/custom_service.json
@@ -0,0 +1 @@
+{"features":[{"feature_idx":0,"name":"query","type":{"dtype":"string","_type":"Value"}},{"feature_idx":1,"name":"response","type":{"dtype":"string","_type":"Value"}}],"rows":[{"row_idx":0,"row":{"query":"My order hasn't arrived yet.","response":"We apologize for the inconvenience. Can you please provide your order number so we can investigate?"},"truncated_cells":[]},{"row_idx":1,"row":{"query":"I received a damaged product.","response":"We apologize for the inconvenience. Can you please provide a photo of the damaged product so we can assist you further?"},"truncated_cells":[]},{"row_idx":2,"row":{"query":"I need to return an item.","response":"Certainly. Please provide your order number and reason for return, and we will provide you with instructions on how to proceed."},"truncated_cells":[]},{"row_idx":3,"row":{"query":"I want to change my shipping address.","response":"No problem. Can you please provide your order number and the new shipping address you'd like to use?"},"truncated_cells":[]},{"row_idx":4,"row":{"query":"I have a question about my bill.","response":"We'd be happy to help. Can you please provide your account number and a brief description of your question?"},"truncated_cells":[]},{"row_idx":5,"row":{"query":"How do I cancel my subscription?","response":"We're sorry to hear that. Can you please provide your account email and the name of the subscription you'd like to cancel?"},"truncated_cells":[]},{"row_idx":6,"row":{"query":"Can I get a refund for my purchase?","response":"We understand your concern. Please provide your order number and the reason for your refund request so we can assist you further."},"truncated_cells":[]},{"row_idx":7,"row":{"query":"I'd like to track my order.","response":"Of course. Can you please provide your order number so we can check the current status for you?"},"truncated_cells":[]},{"row_idx":8,"row":{"query":"My account has been locked.","response":"We apologize for the inconvenience. Can you please provide your account email so we can help you regain access?"},"truncated_cells":[]},{"row_idx":9,"row":{"query":"I can't find the item I'm looking for.","response":"We're here to help. Can you please provide a description or product name of the item you're looking for so we can assist you?"},"truncated_cells":[]},{"row_idx":10,"row":{"query":"I'm having trouble applying a promo code.","response":"We apologize for the inconvenience. Can you please provide the promo code you're trying to use, and we'll help you resolve the issue?"},"truncated_cells":[]},{"row_idx":11,"row":{"query":"How do I update my payment information?","response":"We can help with that. Can you please provide your account email so we can guide you through updating your payment information?"},"truncated_cells":[]},{"row_idx":12,"row":{"query":"I haven't received a response to my email inquiry.","response":"We apologize for the delay. Can you please provide your ticket number or account email so we can follow up on your inquiry?"},"truncated_cells":[]},{"row_idx":13,"row":{"query":"Can I change the size of an item in my order?","response":"Certainly. Can you please provide your order number and the details of the item you'd like to change?"},"truncated_cells":[]},{"row_idx":14,"row":{"query":"What is the status of my warranty claim?","response":"We'd be happy to check for you. Can you please provide your claim number or the product's serial number?"},"truncated_cells":[]},{"row_idx":15,"row":{"query":"Is my item in stock?","response":"To check availability, can you please provide the product name or SKU so we can assist you?"},"truncated_cells":[]},{"row_idx":16,"row":{"query":"My promo code has expired. Can I still use it?","response":"Unfortunately, expired promo codes cannot be used. However, please provide the code, and we'll see if there's a similar offer currently available."},"truncated_cells":[]},{"row_idx":17,"row":{"query":"How do I reset my password?","response":"We can help with that. Please provide your account email, and we'll send you instructions on how to reset your password."},"truncated_cells":[]},{"row_idx":18,"row":{"query":"Can I place a bulk order?","response":"Certainly. Can you please provide the product name or SKU, and the quantity you'd like to order so we can check availability and pricing?"},"truncated_cells":[]},{"row_idx":19,"row":{"query":"How do I apply for a job at your company?","response":"Thank you for your interest. Can you please provide your email address so we can send you information on how to apply?"},"truncated_cells":[]},{"row_idx":20,"row":{"query":"Can I add an item to an existing order?","response":"We'll do our best to help. Can you please provide your order number and the details of the item you'd like to add?"},"truncated_cells":[]},{"row_idx":21,"row":{"query":"How long does shipping take?","response":"We'd be happy to provide an estimate. Can you please provide your shipping destination and the product name or SKU?"},"truncated_cells":[]},{"row_idx":22,"row":{"query":"I was charged twice for my order.","response":"We apologize for the inconvenience. Can you please provide your order number so we can investigate and resolve the issue?"},"truncated_cells":[]},{"row_idx":23,"row":{"query":"Where can I find your sizing chart?","response":"We'd be happy to help. Can you please provide the product name or SKU so we can direct you to the appropriate sizing chart?"},"truncated_cells":[]},{"row_idx":24,"row":{"query":"I need technical support for a product.","response":"We're here to help. Can you please provide the product name, SKU, or serial number, and a description of the issue you're experiencing?"},"truncated_cells":[]},{"row_idx":25,"row":{"query":"Can I exchange an item for a different color?","response":"Certainly. Can you please provide your order number and the details of the item you'd like to exchange?"},"truncated_cells":[]},{"row_idx":26,"row":{"query":"How do I apply for a store credit card?","response":"Thank you for your interest. Can you please provide your email address so we can send you information on how to apply?"},"truncated_cells":[]},{"row_idx":27,"row":{"query":"Where is your store located?","response":"We'd be happy to help you find a store near you. Can you please provide your city and state or zip code?"},"truncated_cells":[]},{"row_idx":28,"row":{"query":"I can't log into my account.","response":"We apologize for the inconvenience. Can you please provide your account email so we can help you troubleshoot the issue?"},"truncated_cells":[]},{"row_idx":29,"row":{"query":"Do you offer gift wrapping?","response":"We do offer gift wrapping for select items. Can you please provide the product name or SKU so we can confirm if gift wrapping is available?"},"truncated_cells":[]},{"row_idx":30,"row":{"query":"Can I schedule a delivery?","response":"Certainly. Can you please provide your order number and the preferred date and time for delivery?"},"truncated_cells":[]},{"row_idx":31,"row":{"query":"How do I unsubscribe from your newsletter?","response":"We're sorry to see you go. Can you please provide your email address so we can remove you from our mailing list?"},"truncated_cells":[]},{"row_idx":32,"row":{"query":"What is your return policy?","response":"We'd be happy to provide information on our return policy. Can you please provide the product name or SKU so we can give you the most accurate information?"},"truncated_cells":[]},{"row_idx":33,"row":{"query":"Do you offer price matching?","response":"We do offer price matching in certain cases. Can you please provide the product name or SKU and the competitor's pricing information?"},"truncated_cells":[]},{"row_idx":34,"row":{"query":"Are there any current promotions or sales?","response":"We'd be happy to inform you of any current promotions. Can you please provide the product name or SKU you're interested in, or the type of promotion you're looking for?"},"truncated_cells":[]},{"row_idx":35,"row":{"query":"My coupon isn't working.","response":"We apologize for the inconvenience. Can you please provide the coupon code and the product name or SKU you're trying to apply it to?"},"truncated_cells":[]},{"row_idx":36,"row":{"query":"Can you help me with a product recommendation?","response":"Of course! Can you please provide some details about what you're looking for, such as product type, features, or price range?"},"truncated_cells":[]},{"row_idx":37,"row":{"query":"Do you offer international shipping?","response":"We'd be happy to help. Can you please provide your shipping destination and the product name or SKU you're interested in?"},"truncated_cells":[]},{"row_idx":38,"row":{"query":"Can I place an order over the phone?","response":"Certainly. Can you please provide your phone number and the best time to reach you, and one of our representatives will contact you to place the order?"},"truncated_cells":[]},{"row_idx":39,"row":{"query":"How do I know if my order went through?","response":"We can help you with that. Can you please provide your account email or order number so we can check the status for you?"},"truncated_cells":[]},{"row_idx":40,"row":{"query":"Can I use multiple promo codes on one order?","response":"In most cases, only one promo code can be applied per order. Can you please provide the promo codes you're trying to use so we can check their compatibility?"},"truncated_cells":[]},{"row_idx":41,"row":{"query":"What are your customer service hours?","response":"Our customer service team is available 24/7 to assist you. Please feel free to reach out with any questions or concerns."},"truncated_cells":[]},{"row_idx":42,"row":{"query":"Where can I find the user manual for my product?","response":"We'd be happy to help. Can you please provide the product name, SKU, or serial number so we can direct you to the appropriate user manual?"},"truncated_cells":[]},{"row_idx":43,"row":{"query":"Do you offer a warranty on your products?","response":"We do offer warranties on select products. Can you please provide the product name or SKU so we can provide you with warranty information?"},"truncated_cells":[]},{"row_idx":44,"row":{"query":"Can I place an order for in-store pickup?","response":"Certainly. Can you please provide the product name or SKU and the store location where you'd like to pick up your order?"},"truncated_cells":[]},{"row_idx":45,"row":{"query":"How do I sign up for your rewards program?","response":"Thank you for your interest in our rewards program. Can you please provide your email address so we can send you information on how to sign up?"},"truncated_cells":[]},{"row_idx":46,"row":{"query":"Can I pay with a gift card online?","response":"Yes, you can use gift cards for online purchases. Can you please provide the gift card number and the product name or SKU you're interested in?"},"truncated_cells":[]},{"row_idx":47,"row":{"query":"I can't find my order confirmation email.","response":"We apologize for the inconvenience. Can you please provide your account email or order number so we can resend the confirmation email?"},"truncated_cells":[]},{"row_idx":48,"row":{"query":"Do you offer a military discount?","response":"Yes, we do offer a military discount. Can you please provide your military ID or email address so we can apply the discount to your account?"},"truncated_cells":[]},{"row_idx":49,"row":{"query":"What is the processing time for my order?","response":"We'd be happy to provide an estimate. Can you please provide your order number or the product name or SKU?"},"truncated_cells":[]},{"row_idx":50,"row":{"query":"How do I update my shipping preferences?","response":"We can help with that. Can you please provide your account email so we can guide you through updating your shipping preferences?"},"truncated_cells":[]},{"row_idx":51,"row":{"query":"Is my payment information secure?","response":"Yes, we take security very seriously. Can you please provide your account email so we can verify the security measures in place for your payment information?"},"truncated_cells":[]},{"row_idx":52,"row":{"query":"Can I pre-order an item?","response":"Certainly. Can you please provide the product name or SKU and your email address so we can notify you when pre-orders are available?"},"truncated_cells":[]},{"row_idx":53,"row":{"query":"How do I use a gift card in-store?","response":"To use a gift card in-store, simply present the gift card at the time of purchase. Can you please provide the gift card number so we can check the balance for you?"},"truncated_cells":[]},{"row_idx":54,"row":{"query":"Do you have a loyalty program?","response":"Yes, we do have a loyalty program. Can you please provide your email address so we can send you information on how to join and enjoy the benefits?"},"truncated_cells":[]},{"row_idx":55,"row":{"query":"Is there a mobile app for your store?","response":"Yes, we do have a mobile app. Can you please provide your email address so we can send you a link to download the app and instructions on how to use it?"},"truncated_cells":[]},{"row_idx":56,"row":{"query":"I need help assembling my product.","response":"We're here to help. Can you please provide the product name, SKU, or serial number, and a description of the issue you're experiencing during assembly?"},"truncated_cells":[]},{"row_idx":57,"row":{"query":"Do you offer financing options?","response":"We do offer financing options for select purchases. Can you please provide the product name or SKU and your email address so we can send you more information?"},"truncated_cells":[]},{"row_idx":58,"row":{"query":"Can I reserve an item in-store?","response":"Certainly. Can you please provide the product name or SKU and the store location where you'd like to reserve the item?"},"truncated_cells":[]},{"row_idx":59,"row":{"query":"How do I get a price adjustment for a recent purchase?","response":"We'd be happy to help. Can you please provide your order number and the product name or SKU for the item you'd like a price adjustment on?"},"truncated_cells":[]},{"row_idx":60,"row":{"query":"How do I change my email preferences?","response":"We can help with that. Can you please provide your account email so we can guide you through updating your email preferences?"},"truncated_cells":[]},{"row_idx":61,"row":{"query":"Can I use my store credit online?","response":"Yes, you can use store credit for online purchases. Can you please provide the store credit number and the product name or SKU you're interested in?"},"truncated_cells":[]},{"row_idx":62,"row":{"query":"What are the washing instructions for this item?","response":"We'd be happy to help. Can you please provide the product name or SKU so we can provide you with the proper washing instructions?"},"truncated_cells":[]},{"row_idx":63,"row":{"query":"Can I get a replacement part for my product?","response":"Certainly. Can you please provide the product name, SKU, or serial number, and a description of the part you need?"},"truncated_cells":[]},{"row_idx":64,"row":{"query":"Do you offer free shipping?","response":"We do offer free shipping on select orders. Can you please provide the product name or SKU and your shipping destination so we can check if your order qualifies?"},"truncated_cells":[]},{"row_idx":65,"row":{"query":"Can I place a custom order?","response":"We'd be happy to assist you. Can you please provide the product name or SKU and a description of the customizations you'd like?"},"truncated_cells":[]},{"row_idx":66,"row":{"query":"How do I report a problem with your website?","response":"We appreciate your feedback. Can you please provide a description of the issue you're experiencing and your email address so we can follow up with you?"},"truncated_cells":[]},{"row_idx":67,"row":{"query":"What is your policy on price adjustments?","response":"We'd be happy to provide information on our price adjustment policy. Can you please provide the product name or SKU so we can give you the most accurate information?"},"truncated_cells":[]},{"row_idx":68,"row":{"query":"Do you have any upcoming sales or events?","response":"We'd be happy to inform you of any upcoming sales or events. Can you please provide your email address so we can keep you updated?"},"truncated_cells":[]},{"row_idx":69,"row":{"query":"How do I schedule a consultation or appointment?","response":"We'd be happy to help. Can you please provide your name, phone number, and the service you're interested in so we can schedule your appointment?"},"truncated_cells":[]},{"row_idx":70,"row":{"query":"Can I get a copy of my receipt?","response":"Certainly. Can you please provide your order number or account email so we can locate your receipt and send you a copy?"},"truncated_cells":[]},{"row_idx":71,"row":{"query":"Can I use a competitor's coupon at your store?","response":"In some cases, we may accept competitor coupons. Can you please provide the competitor's coupon code and the product name or SKU you'd like to apply it to?"},"truncated_cells":[]},{"row_idx":72,"row":{"query":"Do you have a recycling program?","response":"Yes, we do have a recycling program. Can you please provide your email address so we can send you information on how to participate?"},"truncated_cells":[]},{"row_idx":73,"row":{"query":"How do I report a lost or stolen gift card?","response":"We're sorry to hear that. Can you please provide the gift card number, if available, and your email address so we can assist you further?"},"truncated_cells":[]}],"num_rows_total":74,"num_rows_per_page":100}
diff --git a/applications/ColossalQA/data/data_sample/custom_service_classification.json b/applications/ColossalQA/data/data_sample/custom_service_classification.json
new file mode 100644
index 000000000000..2d7e4b05d217
--- /dev/null
+++ b/applications/ColossalQA/data/data_sample/custom_service_classification.json
@@ -0,0 +1,64 @@
+{
+ "data": [
+ {
+ "key": "客户反映手机无法接收短信,但可以正常拨打电话,已确认手机号码正常,需要处理。",
+ "value": "故障原因分类: 短信接收问题"
+ },
+ {
+ "key": "客户申请开通国际漫游服务,但在目的地无法使用手机信号,已核实客户所在地国家为不支持漫游的区域,已通知客户。",
+ "value": "故障原因分类: 国际漫游服务"
+ },
+ {
+ "key": "客户称手机信号时强时弱,经过测试发现在不同区域信号确实存在波动,属于正常现象。",
+ "value": "故障原因分类: 信号强弱波动"
+ },
+ {
+ "key": "客户反映在家中无法连接Wi-Fi,建议检查路由器或尝试更换位置。",
+ "value": "故障原因分类: 家庭网络问题"
+ },
+ {
+ "key": "客户申请更换新的SIM卡,因旧卡损坏,已为客户办理新卡。",
+ "value": "故障原因分类: SIM卡更换"
+ },
+ {
+ "key": "客户反映通话时听不清对方声音,经检查发现是手机内置扬声器故障,建议维修。",
+ "value": "故障原因分类: 扬声器故障"
+ },
+ {
+ "key": "客户手机丢失,请求挂失并办理新卡,已为客户挂失旧卡并补办新卡。",
+ "value": "故障原因分类: 挂失与补办"
+ },
+ {
+ "key": "客户反映在市区内无法使用手机信号,经排查发现信号塔维护,属于暂时性故障。",
+ "value": "故障原因分类: 信号塔维护"
+ },
+ {
+ "key": "客户反映手机充电时出现过热情况,建议更换充电器。",
+ "value": "故障原因分类: 充电器故障"
+ },
+ {
+ "key": "客户要求关闭数据漫游功能,已为客户关闭。",
+ "value": "故障原因分类: 关闭数据漫游"
+ },
+ {
+ "key": "客户申请办理家庭套餐业务,已为客户办理。",
+ "value": "故障原因分类: 家庭套餐办理"
+ },
+ {
+ "key": "客户反映在商场内无法使用手机信号,建议检查手机信号设置。",
+ "value": "故障原因分类: 手机信号设置"
+ },
+ {
+ "key": "客户申请开通国际长途业务,已为客户办理。",
+ "value": "故障原因分类: 国际长途业务办理"
+ },
+ {
+ "key": "客户反映手机屏幕出现蓝屏,建议客户前往维修。",
+ "value": "故障原因分类: 手机屏幕故障"
+ },
+ {
+ "key": "客户申请办理免流量业务,已为客户办理。",
+ "value": "故障原因分类: 免流量业务办理"
+ }
+ ]
+}
diff --git a/applications/ColossalQA/data/data_sample/custom_service_preprocessed.json b/applications/ColossalQA/data/data_sample/custom_service_preprocessed.json
new file mode 100644
index 000000000000..f4e095f162af
--- /dev/null
+++ b/applications/ColossalQA/data/data_sample/custom_service_preprocessed.json
@@ -0,0 +1 @@
+{"data": [{"key": "My order hasn't arrived yet.", "value": "We apologize for the inconvenience. Can you please provide your order number so we can investigate?"}, {"key": "I received a damaged product.", "value": "We apologize for the inconvenience. Can you please provide a photo of the damaged product so we can assist you further?"}, {"key": "I need to return an item.", "value": "Certainly. Please provide your order number and reason for return, and we will provide you with instructions on how to proceed."}, {"key": "I want to change my shipping address.", "value": "No problem. Can you please provide your order number and the new shipping address you'd like to use?"}, {"key": "I have a question about my bill.", "value": "We'd be happy to help. Can you please provide your account number and a brief description of your question?"}, {"key": "How do I cancel my subscription?", "value": "We're sorry to hear that. Can you please provide your account email and the name of the subscription you'd like to cancel?"}, {"key": "Can I get a refund for my purchase?", "value": "We understand your concern. Please provide your order number and the reason for your refund request so we can assist you further."}, {"key": "I'd like to track my order.", "value": "Of course. Can you please provide your order number so we can check the current status for you?"}, {"key": "My account has been locked.", "value": "We apologize for the inconvenience. Can you please provide your account email so we can help you regain access?"}, {"key": "I can't find the item I'm looking for.", "value": "We're here to help. Can you please provide a description or product name of the item you're looking for so we can assist you?"}, {"key": "I'm having trouble applying a promo code.", "value": "We apologize for the inconvenience. Can you please provide the promo code you're trying to use, and we'll help you resolve the issue?"}, {"key": "How do I update my payment information?", "value": "We can help with that. Can you please provide your account email so we can guide you through updating your payment information?"}, {"key": "I haven't received a response to my email inquiry.", "value": "We apologize for the delay. Can you please provide your ticket number or account email so we can follow up on your inquiry?"}, {"key": "Can I change the size of an item in my order?", "value": "Certainly. Can you please provide your order number and the details of the item you'd like to change?"}, {"key": "What is the status of my warranty claim?", "value": "We'd be happy to check for you. Can you please provide your claim number or the product's serial number?"}, {"key": "Is my item in stock?", "value": "To check availability, can you please provide the product name or SKU so we can assist you?"}, {"key": "My promo code has expired. Can I still use it?", "value": "Unfortunately, expired promo codes cannot be used. However, please provide the code, and we'll see if there's a similar offer currently available."}, {"key": "How do I reset my password?", "value": "We can help with that. Please provide your account email, and we'll send you instructions on how to reset your password."}, {"key": "Can I place a bulk order?", "value": "Certainly. Can you please provide the product name or SKU, and the quantity you'd like to order so we can check availability and pricing?"}, {"key": "How do I apply for a job at your company?", "value": "Thank you for your interest. Can you please provide your email address so we can send you information on how to apply?"}, {"key": "Can I add an item to an existing order?", "value": "We'll do our best to help. Can you please provide your order number and the details of the item you'd like to add?"}, {"key": "How long does shipping take?", "value": "We'd be happy to provide an estimate. Can you please provide your shipping destination and the product name or SKU?"}, {"key": "I was charged twice for my order.", "value": "We apologize for the inconvenience. Can you please provide your order number so we can investigate and resolve the issue?"}, {"key": "Where can I find your sizing chart?", "value": "We'd be happy to help. Can you please provide the product name or SKU so we can direct you to the appropriate sizing chart?"}, {"key": "I need technical support for a product.", "value": "We're here to help. Can you please provide the product name, SKU, or serial number, and a description of the issue you're experiencing?"}, {"key": "Can I exchange an item for a different color?", "value": "Certainly. Can you please provide your order number and the details of the item you'd like to exchange?"}, {"key": "How do I apply for a store credit card?", "value": "Thank you for your interest. Can you please provide your email address so we can send you information on how to apply?"}, {"key": "Where is your store located?", "value": "We'd be happy to help you find a store near you. Can you please provide your city and state or zip code?"}, {"key": "I can't log into my account.", "value": "We apologize for the inconvenience. Can you please provide your account email so we can help you troubleshoot the issue?"}, {"key": "Do you offer gift wrapping?", "value": "We do offer gift wrapping for select items. Can you please provide the product name or SKU so we can confirm if gift wrapping is available?"}, {"key": "Can I schedule a delivery?", "value": "Certainly. Can you please provide your order number and the preferred date and time for delivery?"}, {"key": "How do I unsubscribe from your newsletter?", "value": "We're sorry to see you go. Can you please provide your email address so we can remove you from our mailing list?"}, {"key": "What is your return policy?", "value": "We'd be happy to provide information on our return policy. Can you please provide the product name or SKU so we can give you the most accurate information?"}, {"key": "Do you offer price matching?", "value": "We do offer price matching in certain cases. Can you please provide the product name or SKU and the competitor's pricing information?"}, {"key": "Are there any current promotions or sales?", "value": "We'd be happy to inform you of any current promotions. Can you please provide the product name or SKU you're interested in, or the type of promotion you're looking for?"}, {"key": "My coupon isn't working.", "value": "We apologize for the inconvenience. Can you please provide the coupon code and the product name or SKU you're trying to apply it to?"}, {"key": "Can you help me with a product recommendation?", "value": "Of course! Can you please provide some details about what you're looking for, such as product type, features, or price range?"}, {"key": "Do you offer international shipping?", "value": "We'd be happy to help. Can you please provide your shipping destination and the product name or SKU you're interested in?"}, {"key": "Can I place an order over the phone?", "value": "Certainly. Can you please provide your phone number and the best time to reach you, and one of our representatives will contact you to place the order?"}, {"key": "How do I know if my order went through?", "value": "We can help you with that. Can you please provide your account email or order number so we can check the status for you?"}, {"key": "Can I use multiple promo codes on one order?", "value": "In most cases, only one promo code can be applied per order. Can you please provide the promo codes you're trying to use so we can check their compatibility?"}, {"key": "What are your customer service hours?", "value": "Our customer service team is available 24/7 to assist you. Please feel free to reach out with any questions or concerns."}, {"key": "Where can I find the user manual for my product?", "value": "We'd be happy to help. Can you please provide the product name, SKU, or serial number so we can direct you to the appropriate user manual?"}, {"key": "Do you offer a warranty on your products?", "value": "We do offer warranties on select products. Can you please provide the product name or SKU so we can provide you with warranty information?"}, {"key": "Can I place an order for in-store pickup?", "value": "Certainly. Can you please provide the product name or SKU and the store location where you'd like to pick up your order?"}, {"key": "How do I sign up for your rewards program?", "value": "Thank you for your interest in our rewards program. Can you please provide your email address so we can send you information on how to sign up?"}, {"key": "Can I pay with a gift card online?", "value": "Yes, you can use gift cards for online purchases. Can you please provide the gift card number and the product name or SKU you're interested in?"}, {"key": "I can't find my order confirmation email.", "value": "We apologize for the inconvenience. Can you please provide your account email or order number so we can resend the confirmation email?"}, {"key": "Do you offer a military discount?", "value": "Yes, we do offer a military discount. Can you please provide your military ID or email address so we can apply the discount to your account?"}, {"key": "What is the processing time for my order?", "value": "We'd be happy to provide an estimate. Can you please provide your order number or the product name or SKU?"}, {"key": "How do I update my shipping preferences?", "value": "We can help with that. Can you please provide your account email so we can guide you through updating your shipping preferences?"}, {"key": "Is my payment information secure?", "value": "Yes, we take security very seriously. Can you please provide your account email so we can verify the security measures in place for your payment information?"}, {"key": "Can I pre-order an item?", "value": "Certainly. Can you please provide the product name or SKU and your email address so we can notify you when pre-orders are available?"}, {"key": "How do I use a gift card in-store?", "value": "To use a gift card in-store, simply present the gift card at the time of purchase. Can you please provide the gift card number so we can check the balance for you?"}, {"key": "Do you have a loyalty program?", "value": "Yes, we do have a loyalty program. Can you please provide your email address so we can send you information on how to join and enjoy the benefits?"}, {"key": "Is there a mobile app for your store?", "value": "Yes, we do have a mobile app. Can you please provide your email address so we can send you a link to download the app and instructions on how to use it?"}, {"key": "I need help assembling my product.", "value": "We're here to help. Can you please provide the product name, SKU, or serial number, and a description of the issue you're experiencing during assembly?"}, {"key": "Do you offer financing options?", "value": "We do offer financing options for select purchases. Can you please provide the product name or SKU and your email address so we can send you more information?"}, {"key": "Can I reserve an item in-store?", "value": "Certainly. Can you please provide the product name or SKU and the store location where you'd like to reserve the item?"}, {"key": "How do I get a price adjustment for a recent purchase?", "value": "We'd be happy to help. Can you please provide your order number and the product name or SKU for the item you'd like a price adjustment on?"}, {"key": "How do I change my email preferences?", "value": "We can help with that. Can you please provide your account email so we can guide you through updating your email preferences?"}, {"key": "Can I use my store credit online?", "value": "Yes, you can use store credit for online purchases. Can you please provide the store credit number and the product name or SKU you're interested in?"}, {"key": "What are the washing instructions for this item?", "value": "We'd be happy to help. Can you please provide the product name or SKU so we can provide you with the proper washing instructions?"}, {"key": "Can I get a replacement part for my product?", "value": "Certainly. Can you please provide the product name, SKU, or serial number, and a description of the part you need?"}, {"key": "Do you offer free shipping?", "value": "We do offer free shipping on select orders. Can you please provide the product name or SKU and your shipping destination so we can check if your order qualifies?"}, {"key": "Can I place a custom order?", "value": "We'd be happy to assist you. Can you please provide the product name or SKU and a description of the customizations you'd like?"}, {"key": "How do I report a problem with your website?", "value": "We appreciate your feedback. Can you please provide a description of the issue you're experiencing and your email address so we can follow up with you?"}, {"key": "What is your policy on price adjustments?", "value": "We'd be happy to provide information on our price adjustment policy. Can you please provide the product name or SKU so we can give you the most accurate information?"}, {"key": "Do you have any upcoming sales or events?", "value": "We'd be happy to inform you of any upcoming sales or events. Can you please provide your email address so we can keep you updated?"}, {"key": "How do I schedule a consultation or appointment?", "value": "We'd be happy to help. Can you please provide your name, phone number, and the service you're interested in so we can schedule your appointment?"}, {"key": "Can I get a copy of my receipt?", "value": "Certainly. Can you please provide your order number or account email so we can locate your receipt and send you a copy?"}, {"key": "Can I use a competitor's coupon at your store?", "value": "In some cases, we may accept competitor coupons. Can you please provide the competitor's coupon code and the product name or SKU you'd like to apply it to?"}, {"key": "Do you have a recycling program?", "value": "Yes, we do have a recycling program. Can you please provide your email address so we can send you information on how to participate?"}, {"key": "How do I report a lost or stolen gift card?", "value": "We're sorry to hear that. Can you please provide the gift card number, if available, and your email address so we can assist you further?"}]}
diff --git a/applications/ColossalQA/data/data_sample/luchen_zh.txt b/applications/ColossalQA/data/data_sample/luchen_zh.txt
new file mode 100644
index 000000000000..afd7fc306fad
--- /dev/null
+++ b/applications/ColossalQA/data/data_sample/luchen_zh.txt
@@ -0,0 +1 @@
+潞晨科技是一家致力于“解放AI生产力”的全球性公司,技术团队核心成员来自美国加州伯克利、斯坦福、新加坡国立、南洋理工、清华、北大等国内外知名高校。在高性能计算、人工智能、分布式系统等方面已有十余年的技术积累,并在国际顶级学术刊物或会议发表论文近百篇。公司核心产品面向大模型时代的通用深度学习系统 Colossal-AI,可实现高效快速部署AI大模型训练和推理,降低AI大模型应用成本。公司在种子轮、天使轮融资已获得“清科中国早期投资机构30强”前三甲创新工场、真格基金、蓝驰创投的600万美元投资。
diff --git a/applications/ColossalQA/data/tests/64KB.json b/applications/ColossalQA/data/tests/64KB.json
new file mode 100644
index 000000000000..99278dc5c79a
--- /dev/null
+++ b/applications/ColossalQA/data/tests/64KB.json
@@ -0,0 +1,7 @@
+{
+ "data":[
+ {"content":"Donec lobortis eleifend condimentum. Cras dictum dolor lacinia lectus vehicula rutrum. Maecenas quis nisi nunc. Nam tristique feugiat est vitae mollis. Maecenas quis nisi nunc."},
+ {"content":"Aliquam sollicitudin ante ligula, eget malesuada nibh efficitur et. Pellentesque massa sem, scelerisque sit amet odio id, cursus tempor urna. Etiam congue dignissim volutpat. Vestibulum pharetra libero et velit gravida euismod."}
+ ],
+ "name":"player"
+}
\ No newline at end of file
diff --git a/applications/ColossalQA/data/tests/companies.csv b/applications/ColossalQA/data/tests/companies.csv
new file mode 100644
index 000000000000..93dcac9f39ae
--- /dev/null
+++ b/applications/ColossalQA/data/tests/companies.csv
@@ -0,0 +1,101 @@
+Index,Organization Id,Name,Website,Country,Description,Founded,Industry,Number of employees
+1,FAB0d41d5b5d22c,Ferrell LLC,https://price.net/,Papua New Guinea,Horizontal empowering knowledgebase,1990,Plastics,3498
+2,6A7EdDEA9FaDC52,"Mckinney, Riley and Day",http://www.hall-buchanan.info/,Finland,User-centric system-worthy leverage,2015,Glass / Ceramics / Concrete,4952
+3,0bFED1ADAE4bcC1,Hester Ltd,http://sullivan-reed.com/,China,Switchable scalable moratorium,1971,Public Safety,5287
+4,2bFC1Be8a4ce42f,Holder-Sellers,https://becker.com/,Turkmenistan,De-engineered systemic artificial intelligence,2004,Automotive,921
+5,9eE8A6a4Eb96C24,Mayer Group,http://www.brewer.com/,Mauritius,Synchronized needs-based challenge,1991,Transportation,7870
+6,cC757116fe1C085,Henry-Thompson,http://morse.net/,Bahamas,Face-to-face well-modulated customer loyalty,1992,Primary / Secondary Education,4914
+7,219233e8aFF1BC3,Hansen-Everett,https://www.kidd.org/,Pakistan,Seamless disintermediate collaboration,2018,Publishing Industry,7832
+8,ccc93DCF81a31CD,Mcintosh-Mora,https://www.brooks.com/,Heard Island and McDonald Islands,Centralized attitude-oriented capability,1970,Import / Export,4389
+9,0B4F93aA06ED03e,Carr Inc,http://ross.com/,Kuwait,Distributed impactful customer loyalty,1996,Plastics,8167
+10,738b5aDe6B1C6A5,Gaines Inc,http://sandoval-hooper.com/,Uzbekistan,Multi-lateral scalable protocol,1997,Outsourcing / Offshoring,9698
+11,AE61b8Ffebbc476,Kidd Group,http://www.lyons.com/,Bouvet Island (Bouvetoya),Proactive foreground paradigm,2001,Primary / Secondary Education,7473
+12,eb3B7D06cCdD609,Crane-Clarke,https://www.sandoval.com/,Denmark,Front-line clear-thinking encryption,2014,Food / Beverages,9011
+13,8D0c29189C9798B,"Keller, Campos and Black",https://www.garner.info/,Liberia,Ameliorated directional emulation,2020,Museums / Institutions,2862
+14,D2c91cc03CA394c,Glover-Pope,http://www.silva.biz/,United Arab Emirates,Persevering contextually-based approach,2013,Medical Practice,9079
+15,C8AC1eaf9C036F4,Pacheco-Spears,https://aguilar.com/,Sweden,Secured logistical synergy,1984,Maritime,769
+16,b5D10A14f7a8AfE,Hodge-Ayers,http://www.archer-elliott.com/,Honduras,Future-proofed radical implementation,1990,Facilities Services,8508
+17,68139b5C4De03B4,"Bowers, Guerra and Krause",http://www.carrillo-nicholson.com/,Uganda,De-engineered transitional strategy,1972,Primary / Secondary Education,6986
+18,5c2EffEfdba2BdF,Mckenzie-Melton,http://montoya-thompson.com/,Hong Kong,Reverse-engineered heuristic alliance,1998,Investment Management / Hedge Fund / Private Equity,4589
+19,ba179F19F7925f5,Branch-Mann,http://www.lozano.com/,Botswana,Adaptive intangible frame,1999,Architecture / Planning,7961
+20,c1Ce9B350BAc66b,Weiss and Sons,https://barrett.com/,Korea,Sharable optimal functionalities,2011,Plastics,5984
+21,8de40AC4e6EaCa4,"Velez, Payne and Coffey",http://burton.com/,Luxembourg,Mandatory coherent synergy,1986,Wholesale,5010
+22,Aad86a4F0385F2d,Harrell LLC,http://www.frey-rosario.com/,Guadeloupe,Reverse-engineered mission-critical moratorium,2018,Construction,2185
+23,22aC3FFd64fD703,"Eaton, Reynolds and Vargas",http://www.freeman.biz/,Monaco,Self-enabling multi-tasking process improvement,2014,Luxury Goods / Jewelry,8987
+24,5Ec4C272bCf085c,Robbins-Cummings,http://donaldson-wilkins.com/,Belgium,Organic non-volatile hierarchy,1991,Pharmaceuticals,5038
+25,5fDBeA8BB91a000,Jenkins Inc,http://www.kirk.biz/,South Africa,Front-line systematic help-desk,2002,Insurance,1215
+26,dFfD6a6F9AC2d9C,"Greene, Benjamin and Novak",http://www.kent.net/,Romania,Centralized leadingedge moratorium,2012,Museums / Institutions,4941
+27,4B217cC5a0674C5,"Dickson, Richmond and Clay",http://everett.com/,Czech Republic,Team-oriented tangible complexity,1980,Real Estate / Mortgage,3122
+28,88b1f1cDcf59a37,Prince-David,http://thompson.com/,Christmas Island,Virtual holistic methodology,1970,Banking / Mortgage,1046
+29,f9F7bBCAEeC360F,Ayala LLC,http://www.zhang.com/,Philippines,Open-source zero administration hierarchy,2021,Legal Services,7664
+30,7Cb3AeFcE4Ba31e,Rivas Group,https://hebert.org/,Australia,Open-architected well-modulated capacity,1998,Logistics / Procurement,4155
+31,ccBcC32adcbc530,"Sloan, Mays and Whitehead",http://lawson.com/,Chad,Face-to-face high-level conglomeration,1997,Civil Engineering,365
+32,f5afd686b3d05F5,"Durham, Allen and Barnes",http://chan-stafford.org/,Zimbabwe,Synergistic web-enabled framework,1993,Mechanical or Industrial Engineering,6135
+33,38C6cfC5074Fa5e,Fritz-Franklin,http://www.lambert.com/,Nepal,Automated 4thgeneration website,1972,Hospitality,4516
+34,5Cd7efccCcba38f,Burch-Ewing,http://cline.net/,Taiwan,User-centric 4thgeneration system engine,1981,Venture Capital / VC,7443
+35,9E6Acb51e3F9d6F,"Glass, Barrera and Turner",https://dunlap.com/,Kyrgyz Republic,Multi-channeled 3rdgeneration open system,2020,Utilities,2610
+36,4D4d7E18321eaeC,Pineda-Cox,http://aguilar.org/,Bolivia,Fundamental asynchronous capability,2010,Human Resources / HR,1312
+37,485f5d06B938F2b,"Baker, Mccann and Macdonald",http://www.anderson-barker.com/,Kenya,Cross-group user-facing focus group,2013,Legislative Office,1638
+38,19E3a5Bf6dBDc4F,Cuevas-Moss,https://dodson-castaneda.net/,Guatemala,Extended human-resource intranet,1994,Music,9995
+39,6883A965c7b68F7,Hahn PLC,http://newman.com/,Belarus,Organic logistical leverage,2012,Electrical / Electronic Manufacturing,3715
+40,AC5B7AA74Aa4A2E,"Valentine, Ferguson and Kramer",http://stuart.net/,Jersey,Centralized secondary time-frame,1997,Non - Profit / Volunteering,3585
+41,decab0D5027CA6a,Arroyo Inc,https://www.turner.com/,Grenada,Managed demand-driven website,2006,Writing / Editing,9067
+42,dF084FbBb613eea,Walls LLC,http://www.reese-vasquez.biz/,Cape Verde,Self-enabling fresh-thinking installation,1989,Investment Management / Hedge Fund / Private Equity,1678
+43,A2D89Ab9bCcAd4e,"Mitchell, Warren and Schneider",https://fox.biz/,Trinidad and Tobago,Enhanced intangible time-frame,2021,Capital Markets / Hedge Fund / Private Equity,3816
+44,77aDc905434a49f,Prince PLC,https://www.watts.com/,Sweden,Profit-focused coherent installation,2016,Individual / Family Services,7645
+45,235fdEFE2cfDa5F,Brock-Blackwell,http://www.small.com/,Benin,Secured foreground emulation,1986,Online Publishing,7034
+46,1eD64cFe986BBbE,Walton-Barnett,https://ashley-schaefer.com/,Western Sahara,Right-sized clear-thinking flexibility,2001,Luxury Goods / Jewelry,1746
+47,CbBbFcdd0eaE2cF,Bartlett-Arroyo,https://cruz.com/,Northern Mariana Islands,Realigned didactic function,1976,Civic / Social Organization,3987
+48,49aECbDaE6aBD53,"Wallace, Madden and Morris",http://www.blevins-fernandez.biz/,Germany,Persistent real-time customer loyalty,2016,Pharmaceuticals,9443
+49,7b3fe6e7E72bFa4,Berg-Sparks,https://cisneros-love.com/,Canada,Stand-alone static implementation,1974,Arts / Crafts,2073
+50,c6DedA82A8aef7E,Gonzales Ltd,http://bird.com/,Tonga,Managed human-resource policy,1988,Consumer Goods,9069
+51,7D9FBF85cdC3871,Lawson and Sons,https://www.wong.com/,French Southern Territories,Compatible analyzing intranet,2021,Arts / Crafts,3527
+52,7dd18Fb7cB07b65,"Mcguire, Mcconnell and Olsen",https://melton-briggs.com/,Korea,Profound client-server frame,1988,Printing,8445
+53,EF5B55FadccB8Fe,Charles-Phillips,https://bowman.com/,Cote d'Ivoire,Monitored client-server implementation,2012,Mental Health Care,3450
+54,f8D4B99e11fAF5D,Odom Ltd,https://www.humphrey-hess.com/,Cote d'Ivoire,Advanced static process improvement,2012,Management Consulting,1825
+55,e24D21BFd3bF1E5,Richard PLC,https://holden-coleman.net/,Mayotte,Object-based optimizing model,1971,Broadcast Media,4942
+56,B9BdfEB6D3Ca44E,Sampson Ltd,https://blevins.com/,Cayman Islands,Intuitive local adapter,2005,Farming,1418
+57,2a74D6f3D3B268e,"Cherry, Le and Callahan",https://waller-delacruz.biz/,Nigeria,Universal human-resource collaboration,2017,Entertainment / Movie Production,7202
+58,Bf3F3f62c8aBC33,Cherry PLC,https://www.avila.info/,Marshall Islands,Persistent tertiary website,1980,Plastics,8245
+59,aeBe26B80a7a23c,Melton-Nichols,https://kennedy.com/,Palau,User-friendly clear-thinking productivity,2021,Legislative Office,8741
+60,aAeb29ad43886C6,Potter-Walsh,http://thomas-french.org/,Turkey,Optional non-volatile open system,2008,Human Resources / HR,6923
+61,bD1bc6bB6d1FeD3,Freeman-Chen,https://mathis.com/,Timor-Leste,Phased next generation adapter,1973,International Trade / Development,346
+62,EB9f456e8b7022a,Soto Group,https://norris.info/,Vietnam,Enterprise-wide executive installation,1988,Business Supplies / Equipment,9097
+63,Dfef38C51D8DAe3,"Poole, Cruz and Whitney",https://reed.info/,Reunion,Balanced analyzing groupware,1978,Marketing / Advertising / Sales,2992
+64,055ffEfB2Dd95B0,Riley Ltd,http://wiley.com/,Brazil,Optional exuding superstructure,1986,Textiles,9315
+65,cBfe4dbAE1699da,"Erickson, Andrews and Bailey",https://www.hobbs-grant.com/,Eritrea,Vision-oriented secondary project,2014,Consumer Electronics,7829
+66,fdFbecbadcdCdf1,"Wilkinson, Charles and Arroyo",http://hunter-mcfarland.com/,United States Virgin Islands,Assimilated 24/7 archive,1996,Building Materials,602
+67,5DCb8A5a5ca03c0,Floyd Ltd,http://www.whitney.com/,Falkland Islands (Malvinas),Function-based fault-tolerant concept,2017,Public Relations / PR,2911
+68,ce57DCbcFD6d618,Newman-Galloway,https://www.scott.com/,Luxembourg,Enhanced foreground collaboration,1987,Information Technology / IT,3934
+69,5aaD187dc929371,Frazier-Butler,https://www.daugherty-farley.info/,Northern Mariana Islands,Persistent interactive circuit,1972,Outsourcing / Offshoring,5130
+70,902D7Ac8b6d476b,Newton Inc,https://www.richmond-manning.info/,Netherlands Antilles,Fundamental stable info-mediaries,1976,Military Industry,563
+71,32BB9Ff4d939788,Duffy-Levy,https://www.potter.com/,Guernsey,Diverse exuding installation,1982,Wireless,6146
+72,adcB0afbE58bAe3,Wagner LLC,https://decker-esparza.com/,Uruguay,Reactive attitude-oriented toolset,1987,International Affairs,6874
+73,dfcA1c84AdB61Ac,Mccall-Holmes,http://www.dean.com/,Benin,Object-based value-added database,2009,Legal Services,696
+74,208044AC2fe52F3,Massey LLC,https://frazier.biz/,Suriname,Configurable zero administration Graphical User Interface,1986,Accounting,5004
+75,f3C365f0c1A0623,Hicks LLC,http://alvarez.biz/,Pakistan,Quality-focused client-server Graphical User Interface,1970,Computer Software / Engineering,8480
+76,ec5Bdd3CBAfaB93,"Cole, Russell and Avery",http://www.blankenship.com/,Mongolia,De-engineered fault-tolerant challenge,2000,Law Enforcement,7012
+77,DDB19Be7eeB56B4,Cummings-Rojas,https://simon-pearson.com/,Svalbard & Jan Mayen Islands,User-centric modular customer loyalty,2012,Financial Services,7529
+78,dd6CA3d0bc3cAfc,"Beasley, Greene and Mahoney",http://www.petersen-lawrence.com/,Togo,Extended content-based methodology,1976,Religious Institutions,869
+79,A0B9d56e61070e3,"Beasley, Sims and Allison",http://burke.info/,Latvia,Secured zero tolerance hub,1972,Facilities Services,6182
+80,cBa7EFe5D05Adaf,Crawford-Rivera,https://black-ramirez.org/,Cuba,Persevering exuding budgetary management,1999,Online Publishing,7805
+81,Ea3f6D52Ec73563,Montes-Hensley,https://krueger.org/,Liechtenstein,Multi-tiered secondary productivity,2009,Printing,8433
+82,bC0CEd48A8000E0,Velazquez-Odom,https://stokes.com/,Djibouti,Streamlined 6thgeneration function,2002,Alternative Dispute Resolution,4044
+83,c89b9b59BC4baa1,Eaton-Morales,https://www.reeves-graham.com/,Micronesia,Customer-focused explicit frame,1990,Capital Markets / Hedge Fund / Private Equity,7013
+84,FEC51bce8421a7b,"Roberson, Pennington and Palmer",http://www.keith-fisher.com/,Cameroon,Adaptive bi-directional hierarchy,1993,Telecommunications,5571
+85,e0E8e27eAc9CAd5,"George, Russo and Guerra",https://drake.com/,Sweden,Centralized non-volatile capability,1989,Military Industry,2880
+86,B97a6CF9bf5983C,Davila Inc,https://mcconnell.info/,Cocos (Keeling) Islands,Profit-focused dedicated frame,2017,Consumer Electronics,2215
+87,a0a6f9b3DbcBEb5,Mays-Preston,http://www.browning-key.com/,Mali,User-centric heuristic focus group,2006,Military Industry,5786
+88,8cC1bDa330a5871,Pineda-Morton,https://www.carr.com/,United States Virgin Islands,Grass-roots methodical info-mediaries,1991,Printing,6168
+89,ED889CB2FE9cbd3,Huang and Sons,https://www.bolton.com/,Eritrea,Re-contextualized dynamic hierarchy,1981,Semiconductors,7484
+90,F4Dc1417BC6cb8f,Gilbert-Simon,https://www.bradford.biz/,Burundi,Grass-roots radical parallelism,1973,Newspapers / Journalism,1927
+91,7ABc3c7ecA03B34,Sampson-Griffith,http://hendricks.org/,Benin,Multi-layered composite paradigm,1972,Textiles,3881
+92,4e0719FBE38e0aB,Miles-Dominguez,http://www.turner.com/,Gibraltar,Organized empowering forecast,1996,Civic / Social Organization,897
+93,dEbDAAeDfaed00A,Rowe and Sons,https://www.simpson.org/,El Salvador,Balanced multimedia knowledgebase,1978,Facilities Services,8172
+94,61BDeCfeFD0cEF5,"Valenzuela, Holmes and Rowland",https://www.dorsey.net/,Taiwan,Persistent tertiary focus group,1999,Transportation,1483
+95,4e91eD25f486110,"Best, Wade and Shepard",https://zimmerman.com/,Zimbabwe,Innovative background definition,1991,Gambling / Casinos,4873
+96,0a0bfFbBbB8eC7c,Holmes Group,https://mcdowell.org/,Ethiopia,Right-sized zero tolerance focus group,1975,Photography,2988
+97,BA6Cd9Dae2Efd62,Good Ltd,http://duffy.com/,Anguilla,Reverse-engineered composite moratorium,1971,Consumer Services,4292
+98,E7df80C60Abd7f9,Clements-Espinoza,http://www.flowers.net/,Falkland Islands (Malvinas),Progressive modular hub,1991,Broadcast Media,236
+99,AFc285dbE2fEd24,Mendez Inc,https://www.burke.net/,Kyrgyz Republic,User-friendly exuding migration,1993,Education Management,339
+100,e9eB5A60Cef8354,Watkins-Kaiser,http://www.herring.com/,Togo,Synergistic background access,2009,Financial Services,2785
diff --git a/applications/ColossalQA/data/tests/sample-pdf-file.pdf b/applications/ColossalQA/data/tests/sample-pdf-file.pdf
new file mode 100644
index 000000000000..4b6eea24d6ea
Binary files /dev/null and b/applications/ColossalQA/data/tests/sample-pdf-file.pdf differ
diff --git a/applications/ColossalQA/data/tests/test.html b/applications/ColossalQA/data/tests/test.html
new file mode 100644
index 000000000000..5ad21421d827
--- /dev/null
+++ b/applications/ColossalQA/data/tests/test.html
@@ -0,0 +1,1970 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ logging — Logging facility for Python — Python 3.11.5 documentation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
This module defines functions and classes which implement a flexible event
+logging system for applications and libraries.
+
The key benefit of having the logging API provided by a standard library module
+is that all Python modules can participate in logging, so your application log
+can include your own messages integrated with messages from third-party
+modules.
The module provides a lot of functionality and flexibility. If you are
+unfamiliar with logging, the best way to get to grips with it is to view the
+tutorials (see the links above and on the right).
+
The basic classes defined by the module, together with their functions, are
+listed below.
+
+
Loggers expose the interface that application code directly uses.
+
Handlers send the log records (created by loggers) to the appropriate
+destination.
+
Filters provide a finer grained facility for determining which log records
+to output.
+
Formatters specify the layout of log records in the final output.
Loggers have the following attributes and methods. Note that Loggers should
+NEVER be instantiated directly, but always through the module-level function
+logging.getLogger(name). Multiple calls to getLogger() with the same
+name will always return a reference to the same Logger object.
+
The name is potentially a period-separated hierarchical value, like
+foo.bar.baz (though it could also be just plain foo, for example).
+Loggers that are further down in the hierarchical list are children of loggers
+higher up in the list. For example, given a logger with a name of foo,
+loggers with names of foo.bar, foo.bar.baz, and foo.bam are all
+descendants of foo. The logger name hierarchy is analogous to the Python
+package hierarchy, and identical to it if you organise your loggers on a
+per-module basis using the recommended construction
+logging.getLogger(__name__). That’s because in a module, __name__
+is the module’s name in the Python package namespace.
If this attribute evaluates to true, events logged to this logger will be
+passed to the handlers of higher level (ancestor) loggers, in addition to
+any handlers attached to this logger. Messages are passed directly to the
+ancestor loggers’ handlers - neither the level nor filters of the ancestor
+loggers in question are considered.
+
If this evaluates to false, logging messages are not passed to the handlers
+of ancestor loggers.
+
Spelling it out with an example: If the propagate attribute of the logger named
+A.B.C evaluates to true, any event logged to A.B.C via a method call such as
+logging.getLogger('A.B.C').error(...) will [subject to passing that logger’s
+level and filter settings] be passed in turn to any handlers attached to loggers
+named A.B, A and the root logger, after first being passed to any handlers
+attached to A.B.C. If any logger in the chain A.B.C, A.B, A has its
+propagate attribute set to false, then that is the last logger whose handlers
+are offered the event to handle, and propagation stops at that point.
+
The constructor sets this attribute to True.
+
+
Note
+
If you attach a handler to a logger and one or more of its
+ancestors, it may emit the same record multiple times. In general, you
+should not need to attach a handler to more than one logger - if you just
+attach it to the appropriate logger which is highest in the logger
+hierarchy, then it will see all events logged by all descendant loggers,
+provided that their propagate setting is left set to True. A common
+scenario is to attach handlers only to the root logger, and to let
+propagation take care of the rest.
Sets the threshold for this logger to level. Logging messages which are less
+severe than level will be ignored; logging messages which have severity level
+or higher will be emitted by whichever handler or handlers service this logger,
+unless a handler’s level has been set to a higher severity level than level.
+
When a logger is created, the level is set to NOTSET (which causes
+all messages to be processed when the logger is the root logger, or delegation
+to the parent when the logger is a non-root logger). Note that the root logger
+is created with level WARNING.
+
The term ‘delegation to the parent’ means that if a logger has a level of
+NOTSET, its chain of ancestor loggers is traversed until either an ancestor with
+a level other than NOTSET is found, or the root is reached.
+
If an ancestor is found with a level other than NOTSET, then that ancestor’s
+level is treated as the effective level of the logger where the ancestor search
+began, and is used to determine how a logging event is handled.
+
If the root is reached, and it has a level of NOTSET, then all messages will be
+processed. Otherwise, the root’s level will be used as the effective level.
Changed in version 3.2: The level parameter now accepts a string representation of the
+level such as ‘INFO’ as an alternative to the integer constants
+such as INFO. Note, however, that levels are internally stored
+as integers, and methods such as e.g. getEffectiveLevel() and
+isEnabledFor() will return/expect to be passed integers.
Indicates if a message of severity level would be processed by this logger.
+This method checks first the module-level level set by
+logging.disable(level) and then the logger’s effective level as determined
+by getEffectiveLevel().
Indicates the effective level for this logger. If a value other than
+NOTSET has been set using setLevel(), it is returned. Otherwise,
+the hierarchy is traversed towards the root until a value other than
+NOTSET is found, and that value is returned. The value returned is
+an integer, typically one of logging.DEBUG, logging.INFO
+etc.
Returns a logger which is a descendant to this logger, as determined by the suffix.
+Thus, logging.getLogger('abc').getChild('def.ghi') would return the same
+logger as would be returned by logging.getLogger('abc.def.ghi'). This is a
+convenience method, useful when the parent logger is named using e.g. __name__
+rather than a literal string.
Logs a message with level DEBUG on this logger. The msg is the
+message format string, and the args are the arguments which are merged into
+msg using the string formatting operator. (Note that this means that you can
+use keywords in the format string, together with a single dictionary argument.)
+No % formatting operation is performed on msg when no args are supplied.
+
There are four keyword arguments in kwargs which are inspected:
+exc_info, stack_info, stacklevel and extra.
+
If exc_info does not evaluate as false, it causes exception information to be
+added to the logging message. If an exception tuple (in the format returned by
+sys.exc_info()) or an exception instance is provided, it is used;
+otherwise, sys.exc_info() is called to get the exception information.
+
The second optional keyword argument is stack_info, which defaults to
+False. If true, stack information is added to the logging
+message, including the actual logging call. Note that this is not the same
+stack information as that displayed through specifying exc_info: The
+former is stack frames from the bottom of the stack up to the logging call
+in the current thread, whereas the latter is information about stack frames
+which have been unwound, following an exception, while searching for
+exception handlers.
+
You can specify stack_info independently of exc_info, e.g. to just show
+how you got to a certain point in your code, even when no exceptions were
+raised. The stack frames are printed following a header line which says:
+
Stack (most recent call last):
+
+
+
This mimics the Traceback(mostrecentcalllast): which is used when
+displaying exception frames.
+
The third optional keyword argument is stacklevel, which defaults to 1.
+If greater than 1, the corresponding number of stack frames are skipped
+when computing the line number and function name set in the LogRecord
+created for the logging event. This can be used in logging helpers so that
+the function name, filename and line number recorded are not the information
+for the helper function/method, but rather its caller. The name of this
+parameter mirrors the equivalent one in the warnings module.
+
The fourth keyword argument is extra which can be used to pass a
+dictionary which is used to populate the __dict__ of the LogRecord
+created for the logging event with user-defined attributes. These custom
+attributes can then be used as you like. For example, they could be
+incorporated into logged messages. For example:
The keys in the dictionary passed in extra should not clash with the keys used
+by the logging system. (See the section on LogRecord attributes for more
+information on which keys are used by the logging system.)
+
If you choose to use these attributes in logged messages, you need to exercise
+some care. In the above example, for instance, the Formatter has been
+set up with a format string which expects ‘clientip’ and ‘user’ in the attribute
+dictionary of the LogRecord. If these are missing, the message will
+not be logged because a string formatting exception will occur. So in this case,
+you always need to pass the extra dictionary with these keys.
+
While this might be annoying, this feature is intended for use in specialized
+circumstances, such as multi-threaded servers where the same code executes in
+many contexts, and interesting conditions which arise are dependent on this
+context (such as remote client IP address and authenticated user name, in the
+above example). In such circumstances, it is likely that specialized
+Formatters would be used with particular Handlers.
+
If no handler is attached to this logger (or any of its ancestors,
+taking into account the relevant Logger.propagate attributes),
+the message will be sent to the handler set on lastResort.
+
+
Changed in version 3.2: The stack_info parameter was added.
+
+
+
Changed in version 3.5: The exc_info parameter can now accept exception instances.
+
+
+
Changed in version 3.8: The stacklevel parameter was added.
Logs a message with level ERROR on this logger. The arguments are
+interpreted as for debug(). Exception info is added to the logging
+message. This method should only be called from an exception handler.
Apply this logger’s filters to the record and return True if the
+record is to be processed. The filters are consulted in turn, until one of
+them returns a false value. If none of them return a false value, the record
+will be processed (passed to handlers). If one returns a false value, no
+further processing of the record occurs.
Finds the caller’s source filename and line number. Returns the filename, line
+number, function name and stack information as a 4-element tuple. The stack
+information is returned as None unless stack_info is True.
+
The stacklevel parameter is passed from code calling the debug()
+and other APIs. If greater than 1, the excess is used to skip stack frames
+before determining the values to be returned. This will generally be useful
+when calling logging APIs from helper/wrapper code, so that the information
+in the event log refers not to the helper/wrapper code, but to the code that
+calls it.
Handles a record by passing it to all handlers associated with this logger and
+its ancestors (until a false value of propagate is found). This method is used
+for unpickled records received from a socket, as well as those created locally.
+Logger-level filtering is applied using filter().
Checks to see if this logger has any handlers configured. This is done by
+looking for handlers in this logger and its parents in the logger hierarchy.
+Returns True if a handler was found, else False. The method stops searching
+up the hierarchy whenever a logger with the ‘propagate’ attribute set to
+false is found - that will be the last logger which is checked for the
+existence of handlers.
+
+
New in version 3.2.
+
+
+
+
+
Changed in version 3.7: Loggers can now be pickled and unpickled.
The numeric values of logging levels are given in the following table. These are
+primarily of interest if you want to define your own levels, and need them to
+have specific values relative to the predefined levels. If you define a level
+with the same numeric value, it overwrites the predefined value; the predefined
+name is lost.
When set on a logger, indicates that
+ancestor loggers are to be consulted
+to determine the effective level.
+If that still resolves to
+NOTSET, then all events
+are logged. When set on a handler,
+all events are handled.
An indication that something
+unexpected happened, or that a
+problem might occur in the near
+future (e.g. ‘disk space low’). The
+software is still working as
+expected.
Handlers have the following attributes and methods. Note that Handler
+is never instantiated directly; this class acts as a base for more useful
+subclasses. However, the __init__() method in subclasses needs to call
+Handler.__init__().
Initializes the Handler instance by setting its level, setting the list
+of filters to the empty list and creating a lock (using createLock()) for
+serializing access to an I/O mechanism.
Sets the threshold for this handler to level. Logging messages which are
+less severe than level will be ignored. When a handler is created, the
+level is set to NOTSET (which causes all messages to be
+processed).
Changed in version 3.2: The level parameter now accepts a string representation of the
+level such as ‘INFO’ as an alternative to the integer constants
+such as INFO.
Apply this handler’s filters to the record and return True if the
+record is to be processed. The filters are consulted in turn, until one of
+them returns a false value. If none of them return a false value, the record
+will be emitted. If one returns a false value, the handler will not emit the
+record.
Tidy up any resources used by the handler. This version does no output but
+removes the handler from an internal list of handlers which is closed when
+shutdown() is called. Subclasses should ensure that this gets called
+from overridden close() methods.
Conditionally emits the specified logging record, depending on filters which may
+have been added to the handler. Wraps the actual emission of the record with
+acquisition/release of the I/O thread lock.
This method should be called from handlers when an exception is encountered
+during an emit() call. If the module-level attribute
+raiseExceptions is False, exceptions get silently ignored. This is
+what is mostly wanted for a logging system - most users will not care about
+errors in the logging system, they are more interested in application
+errors. You could, however, replace this with a custom handler if you wish.
+The specified record is the one which was being processed when the exception
+occurred. (The default value of raiseExceptions is True, as that is
+more useful during development).
Do whatever it takes to actually log the specified logging record. This version
+is intended to be implemented by subclasses and so raises a
+NotImplementedError.
+
+
Warning
+
This method is called after a handler-level lock is acquired, which
+is released after this method returns. When you override this method, note
+that you should be careful when calling anything that invokes other parts of
+the logging API which might do locking, because that might result in a
+deadlock. Specifically:
+
+
Logging configuration APIs acquire the module-level lock, and then
+individual handler-level locks as those handlers are configured.
+
Many logging APIs lock the module-level lock. If such an API is called
+from this method, it could cause a deadlock if a configuration call is
+made on another thread, because that thread will try to acquire the
+module-level lock before the handler-level lock, whereas this thread
+tries to acquire the module-level lock after the handler-level lock
+(because in this method, the handler-level lock has already been acquired).
+
+
+
+
+
+
+
For a list of handlers included as standard, see logging.handlers.
Formatter objects have the following attributes and methods. They are
+responsible for converting a LogRecord to (usually) a string which can
+be interpreted by either a human or an external system. The base
+Formatter allows a formatting string to be specified. If none is
+supplied, the default value of '%(message)s' is used, which just includes
+the message in the logging call. To have additional items of information in the
+formatted output (such as a timestamp), keep reading.
+
A Formatter can be initialized with a format string which makes use of knowledge
+of the LogRecord attributes - such as the default value mentioned above
+making use of the fact that the user’s message and arguments are pre-formatted
+into a LogRecord’s message attribute. This format string contains
+standard Python %-style mapping keys. See section printf-style String Formatting
+for more information on string formatting.
Returns a new instance of the Formatter class. The instance is
+initialized with a format string for the message as a whole, as well as a
+format string for the date/time portion of a message. If no fmt is
+specified, '%(message)s' is used. If no datefmt is specified, a format
+is used which is described in the formatTime() documentation.
+
The style parameter can be one of ‘%’, ‘{’ or ‘$’ and determines how
+the format string will be merged with its data: using one of %-formatting,
+str.format() or string.Template. This only applies to the
+format string fmt (e.g. '%(message)s' or {message}), not to the
+actual log messages passed to Logger.debug etc; see
+Using particular formatting styles throughout your application for more information on using {- and $-formatting
+for log messages.
+
The defaults parameter can be a dictionary with default values to use in
+custom fields. For example:
+logging.Formatter('%(ip)s%(message)s',defaults={"ip":None})
+
+
Changed in version 3.2: The style parameter was added.
+
+
+
Changed in version 3.8: The validate parameter was added. Incorrect or mismatched style and fmt
+will raise a ValueError.
+For example: logging.Formatter('%(asctime)s-%(message)s',style='{').
+
+
+
Changed in version 3.10: The defaults parameter was added.
The record’s attribute dictionary is used as the operand to a string
+formatting operation. Returns the resulting string. Before formatting the
+dictionary, a couple of preparatory steps are carried out. The message
+attribute of the record is computed using msg % args. If the
+formatting string contains '(asctime)', formatTime() is called
+to format the event time. If there is exception information, it is
+formatted using formatException() and appended to the message. Note
+that the formatted exception information is cached in attribute
+exc_text. This is useful because the exception information can be
+pickled and sent across the wire, but you should be careful if you have
+more than one Formatter subclass which customizes the formatting
+of exception information. In this case, you will have to clear the cached
+value (by setting the exc_text attribute to None) after a formatter
+has done its formatting, so that the next formatter to handle the event
+doesn’t use the cached value, but recalculates it afresh.
+
If stack information is available, it’s appended after the exception
+information, using formatStack() to transform it if necessary.
This method should be called from format() by a formatter which
+wants to make use of a formatted time. This method can be overridden in
+formatters to provide for any specific requirement, but the basic behavior
+is as follows: if datefmt (a string) is specified, it is used with
+time.strftime() to format the creation time of the
+record. Otherwise, the format ‘%Y-%m-%d %H:%M:%S,uuu’ is used, where the
+uuu part is a millisecond value and the other letters are as per the
+time.strftime() documentation. An example time in this format is
+2003-01-2300:29:50,411. The resulting string is returned.
+
This function uses a user-configurable function to convert the creation
+time to a tuple. By default, time.localtime() is used; to change
+this for a particular formatter instance, set the converter attribute
+to a function with the same signature as time.localtime() or
+time.gmtime(). To change it for all formatters, for example if you
+want all logging times to be shown in GMT, set the converter
+attribute in the Formatter class.
+
+
Changed in version 3.3: Previously, the default format was hard-coded as in this example:
+2010-09-0622:38:15,292 where the part before the comma is
+handled by a strptime format string ('%Y-%m-%d%H:%M:%S'), and the
+part after the comma is a millisecond value. Because strptime does not
+have a format placeholder for milliseconds, the millisecond value is
+appended using another format string, '%s,%03d' — and both of these
+format strings have been hardcoded into this method. With the change,
+these strings are defined as class-level attributes which can be
+overridden at the instance level when desired. The names of the
+attributes are default_time_format (for the strptime format string)
+and default_msec_format (for appending the millisecond value).
+
+
+
Changed in version 3.9: The default_msec_format can be None.
Formats the specified exception information (a standard exception tuple as
+returned by sys.exc_info()) as a string. This default implementation
+just uses traceback.print_exception(). The resulting string is
+returned.
Formats the specified stack information (a string as returned by
+traceback.print_stack(), but with the last newline removed) as a
+string. This default implementation just returns the input value.
A base formatter class suitable for subclassing when you want to format a
+number of records. You can pass a Formatter instance which you want
+to use to format each line (that corresponds to a single record). If not
+specified, the default formatter (which just outputs the event message) is
+used as the line formatter.
Return a header for a list of records. The base implementation just
+returns the empty string. You will need to override this method if you
+want specific behaviour, e.g. to show the count of records, a title or a
+separator line.
Return a footer for a list of records. The base implementation just
+returns the empty string. You will need to override this method if you
+want specific behaviour, e.g. to show the count of records or a separator
+line.
Return formatted text for a list of records. The base implementation
+just returns the empty string if there are no records; otherwise, it
+returns the concatenation of the header, each record formatted with the
+line formatter, and the footer.
Filters can be used by Handlers and Loggers for more sophisticated
+filtering than is provided by levels. The base filter class only allows events
+which are below a certain point in the logger hierarchy. For example, a filter
+initialized with ‘A.B’ will allow events logged by loggers ‘A.B’, ‘A.B.C’,
+‘A.B.C.D’, ‘A.B.D’ etc. but not ‘A.BB’, ‘B.A.B’ etc. If initialized with the
+empty string, all events are passed.
Returns an instance of the Filter class. If name is specified, it
+names a logger which, together with its children, will have its events allowed
+through the filter. If name is the empty string, allows every event.
Is the specified record to be logged? Returns zero for no, nonzero for
+yes. If deemed appropriate, the record may be modified in-place by this
+method.
+
+
+
+
+
Note that filters attached to handlers are consulted before an event is
+emitted by the handler, whereas filters attached to loggers are consulted
+whenever an event is logged (using debug(), info(),
+etc.), before sending an event to handlers. This means that events which have
+been generated by descendant loggers will not be filtered by a logger’s filter
+setting, unless the filter has also been applied to those descendant loggers.
+
You don’t actually need to subclass Filter: you can pass any instance
+which has a filter method with the same semantics.
+
+
Changed in version 3.2: You don’t need to create specialized Filter classes, or use other
+classes with a filter method: you can use a function (or other
+callable) as a filter. The filtering logic will check to see if the filter
+object has a filter attribute: if it does, it’s assumed to be a
+Filter and its filter() method is called. Otherwise, it’s
+assumed to be a callable and called with the record as the single
+parameter. The returned value should conform to that returned by
+filter().
+
+
Although filters are used primarily to filter records based on more
+sophisticated criteria than levels, they get to see every record which is
+processed by the handler or logger they’re attached to: this can be useful if
+you want to do things like counting how many records were processed by a
+particular logger or handler, or adding, changing or removing attributes in
+the LogRecord being processed. Obviously changing the LogRecord needs
+to be done with some care, but it does allow the injection of contextual
+information into logs (see Using Filters to impart contextual information).
LogRecord instances are created automatically by the Logger
+every time something is logged, and can be created manually via
+makeLogRecord() (for example, from a pickled event received over the
+wire).
Contains all the information pertinent to the event being logged.
+
The primary information is passed in msg and args,
+which are combined using msg%args to create
+the message attribute of the record.
+
+
Parameters
+
+
name (str) – The name of the logger used to log the event
+represented by this LogRecord.
+Note that the logger name in the LogRecord
+will always have this value,
+even though it may be emitted by a handler
+attached to a different (ancestor) logger.
+
level (int) – The numeric level of the logging event
+(such as 10 for DEBUG, 20 for INFO, etc).
+Note that this is converted to two attributes of the LogRecord:
+levelno for the numeric value
+and levelname for the corresponding level name.
+
pathname (str) – The full string path of the source file
+where the logging call was made.
+
lineno (int) – The line number in the source file
+where the logging call was made.
+
msg (Any) – The event description message,
+which can be a %-format string with placeholders for variable data,
+or an arbitrary object (see Using arbitrary objects as messages).
+
args (tuple | dict[str, Any]) – Variable data to merge into the msg argument
+to obtain the event description.
Returns the message for this LogRecord instance after merging any
+user-supplied arguments with the message. If the user-supplied message
+argument to the logging call is not a string, str() is called on it to
+convert it to a string. This allows use of user-defined classes as
+messages, whose __str__ method can return the actual format string to
+be used.
+
+
+
+
Changed in version 3.2: The creation of a LogRecord has been made more configurable by
+providing a factory which is used to create the record. The factory can be
+set using getLogRecordFactory() and setLogRecordFactory()
+(see this for the factory’s signature).
+
+
This functionality can be used to inject your own values into a
+LogRecord at creation time. You can use the following pattern:
With this pattern, multiple factories could be chained, and as long
+as they don’t overwrite each other’s attributes or unintentionally
+overwrite the standard attributes listed above, there should be no
+surprises.
The LogRecord has a number of attributes, most of which are derived from the
+parameters to the constructor. (Note that the names do not always correspond
+exactly between the LogRecord constructor parameters and the LogRecord
+attributes.) These attributes can be used to merge data from the record into
+the format string. The following table lists (in alphabetical order) the
+attribute names, their meanings and the corresponding placeholder in a %-style
+format string.
+
If you are using {}-formatting (str.format()), you can use
+{attrname} as the placeholder in the format string. If you are using
+$-formatting (string.Template), use the form ${attrname}. In
+both cases, of course, replace attrname with the actual attribute name
+you want to use.
+
In the case of {}-formatting, you can specify formatting flags by placing them
+after the attribute name, separated from it with a colon. For example: a
+placeholder of {msecs:03d} would format a millisecond value of 4 as
+004. Refer to the str.format() documentation for full details on
+the options available to you.
+
+
+
+
+
+
+
+
Attribute name
+
Format
+
Description
+
+
+
+
args
+
You shouldn’t need to
+format this yourself.
+
The tuple of arguments merged into msg to
+produce message, or a dict whose values
+are used for the merge (when there is only one
+argument, and it is a dictionary).
+
+
asctime
+
%(asctime)s
+
Human-readable time when the
+LogRecord was created. By default
+this is of the form ‘2003-07-08 16:49:45,896’
+(the numbers after the comma are millisecond
+portion of the time).
Source line number where the logging call was
+issued (if available).
+
+
message
+
%(message)s
+
The logged message, computed as msg%
+args. This is set when
+Formatter.format() is invoked.
+
+
module
+
%(module)s
+
Module (name portion of filename).
+
+
msecs
+
%(msecs)d
+
Millisecond portion of the time when the
+LogRecord was created.
+
+
msg
+
You shouldn’t need to
+format this yourself.
+
The format string passed in the original
+logging call. Merged with args to
+produce message, or an arbitrary object
+(see Using arbitrary objects as messages).
+
+
name
+
%(name)s
+
Name of the logger used to log the call.
+
+
pathname
+
%(pathname)s
+
Full pathname of the source file where the
+logging call was issued (if available).
+
+
process
+
%(process)d
+
Process ID (if available).
+
+
processName
+
%(processName)s
+
Process name (if available).
+
+
relativeCreated
+
%(relativeCreated)d
+
Time in milliseconds when the LogRecord was
+created, relative to the time the logging
+module was loaded.
+
+
stack_info
+
You shouldn’t need to
+format this yourself.
+
Stack frame information (where available)
+from the bottom of the stack in the current
+thread, up to and including the stack frame
+of the logging call which resulted in the
+creation of this record.
Modifies the message and/or keyword arguments passed to a logging call in
+order to insert contextual information. This implementation takes the object
+passed as extra to the constructor and adds it to kwargs using key
+‘extra’. The return value is a (msg, kwargs) tuple which has the
+(possibly modified) versions of the arguments passed in.
The logging module is intended to be thread-safe without any special work
+needing to be done by its clients. It achieves this though using threading
+locks; there is one lock to serialize access to the module’s shared data, and
+each handler also creates a lock to serialize access to its underlying I/O.
+
If you are implementing asynchronous signal handlers using the signal
+module, you may not be able to use logging from within such handlers. This is
+because lock implementations in the threading module are not always
+re-entrant, and so cannot be invoked from such signal handlers.
Return a logger with the specified name or, if name is None, return a
+logger which is the root logger of the hierarchy. If specified, the name is
+typically a dot-separated hierarchical name like ‘a’, ‘a.b’ or ‘a.b.c.d’.
+Choice of these names is entirely up to the developer who is using logging.
+
All calls to this function with a given name return the same logger instance.
+This means that logger instances never need to be passed between different parts
+of an application.
Return either the standard Logger class, or the last class passed to
+setLoggerClass(). This function may be called from within a new class
+definition, to ensure that installing a customized Logger class will
+not undo customizations already applied by other code. For example:
+
classMyLogger(logging.getLoggerClass()):
+ # ... override behaviour here
+
Return a callable which is used to create a LogRecord.
+
+
New in version 3.2: This function has been provided, along with setLogRecordFactory(),
+to allow developers more control over how the LogRecord
+representing a logging event is constructed.
+
+
See setLogRecordFactory() for more information about the how the
+factory is called.
Logs a message with level DEBUG on the root logger. The msg is the
+message format string, and the args are the arguments which are merged into
+msg using the string formatting operator. (Note that this means that you can
+use keywords in the format string, together with a single dictionary argument.)
+
There are three keyword arguments in kwargs which are inspected: exc_info
+which, if it does not evaluate as false, causes exception information to be
+added to the logging message. If an exception tuple (in the format returned by
+sys.exc_info()) or an exception instance is provided, it is used;
+otherwise, sys.exc_info() is called to get the exception information.
+
The second optional keyword argument is stack_info, which defaults to
+False. If true, stack information is added to the logging
+message, including the actual logging call. Note that this is not the same
+stack information as that displayed through specifying exc_info: The
+former is stack frames from the bottom of the stack up to the logging call
+in the current thread, whereas the latter is information about stack frames
+which have been unwound, following an exception, while searching for
+exception handlers.
+
You can specify stack_info independently of exc_info, e.g. to just show
+how you got to a certain point in your code, even when no exceptions were
+raised. The stack frames are printed following a header line which says:
+
Stack (most recent call last):
+
+
+
This mimics the Traceback(mostrecentcalllast): which is used when
+displaying exception frames.
+
The third optional keyword argument is extra which can be used to pass a
+dictionary which is used to populate the __dict__ of the LogRecord created for
+the logging event with user-defined attributes. These custom attributes can then
+be used as you like. For example, they could be incorporated into logged
+messages. For example:
The keys in the dictionary passed in extra should not clash with the keys used
+by the logging system. (See the Formatter documentation for more
+information on which keys are used by the logging system.)
+
If you choose to use these attributes in logged messages, you need to exercise
+some care. In the above example, for instance, the Formatter has been
+set up with a format string which expects ‘clientip’ and ‘user’ in the attribute
+dictionary of the LogRecord. If these are missing, the message will not be
+logged because a string formatting exception will occur. So in this case, you
+always need to pass the extra dictionary with these keys.
+
While this might be annoying, this feature is intended for use in specialized
+circumstances, such as multi-threaded servers where the same code executes in
+many contexts, and interesting conditions which arise are dependent on this
+context (such as remote client IP address and authenticated user name, in the
+above example). In such circumstances, it is likely that specialized
+Formatters would be used with particular Handlers.
Logs a message with level ERROR on the root logger. The arguments are
+interpreted as for debug(). Exception info is added to the logging
+message. This function should only be called from an exception handler.
Provides an overriding level level for all loggers which takes precedence over
+the logger’s own level. When the need arises to temporarily throttle logging
+output down across the whole application, this function can be useful. Its
+effect is to disable all logging calls of severity level and below, so that
+if you call it with a value of INFO, then all INFO and DEBUG events would be
+discarded, whereas those of severity WARNING and above would be processed
+according to the logger’s effective level. If
+logging.disable(logging.NOTSET) is called, it effectively removes this
+overriding level, so that logging output again depends on the effective
+levels of individual loggers.
+
Note that if you have defined any custom logging level higher than
+CRITICAL (this is not recommended), you won’t be able to rely on the
+default value for the level parameter, but will have to explicitly supply a
+suitable value.
+
+
Changed in version 3.7: The level parameter was defaulted to level CRITICAL. See
+bpo-28524 for more information about this change.
Associates level level with text levelName in an internal dictionary, which is
+used to map numeric levels to a textual representation, for example when a
+Formatter formats a message. This function can also be used to define
+your own levels. The only constraints are that all levels used must be
+registered using this function, levels should be positive integers and they
+should increase in increasing order of severity.
+
+
Note
+
If you are thinking of defining your own levels, please see the
+section on Custom Levels.
Returns a mapping from level names to their corresponding logging levels. For example, the
+string “CRITICAL” maps to CRITICAL. The returned mapping is copied from an internal
+mapping on each call to this function.
Returns the textual or numeric representation of logging level level.
+
If level is one of the predefined levels CRITICAL, ERROR,
+WARNING, INFO or DEBUG then you get the
+corresponding string. If you have associated levels with names using
+addLevelName() then the name you have associated with level is
+returned. If a numeric value corresponding to one of the defined levels is
+passed in, the corresponding string representation is returned.
+
The level parameter also accepts a string representation of the level such
+as ‘INFO’. In such cases, this functions returns the corresponding numeric
+value of the level.
+
If no matching numeric or string value is passed in, the string
+‘Level %s’ % level is returned.
+
+
Note
+
Levels are internally integers (as they need to be compared in the
+logging logic). This function is used to convert between an integer level
+and the level name displayed in the formatted log output by means of the
+%(levelname)s format specifier (see LogRecord attributes), and
+vice versa.
+
+
+
Changed in version 3.4: In Python versions earlier than 3.4, this function could also be passed a
+text level, and would return the corresponding numeric value of the level.
+This undocumented behaviour was considered a mistake, and was removed in
+Python 3.4, but reinstated in 3.4.2 due to retain backward compatibility.
Creates and returns a new LogRecord instance whose attributes are
+defined by attrdict. This function is useful for taking a pickled
+LogRecord attribute dictionary, sent over a socket, and reconstituting
+it as a LogRecord instance at the receiving end.
This function does nothing if the root logger already has handlers
+configured, unless the keyword argument force is set to True.
+
+
Note
+
This function should be called from the main thread
+before other threads are started. In versions of Python prior to
+2.7.1 and 3.2, if this function is called from multiple threads,
+it is possible (in rare circumstances) that a handler will be added
+to the root logger more than once, leading to unexpected results
+such as messages being duplicated in the log.
+
+
The following keyword arguments are supported.
+
+
+
+
+
+
+
Format
+
Description
+
+
+
+
filename
+
Specifies that a FileHandler be
+created, using the specified filename,
+rather than a StreamHandler.
+
+
filemode
+
If filename is specified, open the file
+in this mode. Defaults
+to 'a'.
+
+
format
+
Use the specified format string for the
+handler. Defaults to attributes
+levelname, name and message
+separated by colons.
+
+
datefmt
+
Use the specified date/time format, as
+accepted by time.strftime().
+
+
style
+
If format is specified, use this style
+for the format string. One of '%',
+'{' or '$' for printf-style,
+str.format() or
+string.Template respectively.
+Defaults to '%'.
+
+
level
+
Set the root logger level to the specified
+level.
+
+
stream
+
Use the specified stream to initialize the
+StreamHandler. Note that this
+argument is incompatible with filename -
+if both are present, a ValueError is
+raised.
+
+
handlers
+
If specified, this should be an iterable of
+already created handlers to add to the root
+logger. Any handlers which don’t already
+have a formatter set will be assigned the
+default formatter created in this function.
+Note that this argument is incompatible
+with filename or stream - if both
+are present, a ValueError is raised.
+
+
force
+
If this keyword argument is specified as
+true, any existing handlers attached to the
+root logger are removed and closed, before
+carrying out the configuration as specified
+by the other arguments.
+
+
encoding
+
If this keyword argument is specified along
+with filename, its value is used when the
+FileHandler is created, and thus
+used when opening the output file.
+
+
errors
+
If this keyword argument is specified along
+with filename, its value is used when the
+FileHandler is created, and thus
+used when opening the output file. If not
+specified, the value ‘backslashreplace’ is
+used. Note that if None is specified,
+it will be passed as such to open(),
+which means that it will be treated the
+same as passing ‘errors’.
+
+
+
+
+
Changed in version 3.2: The style argument was added.
+
+
+
Changed in version 3.3: The handlers argument was added. Additional checks were added to
+catch situations where incompatible arguments are specified (e.g.
+handlers together with stream or filename, or stream
+together with filename).
+
+
+
Changed in version 3.8: The force argument was added.
+
+
+
Changed in version 3.9: The encoding and errors arguments were added.
Informs the logging system to perform an orderly shutdown by flushing and
+closing all handlers. This should be called at application exit and no
+further use of the logging system should be made after this call.
+
When the logging module is imported, it registers this function as an exit
+handler (see atexit), so normally there’s no need to do that
+manually.
Tells the logging system to use the class klass when instantiating a logger.
+The class should define __init__() such that only a name argument is
+required, and the __init__() should call Logger.__init__(). This
+function is typically called before any loggers are instantiated by applications
+which need to use custom logger behavior. After this call, as at any other
+time, do not instantiate loggers directly using the subclass: continue to use
+the logging.getLogger() API to get your loggers.
Set a callable which is used to create a LogRecord.
+
+
Parameters
+
factory – The factory callable to be used to instantiate a log record.
+
+
+
+
New in version 3.2: This function has been provided, along with getLogRecordFactory(), to
+allow developers more control over how the LogRecord representing
+a logging event is constructed.
A “handler of last resort” is available through this attribute. This
+is a StreamHandler writing to sys.stderr with a level of
+WARNING, and is used to handle logging events in the absence of any
+logging configuration. The end result is to just print the message to
+sys.stderr. This replaces the earlier error message saying that
+“no handlers could be found for logger XYZ”. If you need the earlier
+behaviour for some reason, lastResort can be set to None.
This function is used to turn the capture of warnings by logging on and
+off.
+
If capture is True, warnings issued by the warnings module will
+be redirected to the logging system. Specifically, a warning will be
+formatted using warnings.formatwarning() and the resulting string
+logged to a logger named 'py.warnings' with a severity of WARNING.
+
If capture is False, the redirection of warnings to the logging system
+will stop, and warnings will be redirected to their original destinations
+(i.e. those in effect before captureWarnings(True) was called).
This is the original source for the logging package. The version of the
+package available from this site is suitable for use with Python 1.5.2, 2.1.x
+and 2.2.x, which do not include the logging package in the standard
+library.
+
+
+
+
+
\ No newline at end of file
diff --git a/applications/ColossalQA/data/tests/test.md b/applications/ColossalQA/data/tests/test.md
new file mode 100644
index 000000000000..20d3c612fef5
--- /dev/null
+++ b/applications/ColossalQA/data/tests/test.md
@@ -0,0 +1,78 @@
+# README Format File for Testing
+
+
+## Table of Contents
+
+- [Table of Contents](#table-of-contents)
+- [Install](#install)
+- [How to Use](#how-to-use)
+- Examples
+ - [Local Chinese Retrieval QA + Chat](examples/retrieval_conversation_zh.py)
+ - [Local English Retrieval QA + Chat](examples/retrieval_conversation_en.py)
+ - [Local Bi-lingual Retrieval QA + Chat](examples/retrieval_conversation_universal.py)
+ - [Experimental AI Agent Based on Chatgpt + Chat](examples/conversation_agent_chatgpt.py)
+
+**As Colossal-AI is undergoing some major updates, this project will be actively maintained to stay in line with the Colossal-AI project.**
+
+## Install
+
+Install colossalqa
+```bash
+# python==3.8.17
+cd ColossalAI/applications/ColossalQA
+pip install -e .
+```
+
+To use the vllm server, please refer to the official guide [here](https://github.com/vllm-project/vllm/tree/main) for installation instruction. Simply run the following command from another terminal.
+```bash
+cd ./vllm/entrypoints
+python api_server.py --host localhost --port $PORT_NUMBER --model $PATH_TO_MODEL --swap-space $SWAP_SPACE_IN_GB
+```
+
+## How to use
+
+### Collect your data
+
+For ChatGPT based Agent we support document retrieval and simple sql search.
+If you want to run the demo locally, we provided document retrieval based conversation system built upon langchain. It accept a wide range of documents.
+
+Read comments under ./colossalqa/data_loader for more detail
+
+### Serving
+Currently use vllm will replace with colossal inference when ready. Please refer class VllmLLM.
+
+### Run the script
+
+We provided scripts for Chinese document retrieval based conversation system, English document retrieval based conversation system, Bi-lingual document retrieval based conversation system and an experimental AI agent with document retrieval and SQL query functionality.
+
+To run the bi-lingual scripts, set the following environmental variables before running the script.
+```bash
+export ZH_MODEL_PATH=XXX
+export ZH_MODEL_NAME: chatglm2
+export EN_MODEL_PATH: XXX
+export EN_MODEL_NAME: llama
+python retrieval_conversation_universal.py
+```
+
+To run retrieval_conversation_en.py. set the following environmental variables.
+```bash
+export EN_MODEL_PATH=XXX
+export EN_MODEL_NAME: llama
+python retrieval_conversation_en.py
+```
+
+To run retrieval_conversation_zh.py. set the following environmental variables.
+```bash
+export ZH_MODEL_PATH=XXX
+export ZH_MODEL_NAME: chatglm2
+python retrieval_conversation_en.py
+```
+
+It will ask you to provide the path to your data during the execution of the script. You can also pass a glob path to load multiple files at once. If csv files are provided, please use ',' as delimiter and '"' as quotation mark. There are no other formatting constraints for loading documents type files. For loading table type files, we use pandas, please refer to [Pandas-Input/Output](https://pandas.pydata.org/pandas-docs/stable/reference/io.html) for file format details.
+
+## The Plan
+
+- [x] build document retrieval QA tool
+- [x] Add long + short term memory
+- [x] Add demo for AI agent with SQL query
+- [x] Add customer retriever for fast construction and retrieving (with incremental mode)
diff --git a/applications/ColossalQA/data/tests/test.txt b/applications/ColossalQA/data/tests/test.txt
new file mode 100644
index 000000000000..ff5bf2dc7742
--- /dev/null
+++ b/applications/ColossalQA/data/tests/test.txt
@@ -0,0 +1,38 @@
+Your Name
+Lorem ipsum dolor sit amet, consectetuer adipiscing elit
+ 123 Your Street
+Your City, ST 12345
+(123) 456-7890
+no_reply@example.com
+ EXPERIENCE
+Company, Location — Job Title
+MONTH 20XX - PRESENT
+Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh.
+Company, Location — Job Title
+MONTH 20XX - MONTH 20XX
+Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh.
+Company, Location — Job Title
+MONTH 20XX - MONTH 20XX
+Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh.
+EDUCATION
+School Name, Location — Degree
+MONTH 20XX - MONTH 20XX
+Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy nibh euismod tincidunt ut laoreet dolore.
+School Name, Location — Degree
+MONTH 20XX - MONTH 20XX
+Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam.
+PROJECTS
+Project Name — Detail
+Lorem ipsum dolor sit amet, consectetuer adipiscing elit.
+ SKILLS
+* Lorem ipsum dolor sit amet.
+* Consectetuer adipiscing elit.
+* Sed diam nonummy nibh euismod tincidunt.
+* Laoreet dolore magna aliquam erat volutpat.
+AWARDS
+Lorem ipsum dolor sit amet Consectetuer adipiscing elit, Sed diam nonummy
+Nibh euismod tincidunt ut laoreet dolore magna aliquam erat volutpat.
+Lorem ipsum dolor sit amet Consectetuer adipiscing elit, Sed diam nonummy
+Nibh euismod tincidunt ut laoreet dolore magna aliquam erat volutpat.
+LANGUAGES
+Lorem ipsum, Dolor sit amet, Consectetuer
\ No newline at end of file
diff --git a/applications/ColossalQA/examples/conversation_agent_chatgpt.py b/applications/ColossalQA/examples/conversation_agent_chatgpt.py
new file mode 100644
index 000000000000..cb5c1a31d03e
--- /dev/null
+++ b/applications/ColossalQA/examples/conversation_agent_chatgpt.py
@@ -0,0 +1,125 @@
+"""
+Script for the multilingual conversation based experimental AI agent
+We used ChatGPT as the language model
+You need openai api key to run this script
+"""
+
+import argparse
+import os
+
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.data_loader.table_dataloader import TableLoader
+from langchain import LLMChain, OpenAI
+from langchain.agents import Tool, ZeroShotAgent
+from langchain.agents.agent import AgentExecutor
+from langchain.agents.agent_toolkits import create_retriever_tool
+from langchain.embeddings.openai import OpenAIEmbeddings
+from langchain.llms import OpenAI
+from langchain.memory import ChatMessageHistory, ConversationBufferMemory
+from langchain.memory.chat_memory import ChatMessageHistory
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain.utilities import SQLDatabase
+from langchain.vectorstores import Chroma
+from langchain_experimental.sql import SQLDatabaseChain
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Experimental AI agent powered by ChatGPT")
+ parser.add_argument("--open_ai_key_path", type=str, default=None, help="path to the plain text open_ai_key file")
+
+ args = parser.parse_args()
+
+ # Setup openai key
+ # Set env var OPENAI_API_KEY or load from a file
+ openai_key = open(args.open_ai_key_path).read()
+ os.environ["OPENAI_API_KEY"] = openai_key
+
+ # Load data served on sql
+ print("Select files for constructing sql database")
+ tools = []
+
+ llm = OpenAI(temperature=0.0)
+
+ while True:
+ file = input("Select a file to load or press Enter to exit:")
+ if file == "":
+ break
+ data_name = input("Enter a short description of the data:")
+
+ table_loader = TableLoader(
+ [[file, data_name.replace(" ", "_")]], sql_path=f"sqlite:///{data_name.replace(' ', '_')}.db"
+ )
+ sql_path = table_loader.get_sql_path()
+
+ # Create sql database
+ db = SQLDatabase.from_uri(sql_path)
+ print(db.get_table_info())
+
+ db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
+ name = f"Query the SQL database regarding {data_name}"
+ description = (
+ f"useful for when you need to answer questions based on data stored on a SQL database regarding {data_name}"
+ )
+ tools.append(
+ Tool(
+ name=name,
+ func=db_chain.run,
+ description=description,
+ )
+ )
+ print(f"Added sql dataset\n\tname={name}\n\tdescription:{description}")
+
+ # VectorDB
+ embedding = OpenAIEmbeddings()
+
+ # Load data serve on sql
+ print("Select files for constructing retriever")
+ while True:
+ file = input("Select a file to load or press Enter to exit:")
+ if file == "":
+ break
+ data_name = input("Enter a short description of the data:")
+ retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
+
+ # Split
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=0)
+ splits = text_splitter.split_documents(retriever_data)
+
+ # Create vector store
+ vectordb = Chroma.from_documents(documents=splits, embedding=embedding)
+ # Create retriever
+ retriever = vectordb.as_retriever(
+ search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.5, "k": 5}
+ )
+ # Add to tool chain
+ name = f"Searches and returns documents regarding {data_name}."
+ tools.append(create_retriever_tool(retriever, data_name, name))
+
+ prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools. If none of the tools can be used to answer the question. Do not share uncertain answer unless you think answering the question doesn't need any background information. In that case, try to answer the question directly."""
+ suffix = """You are provided with the following background knowledge:
+ Begin!"
+
+ {chat_history}
+ Question: {input}
+ {agent_scratchpad}"""
+
+ prompt = ZeroShotAgent.create_prompt(
+ tools,
+ prefix=prefix,
+ suffix=suffix,
+ input_variables=["input", "chat_history", "agent_scratchpad"],
+ )
+
+ memory = ConversationBufferMemory(memory_key="chat_history", chat_memory=ChatMessageHistory())
+
+ llm_chain = LLMChain(llm=OpenAI(temperature=0.7), prompt=prompt)
+ agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
+ agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory)
+
+ while True:
+ user_input = input("User: ")
+ if " end " in user_input:
+ print("Agent: Happy to chat with you :)")
+ break
+ agent_response = agent_chain.run(user_input)
+ print(f"Agent: {agent_response}")
+ table_loader.sql_engine.dispose()
diff --git a/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py b/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py
new file mode 100644
index 000000000000..00b920d274bc
--- /dev/null
+++ b/applications/ColossalQA/examples/retrieval_conversation_chatgpt.py
@@ -0,0 +1,131 @@
+"""
+Multilingual retrieval based conversation system backed by ChatGPT
+"""
+
+import argparse
+import os
+
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.memory import ConversationBufferWithSummary
+from colossalqa.retriever import CustomRetriever
+from langchain import LLMChain
+from langchain.chains import RetrievalQA
+from langchain.embeddings import HuggingFaceEmbeddings
+from langchain.llms import OpenAI
+from langchain.prompts.prompt import PromptTemplate
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Multilingual retrieval based conversation system backed by ChatGPT")
+ parser.add_argument("--open_ai_key_path", type=str, default=None, help="path to the model")
+ parser.add_argument(
+ "--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing"
+ )
+
+ args = parser.parse_args()
+
+ if not os.path.exists(args.sql_file_path):
+ os.makedirs(args.sql_file_path)
+
+ # Setup openai key
+ # Set env var OPENAI_API_KEY or load from a file
+ openai_key = open(args.open_ai_key_path).read()
+ os.environ["OPENAI_API_KEY"] = openai_key
+
+ llm = OpenAI(temperature=0.6)
+
+ information_retriever = CustomRetriever(k=3, sql_file_path=args.sql_file_path, verbose=True)
+ # VectorDB
+ embedding = HuggingFaceEmbeddings(
+ model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
+ )
+
+ # Define memory with summarization ability
+ memory = ConversationBufferWithSummary(llm=llm)
+
+ # Load data to vector store
+ print("Select files for constructing retriever")
+ documents = []
+ while True:
+ file = input("Enter a file path or press Enter directory without input to exit:").strip()
+ if file == "":
+ break
+ data_name = input("Enter a short description of the data:")
+ retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
+
+ # Split
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=200, chunk_overlap=0)
+ splits = text_splitter.split_documents(retriever_data)
+ documents.extend(splits)
+ # Create retriever
+ information_retriever.add_documents(docs=documents, cleanup="incremental", mode="by_source", embedding=embedding)
+
+ prompt_template = """Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
+ If the answer cannot be infered based on the given context, please don't share false information.
+ Use the context and chat history to respond to the human's input at the end or carry on the conversation. You should generate one response only. No following up is needed.
+
+ context:
+ {context}
+
+ chat history
+ {chat_history}
+
+ Human: {question}
+ Assistant:"""
+
+ prompt_template_disambiguate = """You are a helpful, respectful and honest assistant. You always follow the instruction.
+ Please replace any ambiguous references in the given sentence with the specific names or entities mentioned in the chat history or just output the original sentence if no chat history is provided or if the sentence doesn't contain ambiguous references. Your output should be the disambiguated sentence itself (in the same line as "disambiguated sentence:") and contain nothing else.
+
+ Here is an example:
+ Chat history:
+ Human: I have a friend, Mike. Do you know him?
+ Assistant: Yes, I know a person named Mike
+
+ sentence: What's his favorite food?
+ disambiguated sentence: What's Mike's favorite food?
+ END OF EXAMPLE
+
+ Chat history:
+ {chat_history}
+
+ sentence: {input}
+ disambiguated sentence:"""
+
+ PROMPT = PromptTemplate(template=prompt_template, input_variables=["question", "chat_history", "context"])
+
+ memory.initiate_document_retrieval_chain(
+ llm,
+ PROMPT,
+ information_retriever,
+ chain_type_kwargs={
+ "chat_history": "",
+ },
+ )
+
+ PROMPT_DISAMBIGUATE = PromptTemplate(
+ template=prompt_template_disambiguate, input_variables=["chat_history", "input"]
+ )
+
+ llm_chain = RetrievalQA.from_chain_type(
+ llm=llm,
+ verbose=False,
+ chain_type="stuff",
+ retriever=information_retriever,
+ chain_type_kwargs={"prompt": PROMPT, "memory": memory},
+ )
+ llm_chain_disambiguate = LLMChain(llm=llm, prompt=PROMPT_DISAMBIGUATE)
+
+ def disambiguity(input):
+ out = llm_chain_disambiguate.run({"input": input, "chat_history": memory.buffer})
+ return out.split("\n")[0]
+
+ information_retriever.set_rephrase_handler(disambiguity)
+
+ while True:
+ user_input = input("User: ")
+ if " end " in user_input:
+ print("Agent: Happy to chat with you :)")
+ break
+ agent_response = llm_chain.run(user_input)
+ agent_response = agent_response.split("\n")[0]
+ print(f"Agent: {agent_response}")
diff --git a/applications/ColossalQA/examples/retrieval_conversation_en.py b/applications/ColossalQA/examples/retrieval_conversation_en.py
new file mode 100644
index 000000000000..e0fe46ae6322
--- /dev/null
+++ b/applications/ColossalQA/examples/retrieval_conversation_en.py
@@ -0,0 +1,119 @@
+"""
+Script for English retrieval based conversation system backed by LLaMa2
+"""
+import argparse
+import os
+
+from colossalqa.chain.retrieval_qa.base import RetrievalQA
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.local.llm import ColossalAPI, ColossalLLM
+from colossalqa.memory import ConversationBufferWithSummary
+from colossalqa.prompt.prompt import (
+ EN_RETRIEVAL_QA_REJECTION_ANSWER,
+ EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ PROMPT_DISAMBIGUATE_EN,
+ PROMPT_RETRIEVAL_QA_EN,
+)
+from colossalqa.retriever import CustomRetriever
+from langchain import LLMChain
+from langchain.embeddings import HuggingFaceEmbeddings
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+if __name__ == "__main__":
+ # Parse arguments
+ parser = argparse.ArgumentParser(description="English retrieval based conversation system backed by LLaMa2")
+ parser.add_argument("--model_path", type=str, default=None, help="path to the model")
+ parser.add_argument("--model_name", type=str, default=None, help="name of the model")
+ parser.add_argument(
+ "--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing"
+ )
+
+ args = parser.parse_args()
+ if not os.path.exists(args.sql_file_path):
+ os.makedirs(args.sql_file_path)
+
+ colossal_api = ColossalAPI.get_api(args.model_name, args.model_path)
+ llm = ColossalLLM(n=1, api=colossal_api)
+
+ # Define the retriever
+ information_retriever = CustomRetriever(k=3, sql_file_path=args.sql_file_path, verbose=True)
+
+ # Setup embedding model locally
+ embedding = HuggingFaceEmbeddings(
+ model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
+ )
+
+ # Define memory with summarization ability
+ memory = ConversationBufferWithSummary(
+ llm=llm, max_tokens=2000, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
+ )
+
+ # Define the chain to preprocess the input
+ # Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
+ llm_chain_disambiguate = LLMChain(
+ llm=llm, prompt=PROMPT_DISAMBIGUATE_EN, llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True}
+ )
+
+ def disambiguity(input):
+ out = llm_chain_disambiguate.run(input=input, chat_history=memory.buffer, stop=["\n"])
+ return out.split("\n")[0]
+
+ # Load data to vector store
+ print("Select files for constructing retriever")
+ documents = []
+ while True:
+ file = input("Enter a file path or press Enter directory without input to exit:").strip()
+ if file == "":
+ break
+ data_name = input("Enter a short description of the data:")
+ separator = input(
+ "Enter a separator to force separating text into chunks, if no separator is given, the defaut separator is '\\n\\n'. Note that"
+ + "we use neural text spliter to split texts into chunks, the seperator only serves as a delimiter to force split long passage into"
+ + " chunks before passing to the neural network. Press ENTER directly to skip:"
+ )
+ separator = separator if separator != "" else "\n\n"
+ retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
+
+ # Split
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
+ splits = text_splitter.split_documents(retriever_data)
+ documents.extend(splits)
+ # Create retriever
+ information_retriever.add_documents(docs=documents, cleanup="incremental", mode="by_source", embedding=embedding)
+
+ # Set document retrieval chain, we need this chain to calculate prompt length
+ memory.initiate_document_retrieval_chain(
+ llm,
+ PROMPT_RETRIEVAL_QA_EN,
+ information_retriever,
+ chain_type_kwargs={
+ "chat_history": "",
+ },
+ )
+
+ # Define retrieval chain
+ retrieval_chain = RetrievalQA.from_chain_type(
+ llm=llm,
+ verbose=False,
+ chain_type="stuff",
+ retriever=information_retriever,
+ chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_EN, "memory": memory},
+ llm_kwargs={"max_new_tokens": 50, "temperature": 0.75, "do_sample": True},
+ )
+ # Set disambiguity handler
+ information_retriever.set_rephrase_handler(disambiguity)
+
+ # Start conversation
+ while True:
+ user_input = input("User: ")
+ if "END" == user_input:
+ print("Agent: Happy to chat with you :)")
+ break
+ agent_response = retrieval_chain.run(
+ query=user_input,
+ stop=["Human: "],
+ rejection_trigger_keywrods=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ rejection_answer=EN_RETRIEVAL_QA_REJECTION_ANSWER,
+ )
+ agent_response = agent_response.split("\n")[0]
+ print(f"Agent: {agent_response}")
diff --git a/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py
new file mode 100644
index 000000000000..d98a75592372
--- /dev/null
+++ b/applications/ColossalQA/examples/retrieval_conversation_en_customer_service.py
@@ -0,0 +1,149 @@
+"""
+Script for English retrieval based conversation system backed by LLaMa2
+"""
+import argparse
+import json
+import os
+
+from colossalqa.chain.retrieval_qa.base import RetrievalQA
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.local.llm import ColossalAPI, ColossalLLM
+from colossalqa.memory import ConversationBufferWithSummary
+from colossalqa.prompt.prompt import (
+ EN_RETRIEVAL_QA_REJECTION_ANSWER,
+ EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ PROMPT_DISAMBIGUATE_EN,
+ PROMPT_RETRIEVAL_QA_EN,
+)
+from colossalqa.retriever import CustomRetriever
+from langchain import LLMChain
+from langchain.embeddings import HuggingFaceEmbeddings
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+if __name__ == "__main__":
+ # Parse arguments
+ parser = argparse.ArgumentParser(description="English retrieval based conversation system backed by LLaMa2")
+ parser.add_argument("--model_path", type=str, default=None, help="path to the model")
+ parser.add_argument("--model_name", type=str, default=None, help="name of the model")
+ parser.add_argument(
+ "--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing"
+ )
+
+ args = parser.parse_args()
+
+ if not os.path.exists(args.sql_file_path):
+ os.makedirs(args.sql_file_path)
+
+ colossal_api = ColossalAPI.get_api(args.model_name, args.model_path)
+ llm = ColossalLLM(n=1, api=colossal_api)
+
+ # Define the retriever
+ information_retriever = CustomRetriever(k=3, sql_file_path=args.sql_file_path, verbose=True)
+
+ # Setup embedding model locally
+ embedding = HuggingFaceEmbeddings(
+ model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
+ )
+
+ # Define memory with summarization ability
+ memory = ConversationBufferWithSummary(
+ llm=llm, max_tokens=2000, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
+ )
+
+ # Define the chain to preprocess the input
+ # Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
+ llm_chain_disambiguate = LLMChain(
+ llm=llm, prompt=PROMPT_DISAMBIGUATE_EN, llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True}
+ )
+
+ def disambiguity(input):
+ out = llm_chain_disambiguate.run(input=input, chat_history=memory.buffer, stop=["\n"])
+ return out.split("\n")[0]
+
+ # Load data to vector store
+ print("Select files for constructing retriever")
+ documents = []
+
+ # preprocess data
+ if not os.path.exists("../data/data_sample/custom_service_preprocessed.json"):
+ if not os.path.exists("../data/data_sample/custom_service.json"):
+ raise ValueError(
+ "custom_service.json not found, please download the data from HuggingFace Datasets: qgyd2021/e_commerce_customer_service"
+ )
+ data = json.load(open("../data/data_sample/custom_service.json", "r", encoding="utf8"))
+ preprocessed = []
+ for row in data["rows"]:
+ preprocessed.append({"key": row["row"]["query"], "value": row["row"]["response"]})
+ data = {}
+ data["data"] = preprocessed
+ with open("../data/data_sample/custom_service_preprocessed.json", "w", encoding="utf8") as f:
+ json.dump(data, f, ensure_ascii=False)
+
+ # define metadata function which is used to format the prompt with value in metadata instead of key,
+ # the later is langchain's default behavior
+ def metadata_func(data_sample, additional_fields):
+ """
+ metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
+ object extracted by the jq_schema and the default metadata and returns
+ a dict of the updated metadata.
+
+ To use key-value format, the metadata_func should be defined as follows:
+ metadata = {'value': 'a string to be used to format the prompt', 'is_key_value_mapping': True}
+ """
+ metadata = {}
+ metadata["value"] = f"Question: {data_sample['key']}\nAnswer:{data_sample['value']}"
+ metadata["is_key_value_mapping"] = True
+ assert "value" not in additional_fields
+ assert "is_key_value_mapping" not in additional_fields
+ metadata.update(additional_fields)
+ return metadata
+
+ retriever_data = DocumentLoader(
+ [["../data/data_sample/custom_service_preprocessed.json", "CustomerServiceDemo"]],
+ content_key="key",
+ metadata_func=metadata_func,
+ ).all_data
+
+ # Split
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
+ splits = text_splitter.split_documents(retriever_data)
+ documents.extend(splits)
+
+ # Create retriever
+ information_retriever.add_documents(docs=documents, cleanup="incremental", mode="by_source", embedding=embedding)
+
+ # Set document retrieval chain, we need this chain to calculate prompt length
+ memory.initiate_document_retrieval_chain(
+ llm,
+ PROMPT_RETRIEVAL_QA_EN,
+ information_retriever,
+ chain_type_kwargs={
+ "chat_history": "",
+ },
+ )
+
+ # Define retrieval chain
+ retrieval_chain = RetrievalQA.from_chain_type(
+ llm=llm,
+ verbose=False,
+ chain_type="stuff",
+ retriever=information_retriever,
+ chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_EN, "memory": memory},
+ llm_kwargs={"max_new_tokens": 50, "temperature": 0.75, "do_sample": True},
+ )
+ # Set disambiguity handler
+ information_retriever.set_rephrase_handler(disambiguity)
+ # Start conversation
+ while True:
+ user_input = input("User: ")
+ if "END" == user_input:
+ print("Agent: Happy to chat with you :)")
+ break
+ agent_response = retrieval_chain.run(
+ query=user_input,
+ stop=["Human: "],
+ rejection_trigger_keywrods=EN_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ rejection_answer=EN_RETRIEVAL_QA_REJECTION_ANSWER,
+ )
+ agent_response = agent_response.split("\n")[0]
+ print(f"Agent: {agent_response}")
diff --git a/applications/ColossalQA/examples/retrieval_conversation_universal.py b/applications/ColossalQA/examples/retrieval_conversation_universal.py
new file mode 100644
index 000000000000..361aa9833d27
--- /dev/null
+++ b/applications/ColossalQA/examples/retrieval_conversation_universal.py
@@ -0,0 +1,22 @@
+import argparse
+from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation
+
+if __name__ == '__main__':
+ # Parse arguments
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--en_model_path', type=str, default=None)
+ parser.add_argument('--zh_model_path', type=str, default=None)
+ parser.add_argument('--zh_model_name', type=str, default=None)
+ parser.add_argument('--en_model_name', type=str, default=None)
+ parser.add_argument('--sql_file_path', type=str, default=None, help='path to the a empty folder for storing sql files for indexing')
+ args = parser.parse_args()
+
+ # Will ask for documents path in runnning time
+ session = UniversalRetrievalConversation(files_en=None,
+ files_zh=None,
+ zh_model_path=args.zh_model_path, en_model_path=args.en_model_path,
+ zh_model_name=args.zh_model_name, en_model_name=args.en_model_name,
+ sql_file_path=args.sql_file_path
+ )
+ session.start_test_session()
+
\ No newline at end of file
diff --git a/applications/ColossalQA/examples/retrieval_conversation_zh.py b/applications/ColossalQA/examples/retrieval_conversation_zh.py
new file mode 100644
index 000000000000..cbbbefad7c7b
--- /dev/null
+++ b/applications/ColossalQA/examples/retrieval_conversation_zh.py
@@ -0,0 +1,113 @@
+"""
+Script for Chinese retrieval based conversation system backed by ChatGLM
+"""
+import argparse
+import os
+
+from colossalqa.chain.retrieval_qa.base import RetrievalQA
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.local.llm import ColossalAPI, ColossalLLM
+from colossalqa.memory import ConversationBufferWithSummary
+from colossalqa.prompt.prompt import (
+ PROMPT_DISAMBIGUATE_ZH,
+ PROMPT_RETRIEVAL_QA_ZH,
+ SUMMARY_PROMPT_ZH,
+ ZH_RETRIEVAL_QA_REJECTION_ANSWER,
+ ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+)
+from colossalqa.retriever import CustomRetriever
+from colossalqa.text_splitter import ChineseTextSplitter
+from langchain import LLMChain
+from langchain.embeddings import HuggingFaceEmbeddings
+
+if __name__ == "__main__":
+ # Parse arguments
+ parser = argparse.ArgumentParser(description="Chinese retrieval based conversation system backed by ChatGLM2")
+ parser.add_argument("--model_path", type=str, default=None, help="path to the model")
+ parser.add_argument("--model_name", type=str, default=None, help="name of the model")
+ parser.add_argument(
+ "--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing"
+ )
+
+ args = parser.parse_args()
+
+ if not os.path.exists(args.sql_file_path):
+ os.makedirs(args.sql_file_path)
+
+ colossal_api = ColossalAPI.get_api(args.model_name, args.model_path)
+ llm = ColossalLLM(n=1, api=colossal_api)
+
+ # Setup embedding model locally
+ embedding = HuggingFaceEmbeddings(
+ model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
+ )
+ # Define the retriever
+ information_retriever = CustomRetriever(k=3, sql_file_path=args.sql_file_path, verbose=True)
+
+ # Define memory with summarization ability
+ memory = ConversationBufferWithSummary(
+ llm=llm,
+ prompt=SUMMARY_PROMPT_ZH,
+ human_prefix="用户",
+ ai_prefix="Assistant",
+ max_tokens=2000,
+ llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True},
+ )
+
+ # Define the chain to preprocess the input
+ # Disambiguate the input. e.g. "What is the capital of that country?" -> "What is the capital of France?"
+ llm_chain_disambiguate = LLMChain(
+ llm=llm, prompt=PROMPT_DISAMBIGUATE_ZH, llm_kwargs={"max_new_tokens": 30, "temperature": 0.6, "do_sample": True}
+ )
+
+ def disambiguity(input: str):
+ out = llm_chain_disambiguate.run(input=input, chat_history=memory.buffer, stop=["\n"])
+ return out.split("\n")[0]
+
+ # Load data to vector store
+ print("Select files for constructing retriever")
+ documents = []
+ while True:
+ file = input("Enter a file path or press Enter directory without input to exit:").strip()
+ if file == "":
+ break
+ data_name = input("Enter a short description of the data:")
+ retriever_data = DocumentLoader([[file, data_name.replace(" ", "_")]]).all_data
+
+ # Split
+ text_splitter = ChineseTextSplitter()
+ splits = text_splitter.split_documents(retriever_data)
+ documents.extend(splits)
+ # Create retriever
+ information_retriever.add_documents(docs=documents, cleanup="incremental", mode="by_source", embedding=embedding)
+
+ # Set document retrieval chain, we need this chain to calculate prompt length
+ memory.initiate_document_retrieval_chain(llm, PROMPT_RETRIEVAL_QA_ZH, information_retriever)
+
+ # Define retrieval chain
+ llm_chain = RetrievalQA.from_chain_type(
+ llm=llm,
+ verbose=False,
+ chain_type="stuff",
+ retriever=information_retriever,
+ chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_QA_ZH, "memory": memory},
+ llm_kwargs={"max_new_tokens": 150, "temperature": 0.6, "do_sample": True},
+ )
+
+ # Set disambiguity handler
+ information_retriever.set_rephrase_handler(disambiguity)
+
+ # Start conversation
+ while True:
+ user_input = input("User: ")
+ if "END" == user_input:
+ print("Agent: Happy to chat with you :)")
+ break
+ agent_response = llm_chain.run(
+ query=user_input,
+ stop=["答案>"],
+ doc_prefix="支持文档",
+ rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
+ )
+ print(f"Agent: {agent_response}")
diff --git a/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py b/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py
new file mode 100644
index 000000000000..adb6544941f0
--- /dev/null
+++ b/applications/ColossalQA/examples/retrieval_intent_classification_zh_customer_service.py
@@ -0,0 +1,97 @@
+"""
+Script for English retrieval based conversation system backed by LLaMa2
+"""
+import argparse
+import os
+
+from colossalqa.chain.retrieval_qa.base import RetrievalQA
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.local.llm import ColossalAPI, ColossalLLM
+from colossalqa.prompt.prompt import PROMPT_RETRIEVAL_CLASSIFICATION_USE_CASE_ZH
+from colossalqa.retriever import CustomRetriever
+from colossalqa.text_splitter import ChineseTextSplitter
+from langchain.embeddings import HuggingFaceEmbeddings
+
+if __name__ == "__main__":
+ # Parse arguments
+ parser = argparse.ArgumentParser(description="English retrieval based conversation system backed by LLaMa2")
+ parser.add_argument("--model_path", type=str, default=None, help="path to the model")
+ parser.add_argument("--model_name", type=str, default=None, help="name of the model")
+ parser.add_argument(
+ "--sql_file_path", type=str, default=None, help="path to the a empty folder for storing sql files for indexing"
+ )
+
+ args = parser.parse_args()
+
+ if not os.path.exists(args.sql_file_path):
+ os.makedirs(args.sql_file_path)
+
+ colossal_api = ColossalAPI.get_api(args.model_name, args.model_path)
+ llm = ColossalLLM(n=1, api=colossal_api)
+
+ # Define the retriever
+ information_retriever = CustomRetriever(k=2, sql_file_path=args.sql_file_path, verbose=True)
+
+ # Setup embedding model locally
+ embedding = HuggingFaceEmbeddings(
+ model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
+ )
+
+ # Load data to vector store
+ print("Select files for constructing retriever")
+ documents = []
+
+ # define metadata function which is used to format the prompt with value in metadata instead of key,
+ # the later is langchain's default behavior
+ def metadata_func(data_sample, additional_fields):
+ """
+ metadata_func (Callable[Dict, Dict]): A function that takes in the JSON
+ object extracted by the jq_schema and the default metadata and returns
+ a dict of the updated metadata.
+
+ To use key-value format, the metadata_func should be defined as follows:
+ metadata = {'value': 'a string to be used to format the prompt', 'is_key_value_mapping': True}
+ """
+ metadata = {}
+ metadata["value"] = f"Question: {data_sample['key']}\nAnswer:{data_sample['value']}"
+ metadata["is_key_value_mapping"] = True
+ assert "value" not in additional_fields
+ assert "is_key_value_mapping" not in additional_fields
+ metadata.update(additional_fields)
+ return metadata
+
+ retriever_data = DocumentLoader(
+ [["../data/data_sample/custom_service_classification.json", "CustomerServiceDemo"]],
+ content_key="key",
+ metadata_func=metadata_func,
+ ).all_data
+
+ # Split
+ text_splitter = ChineseTextSplitter()
+ splits = text_splitter.split_documents(retriever_data)
+ documents.extend(splits)
+
+ # Create retriever
+ information_retriever.add_documents(docs=documents, cleanup="incremental", mode="by_source", embedding=embedding)
+
+ # Define retrieval chain
+ retrieval_chain = RetrievalQA.from_chain_type(
+ llm=llm,
+ verbose=True,
+ chain_type="stuff",
+ retriever=information_retriever,
+ chain_type_kwargs={"prompt": PROMPT_RETRIEVAL_CLASSIFICATION_USE_CASE_ZH},
+ llm_kwargs={"max_new_tokens": 50, "temperature": 0.75, "do_sample": True},
+ )
+ # Set disambiguity handler
+
+ # Start conversation
+ while True:
+ user_input = input("User: ")
+ if "END" == user_input:
+ print("Agent: Happy to chat with you :)")
+ break
+ # 要使用和custom_service_classification.json 里的key 类似的句子做输入
+ agent_response = retrieval_chain.run(query=user_input, stop=["Human: "])
+ agent_response = agent_response.split("\n")[0]
+ print(f"Agent: {agent_response}")
diff --git a/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py b/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py
new file mode 100644
index 000000000000..0ad547c0093a
--- /dev/null
+++ b/applications/ColossalQA/examples/webui_demo/RAG_ChatBot.py
@@ -0,0 +1,184 @@
+from typing import Dict, Tuple
+
+from colossalqa.chain.retrieval_qa.base import RetrievalQA
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.memory import ConversationBufferWithSummary
+from colossalqa.mylogging import get_logger
+from colossalqa.prompt.prompt import (
+ PROMPT_DISAMBIGUATE_ZH,
+ PROMPT_RETRIEVAL_QA_ZH,
+ SUMMARY_PROMPT_ZH,
+ ZH_RETRIEVAL_QA_REJECTION_ANSWER,
+ ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+)
+from colossalqa.retriever import CustomRetriever
+from colossalqa.text_splitter import ChineseTextSplitter
+from langchain import LLMChain
+from langchain.embeddings import HuggingFaceEmbeddings
+
+logger = get_logger()
+
+DEFAULT_RAG_CFG = {
+ "retri_top_k": 3,
+ "retri_kb_file_path": "./",
+ "verbose": True,
+ "mem_summary_prompt": SUMMARY_PROMPT_ZH,
+ "mem_human_prefix": "用户",
+ "mem_ai_prefix": "Assistant",
+ "mem_max_tokens": 2000,
+ "mem_llm_kwargs": {"max_new_tokens": 50, "temperature": 1, "do_sample": True},
+ "disambig_prompt": PROMPT_DISAMBIGUATE_ZH,
+ "disambig_llm_kwargs": {"max_new_tokens": 30, "temperature": 1, "do_sample": True},
+ "embed_model_name_or_path": "moka-ai/m3e-base",
+ "embed_model_device": {"device": "cpu"},
+ "gen_llm_kwargs": {"max_new_tokens": 100, "temperature": 1, "do_sample": True},
+ "gen_qa_prompt": PROMPT_RETRIEVAL_QA_ZH,
+}
+
+
+class RAG_ChatBot:
+ def __init__(
+ self,
+ llm,
+ rag_config,
+ ) -> None:
+ self.llm = llm
+ self.rag_config = rag_config
+ self.set_embed_model(**self.rag_config)
+ self.set_text_splitter(**self.rag_config)
+ self.set_memory(**self.rag_config)
+ self.set_info_retriever(**self.rag_config)
+ self.set_rag_chain(**self.rag_config)
+ if self.rag_config.get("disambig_prompt", None):
+ self.set_disambig_retriv(**self.rag_config)
+
+ def set_embed_model(self, **kwargs):
+ self.embed_model = HuggingFaceEmbeddings(
+ model_name=kwargs["embed_model_name_or_path"],
+ model_kwargs=kwargs["embed_model_device"],
+ encode_kwargs={"normalize_embeddings": False},
+ )
+
+ def set_text_splitter(self, **kwargs):
+ # Initialize text_splitter
+ self.text_splitter = ChineseTextSplitter()
+
+ def set_memory(self, **kwargs):
+ params = {"llm_kwargs": kwargs["mem_llm_kwargs"]} if kwargs.get("mem_llm_kwargs", None) else {}
+ # Initialize memory with summarization ability
+ self.memory = ConversationBufferWithSummary(
+ llm=self.llm,
+ prompt=kwargs["mem_summary_prompt"],
+ human_prefix=kwargs["mem_human_prefix"],
+ ai_prefix=kwargs["mem_ai_prefix"],
+ max_tokens=kwargs["mem_max_tokens"],
+ **params,
+ )
+
+ def set_info_retriever(self, **kwargs):
+ self.info_retriever = CustomRetriever(
+ k=kwargs["retri_top_k"], sql_file_path=kwargs["retri_kb_file_path"], verbose=kwargs["verbose"]
+ )
+
+ def set_rag_chain(self, **kwargs):
+ params = {"llm_kwargs": kwargs["gen_llm_kwargs"]} if kwargs.get("gen_llm_kwargs", None) else {}
+ self.rag_chain = RetrievalQA.from_chain_type(
+ llm=self.llm,
+ verbose=kwargs["verbose"],
+ chain_type="stuff",
+ retriever=self.info_retriever,
+ chain_type_kwargs={"prompt": kwargs["gen_qa_prompt"], "memory": self.memory},
+ **params,
+ )
+
+ def split_docs(self, documents):
+ doc_splits = self.text_splitter.split_documents(documents)
+ return doc_splits
+
+ def set_disambig_retriv(self, **kwargs):
+ params = {"llm_kwargs": kwargs["disambig_llm_kwargs"]} if kwargs.get("disambig_llm_kwargs", None) else {}
+ self.llm_chain_disambiguate = LLMChain(llm=self.llm, prompt=kwargs["disambig_prompt"], **params)
+
+ def disambiguity(input: str):
+ out = self.llm_chain_disambiguate.run(input=input, chat_history=self.memory.buffer, stop=["\n"])
+ return out.split("\n")[0]
+
+ self.info_retriever.set_rephrase_handler(disambiguity)
+
+ def load_doc_from_console(self, json_parse_args: Dict = {}):
+ documents = []
+ print("Select files for constructing Chinese retriever")
+ while True:
+ file = input("Enter a file path or press Enter directly without input to exit:").strip()
+ if file == "":
+ break
+ data_name = input("Enter a short description of the data:")
+ docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
+ documents.extend(docs)
+ self.documents = documents
+ self.split_docs_and_add_to_mem(**self.rag_config)
+
+ def load_doc_from_files(self, files, data_name="default_kb", json_parse_args: Dict = {}):
+ documents = []
+ for file in files:
+ docs = DocumentLoader([[file, data_name.replace(" ", "_")]], **json_parse_args).all_data
+ documents.extend(docs)
+ self.documents = documents
+ self.split_docs_and_add_to_mem(**self.rag_config)
+
+ def split_docs_and_add_to_mem(self, **kwargs):
+ self.doc_splits = self.split_docs(self.documents)
+ self.info_retriever.add_documents(
+ docs=self.doc_splits, cleanup="incremental", mode="by_source", embedding=self.embed_model
+ )
+ self.memory.initiate_document_retrieval_chain(self.llm, kwargs["gen_qa_prompt"], self.info_retriever)
+
+ def reset_config(self, rag_config):
+ self.rag_config = rag_config
+ self.set_embed_model(**self.rag_config)
+ self.set_text_splitter(**self.rag_config)
+ self.set_memory(**self.rag_config)
+ self.set_info_retriever(**self.rag_config)
+ self.set_rag_chain(**self.rag_config)
+ if self.rag_config.get("disambig_prompt", None):
+ self.set_disambig_retriv(**self.rag_config)
+
+ def run(self, user_input: str, memory: ConversationBufferWithSummary) -> Tuple[str, ConversationBufferWithSummary]:
+ if memory:
+ memory.buffered_history.messages = memory.buffered_history.messages
+ memory.summarized_history_temp.messages = memory.summarized_history_temp.messages
+ result = self.rag_chain.run(
+ query=user_input,
+ stop=[memory.human_prefix + ": "],
+ rejection_trigger_keywrods=ZH_RETRIEVAL_QA_TRIGGER_KEYWORDS,
+ rejection_answer=ZH_RETRIEVAL_QA_REJECTION_ANSWER,
+ )
+ return result.split("\n")[0], memory
+
+ def start_test_session(self):
+ """
+ Simple session for testing purpose
+ """
+ while True:
+ user_input = input("User: ")
+ if "END" == user_input:
+ print("Agent: Happy to chat with you :)")
+ break
+ agent_response, self.memory = self.run(user_input, self.memory)
+ print(f"Agent: {agent_response}")
+
+
+if __name__ == "__main__":
+ # Initialize an Langchain LLM(here we use ChatGPT as an example)
+ from langchain.llms import OpenAI
+
+ llm = OpenAI(openai_api_key="YOUR_OPENAI_API_KEY")
+
+ # chatgpt cannot control temperature, do_sample, etc.
+ DEFAULT_RAG_CFG["mem_llm_kwargs"] = None
+ DEFAULT_RAG_CFG["disambig_llm_kwargs"] = None
+ DEFAULT_RAG_CFG["gen_llm_kwargs"] = None
+
+ rag = RAG_ChatBot(llm, DEFAULT_RAG_CFG)
+ rag.load_doc_from_console()
+ rag.start_test_session()
diff --git a/applications/ColossalQA/examples/webui_demo/README.md b/applications/ColossalQA/examples/webui_demo/README.md
new file mode 100644
index 000000000000..15ce6b5b71be
--- /dev/null
+++ b/applications/ColossalQA/examples/webui_demo/README.md
@@ -0,0 +1,37 @@
+# ColossalQA WebUI Demo
+
+This demo provides a simple WebUI for ColossalQA, enabling you to upload your files as a knowledge base and interact with them through a chat interface in your browser.
+
+The `server.py` initializes the backend RAG chain that can be backed by various language models (e.g., ChatGPT, Huawei Pangu, ChatGLM2). Meanwhile, `webui.py` launches a Gradio-supported chatbot interface.
+
+# Usage
+
+## Installation
+
+First, install the necessary dependencies for ColossalQA:
+
+```sh
+git clone https://github.com/hpcaitech/ColossalAI.git
+cd ColossalAI/applications/ColossalQA/
+pip install -e .
+```
+
+## Configure the RAG Chain
+
+Customize the RAG Chain settings, such as the embedding model (default: moka-ai/m3e) and the language model, in the `start_colossal_qa.sh` script.
+
+For API-based language models (like ChatGPT or Huawei Pangu), provide your API key for authentication. For locally-run models, indicate the path to the model's checkpoint file.
+
+If you want to customize prompts in the RAG Chain, you can have a look at the `RAG_ChatBot.py` file to modify them.
+
+## Run WebUI Demo
+
+Execute the following command to start the demo:
+
+```sh
+bash start_colossal_qa.sh
+```
+
+After launching the script, you can upload files and engage with the chatbot through your web browser.
+
+
\ No newline at end of file
diff --git a/applications/ColossalQA/examples/webui_demo/img/avatar_ai.png b/applications/ColossalQA/examples/webui_demo/img/avatar_ai.png
new file mode 100644
index 000000000000..031daa405b74
Binary files /dev/null and b/applications/ColossalQA/examples/webui_demo/img/avatar_ai.png differ
diff --git a/applications/ColossalQA/examples/webui_demo/img/avatar_user.png b/applications/ColossalQA/examples/webui_demo/img/avatar_user.png
new file mode 100644
index 000000000000..b5a8014b87c3
Binary files /dev/null and b/applications/ColossalQA/examples/webui_demo/img/avatar_user.png differ
diff --git a/applications/ColossalQA/examples/webui_demo/server.py b/applications/ColossalQA/examples/webui_demo/server.py
new file mode 100644
index 000000000000..c3147594fc89
--- /dev/null
+++ b/applications/ColossalQA/examples/webui_demo/server.py
@@ -0,0 +1,117 @@
+import argparse
+import copy
+import json
+import os
+import random
+import string
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from colossalqa.local.llm import ColossalAPI, ColossalLLM
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.retrieval_conversation_zh import ChineseRetrievalConversation
+from colossalqa.retriever import CustomRetriever
+from langchain.embeddings import HuggingFaceEmbeddings
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from RAG_ChatBot import RAG_ChatBot, DEFAULT_RAG_CFG
+
+# Define the mapping between embed_model_name(passed from Front End) and the actual path on the back end server
+EMBED_MODEL_DICT = {
+ "m3e": os.environ.get("EMB_MODEL_PATH", DEFAULT_RAG_CFG["embed_model_name_or_path"])
+}
+# Define the mapping between LLM_name(passed from Front End) and the actual path on the back end server
+LLM_DICT = {
+ "chatglm2": os.environ.get("CHAT_LLM_PATH", "THUDM/chatglm-6b"),
+ "pangu": "Pangu_API",
+ "chatgpt": "OpenAI_API"
+}
+
+def randomword(length):
+ letters = string.ascii_lowercase
+ return "".join(random.choice(letters) for i in range(length))
+
+class ColossalQAServerRequestHandler(BaseHTTPRequestHandler):
+ chatbot = None
+ def _set_response(self):
+ """
+ set http header for response
+ """
+ self.send_response(200)
+ self.send_header("Content-type", "application/json")
+ self.end_headers()
+
+ def do_POST(self):
+ content_length = int(self.headers["Content-Length"])
+ post_data = self.rfile.read(content_length)
+ received_json = json.loads(post_data.decode("utf-8"))
+ print(received_json)
+ # conversation_ready is False(user's first request): Need to upload files and initialize the RAG chain
+ if received_json["conversation_ready"] is False:
+ self.rag_config = DEFAULT_RAG_CFG.copy()
+ try:
+ assert received_json["embed_model_name"] in EMBED_MODEL_DICT
+ assert received_json["llm_name"] in LLM_DICT
+ self.docs_files = received_json["docs"]
+ embed_model_name, llm_name = received_json["embed_model_name"], received_json["llm_name"]
+
+ # Find the embed_model/llm ckpt path on the back end server.
+ embed_model_path, llm_path = EMBED_MODEL_DICT[embed_model_name], LLM_DICT[llm_name]
+ self.rag_config["embed_model_name_or_path"] = embed_model_path
+
+ # Create the storage path for knowledge base files
+ self.rag_config["retri_kb_file_path"] = os.path.join(os.environ["TMP"], "colossalqa_kb/"+randomword(20))
+ if not os.path.exists(self.rag_config["retri_kb_file_path"]):
+ os.makedirs(self.rag_config["retri_kb_file_path"])
+
+ if (embed_model_path is not None) and (llm_path is not None):
+ # ---- Intialize LLM, QA_chatbot here ----
+ print("Initializing LLM...")
+ if llm_path == "Pangu_API":
+ from colossalqa.local.pangu_llm import Pangu
+ self.llm = Pangu(id=1)
+ self.llm.set_auth_config() # verify user's auth info here
+ self.rag_config["mem_llm_kwargs"] = None
+ self.rag_config["disambig_llm_kwargs"] = None
+ self.rag_config["gen_llm_kwargs"] = None
+ elif llm_path == "OpenAI_API":
+ from langchain.llms import OpenAI
+ self.llm = OpenAI()
+ self.rag_config["mem_llm_kwargs"] = None
+ self.rag_config["disambig_llm_kwargs"] = None
+ self.rag_config["gen_llm_kwargs"] = None
+ else:
+ # ** (For Testing Only) **
+ # In practice, all LLMs will run on the cloud platform and accessed by API, instead of running locally.
+ # initialize model from model_path by using ColossalLLM
+ self.rag_config["mem_llm_kwargs"] = {"max_new_tokens": 50, "temperature": 1, "do_sample": True}
+ self.rag_config["disambig_llm_kwargs"] = {"max_new_tokens": 30, "temperature": 1, "do_sample": True}
+ self.rag_config["gen_llm_kwargs"] = {"max_new_tokens": 100, "temperature": 1, "do_sample": True}
+ self.colossal_api = ColossalAPI(llm_name, llm_path)
+ self.llm = ColossalLLM(n=1, api=self.colossal_api)
+
+ print(f"Initializing RAG Chain...")
+ print("RAG_CONFIG: ", self.rag_config)
+ self.__class__.chatbot = RAG_ChatBot(self.llm, self.rag_config)
+ print("Loading Files....\n", self.docs_files)
+ self.__class__.chatbot.load_doc_from_files(self.docs_files)
+ # -----------------------------------------------------------------------------------
+ res = {"response": f"文件上传完成,模型初始化完成,让我们开始对话吧!(后端模型:{llm_name})", "error": "", "conversation_ready": True}
+ except Exception as e:
+ res = {"response": "文件上传或模型初始化有误,无法开始对话。",
+ "error": f"Error in File Uploading and/or RAG initialization. Error details: {e}",
+ "conversation_ready": False}
+ # conversation_ready is True: Chatbot and docs are all set. Ready to chat.
+ else:
+ user_input = received_json["user_input"]
+ chatbot_response, self.__class__.chatbot.memory = self.__class__.chatbot.run(user_input, self.__class__.chatbot.memory)
+ res = {"response": chatbot_response, "error": "", "conversation_ready": True}
+ self._set_response()
+ self.wfile.write(json.dumps(res).encode("utf-8"))
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Chinese retrieval based conversation system")
+ parser.add_argument("--port", type=int, default=13666, help="port on localhost to start the server")
+ args = parser.parse_args()
+ server_address = ("localhost", args.port)
+ httpd = HTTPServer(server_address, ColossalQAServerRequestHandler)
+ print(f"Starting server on port {args.port}...")
+ httpd.serve_forever()
+
diff --git a/applications/ColossalQA/examples/webui_demo/start_colossal_qa.sh b/applications/ColossalQA/examples/webui_demo/start_colossal_qa.sh
new file mode 100755
index 000000000000..c9c7b71c3e90
--- /dev/null
+++ b/applications/ColossalQA/examples/webui_demo/start_colossal_qa.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+cleanup() {
+ echo "Caught Signal ... cleaning up."
+ pkill -P $$ # kill all subprocess of this script
+ exit 1 # exit script
+}
+# 'cleanup' is trigered when receive SIGINT(Ctrl+C) OR SIGTERM(kill) signal
+trap cleanup INT TERM
+
+# Disable your proxy
+# unset HTTP_PROXY HTTPS_PROXY http_proxy https_proxy
+
+# Path to store knowledge base(Home Directory by default)
+export TMP=$HOME
+
+# Use m3e as embedding model
+export EMB_MODEL="m3e" # moka-ai/m3e-base model will be download automatically
+# export EMB_MODEL_PATH="PATH_TO_LOCAL_CHECKPOINT/m3e-base" # you can also specify the local path to embedding model
+
+# Choose a backend LLM
+# - ChatGLM2
+# export CHAT_LLM="chatglm2"
+# export CHAT_LLM_PATH="PATH_TO_LOCAL_CHECKPOINT/chatglm2-6b"
+
+# - ChatGPT
+export CHAT_LLM="chatgpt"
+# Auth info for OpenAI API
+export OPENAI_API_KEY="YOUR_OPENAI_API_KEY"
+
+# - Pangu
+# export CHAT_LLM="pangu"
+# # Auth info for Pangu API
+# export URL=""
+# export USERNAME=""
+# export PASSWORD=""
+# export DOMAIN_NAME=""
+
+# Run server.py and colossalqa_webui.py in the background
+python server.py &
+python webui.py &
+
+# Wait for all processes to finish
+wait
diff --git a/applications/ColossalQA/examples/webui_demo/webui.py b/applications/ColossalQA/examples/webui_demo/webui.py
new file mode 100644
index 000000000000..2d2910b5adce
--- /dev/null
+++ b/applications/ColossalQA/examples/webui_demo/webui.py
@@ -0,0 +1,102 @@
+import json
+import os
+import gradio as gr
+import requests
+
+RAG_STATE = {"conversation_ready": False, # Conversation is not ready until files are uploaded and RAG chain is initialized
+ "embed_model_name": os.environ.get("EMB_MODEL", "m3e"),
+ "llm_name": os.environ.get("CHAT_LLM", "chatgpt")}
+URL = "http://localhost:13666"
+
+def get_response(client_data, URL):
+ headers = {"Content-type": "application/json"}
+ print(f"Sending request to server url: {URL}")
+ response = requests.post(URL, data=json.dumps(client_data), headers=headers)
+ response = json.loads(response.content)
+ return response
+
+def add_text(history, text):
+ history = history + [(text, None)]
+ return history, gr.update(value=None, interactive=True)
+
+def add_file(history, files):
+ global RAG_STATE
+ RAG_STATE["conversation_ready"] = False # after adding new files, reset the ChatBot
+ RAG_STATE["upload_files"]=[file.name for file in files]
+ files_string = "\n".join([os.path.basename(path) for path in RAG_STATE["upload_files"]])
+ print(files_string)
+ history = history + [(files_string, None)]
+ return history
+
+def bot(history):
+ print(history)
+ global RAG_STATE
+ if not RAG_STATE["conversation_ready"]:
+ # Upload files and initialize models
+ client_data = {
+ "docs": RAG_STATE["upload_files"],
+ "embed_model_name": RAG_STATE["embed_model_name"], # Select embedding model name here
+ "llm_name": RAG_STATE["llm_name"], # Select LLM model name here. ["pangu", "chatglm2"]
+ "conversation_ready": RAG_STATE["conversation_ready"]
+ }
+ else:
+ client_data = {}
+ client_data["conversation_ready"] = RAG_STATE["conversation_ready"]
+ client_data["user_input"] = history[-1][0].strip()
+
+ response = get_response(client_data, URL) # TODO: async request, to avoid users waiting the model initialization too long
+ print(response)
+ if response["error"] != "":
+ raise gr.Error(response["error"])
+
+ RAG_STATE["conversation_ready"] = response["conversation_ready"]
+ history[-1][1] = response["response"]
+ yield history
+
+
+CSS = """
+.contain { display: flex; flex-direction: column; height: 100vh }
+#component-0 { height: 100%; }
+#chatbot { flex-grow: 1; }
+"""
+
+header_html = """
+
+
ColossalQA
+
ColossalQA
+
+"""
+
+with gr.Blocks(css=CSS) as demo:
+ html = gr.HTML(header_html)
+ chatbot = gr.Chatbot(
+ [],
+ elem_id="chatbot",
+ bubble_full_width=False,
+ avatar_images=(
+ (os.path.join(os.path.dirname(__file__), "img/avatar_user.png")),
+ (os.path.join(os.path.dirname(__file__), "img/avatar_ai.png")),
+ ),
+ )
+
+ with gr.Row():
+ txt = gr.Textbox(
+ scale=4,
+ show_label=False,
+ placeholder="Enter text and press enter, or upload an image",
+ container=True,
+ autofocus=True,
+ )
+ btn = gr.UploadButton("📁", file_types=["file"], file_count="multiple")
+
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(bot, chatbot, chatbot)
+ # Clear the original textbox
+ txt_msg.then(lambda: gr.update(value=None, interactive=True), None, [txt], queue=False)
+ # Click Upload Button: 1. upload files 2. send config to backend, initalize model 3. get response "conversation_ready" = True/False
+ file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(bot, chatbot, chatbot)
+
+
+
+if __name__ == "__main__":
+ demo.queue()
+ demo.launch(share=True) # share=True will release a public link of the demo
diff --git a/applications/ColossalQA/pytest.ini b/applications/ColossalQA/pytest.ini
new file mode 100644
index 000000000000..9e84349f2285
--- /dev/null
+++ b/applications/ColossalQA/pytest.ini
@@ -0,0 +1,4 @@
+[pytest]
+markers =
+ dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs)
+ largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs)
\ No newline at end of file
diff --git a/applications/ColossalQA/requirements.txt b/applications/ColossalQA/requirements.txt
new file mode 100644
index 000000000000..2b86cfd0ed57
--- /dev/null
+++ b/applications/ColossalQA/requirements.txt
@@ -0,0 +1,22 @@
+transformers>=4.20.1
+tqdm==4.66.1
+datasets==2.13.0
+torch<2.0.0, >=1.12.1
+langchain==0.0.330
+langchain-experimental==0.0.37
+tokenizers==0.13.3
+modelscope==1.9.0
+sentencepiece==0.1.99
+gpustat==1.1.1
+sqlalchemy==2.0.20
+pytest==7.4.2
+# coati install from ../Chat
+sentence-transformers==2.2.2
+chromadb==0.4.9
+openai==0.28.0 #used for chatgpt please install directly from openai repo
+tiktoken==0.5.1
+unstructured==0.10.14
+pypdf==3.16.0
+jq==1.6.0
+gradio==3.44.4
+Requests==2.31.0
diff --git a/applications/ColossalQA/setup.py b/applications/ColossalQA/setup.py
new file mode 100644
index 000000000000..b1085942d6d6
--- /dev/null
+++ b/applications/ColossalQA/setup.py
@@ -0,0 +1,38 @@
+from setuptools import find_packages, setup
+
+
+def fetch_requirements(path):
+ with open(path, "r") as fd:
+ return [r.strip() for r in fd.readlines()]
+
+
+def fetch_readme():
+ with open("README.md", encoding="utf-8") as f:
+ return f.read()
+
+
+def fetch_version():
+ with open("version.txt", "r") as f:
+ return f.read().strip()
+
+
+print(find_packages(exclude=("tests", "*.egg-info", "data", "examples")))
+setup(
+ name="colossalqa",
+ version=fetch_version(),
+ packages=find_packages(exclude=("tests", "*.egg-info", "data", "examples")),
+ description="Colossal-AI powered retrieval QA",
+ long_description=fetch_readme(),
+ long_description_content_type="text/markdown",
+ license="Apache Software License 2.0",
+ url="https://github.com/hpcaitech/Coati",
+ install_requires=fetch_requirements("requirements.txt"),
+ python_requires=">=3.6",
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Environment :: GPU :: NVIDIA CUDA",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: System :: Distributed Computing",
+ ],
+)
diff --git a/applications/ColossalQA/tests/__init__.py b/applications/ColossalQA/tests/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/applications/ColossalQA/tests/test_document_loader.py b/applications/ColossalQA/tests/test_document_loader.py
new file mode 100644
index 000000000000..163b0d2cca22
--- /dev/null
+++ b/applications/ColossalQA/tests/test_document_loader.py
@@ -0,0 +1,21 @@
+import os
+from colossalqa.data_loader.document_loader import DocumentLoader
+
+
+def test_add_document():
+ PATH = os.environ.get('TEST_DOCUMENT_LOADER_DATA_PATH')
+ files = [[PATH, 'all data']]
+ document_loader = DocumentLoader(files)
+ documents = document_loader.all_data
+ all_files = []
+ for doc in documents:
+ assert isinstance(doc.page_content, str)==True
+ if doc.metadata['source'] not in all_files:
+ all_files.append(doc.metadata['source'])
+ print(all_files)
+ assert len(all_files) == 6
+
+
+if __name__=='__main__':
+ test_add_document()
+
diff --git a/applications/ColossalQA/tests/test_memory.py b/applications/ColossalQA/tests/test_memory.py
new file mode 100644
index 000000000000..a364eda5e582
--- /dev/null
+++ b/applications/ColossalQA/tests/test_memory.py
@@ -0,0 +1,117 @@
+import os
+
+from colossalqa.data_loader.document_loader import DocumentLoader
+from colossalqa.local.llm import ColossalAPI, ColossalLLM
+from colossalqa.memory import ConversationBufferWithSummary
+from colossalqa.prompt.prompt import PROMPT_RETRIEVAL_QA_ZH
+from colossalqa.retriever import CustomRetriever
+from langchain.embeddings import HuggingFaceEmbeddings
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+
+
+def test_memory_long():
+ model_path = os.environ.get("EN_MODEL_PATH")
+ data_path = os.environ.get("TEST_DATA_PATH_EN")
+ model_name = os.environ.get("EN_MODEL_NAME")
+ sql_file_path = os.environ.get("SQL_FILE_PATH")
+
+ if not os.path.exists(sql_file_path):
+ os.makedirs(sql_file_path)
+
+ colossal_api = ColossalAPI.get_api(model_name, model_path)
+ llm = ColossalLLM(n=4, api=colossal_api)
+ memory = ConversationBufferWithSummary(
+ llm=llm, max_tokens=600, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
+ )
+ retriever_data = DocumentLoader([[data_path, "company information"]]).all_data
+
+ # Split
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
+ splits = text_splitter.split_documents(retriever_data)
+
+ embedding = HuggingFaceEmbeddings(
+ model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
+ )
+
+ # Create retriever
+ information_retriever = CustomRetriever(k=3, sql_file_path=sql_file_path)
+ information_retriever.add_documents(docs=splits, cleanup="incremental", mode="by_source", embedding=embedding)
+
+ memory.initiate_document_retrieval_chain(
+ llm,
+ PROMPT_RETRIEVAL_QA_ZH,
+ information_retriever,
+ chain_type_kwargs={
+ "chat_history": "",
+ },
+ )
+
+ # This keep the prompt length excluding dialogues the same
+ docs = information_retriever.get_relevant_documents("this is a test input.")
+ prompt_length = memory.chain.prompt_length(docs, **{"question": "this is a test input.", "chat_history": ""})
+ remain = 600 - prompt_length
+ have_summarization_flag = False
+ for i in range(40):
+ chat_history = memory.load_memory_variables({"question": "this is a test input.", "input_documents": docs})[
+ "chat_history"
+ ]
+
+ assert memory.get_conversation_length() <= remain
+ memory.save_context({"question": "this is a test input."}, {"output": "this is a test output."})
+ if "A summarization of historical conversation:" in chat_history:
+ have_summarization_flag = True
+ assert have_summarization_flag == True
+
+
+def test_memory_short():
+ model_path = os.environ.get("EN_MODEL_PATH")
+ data_path = os.environ.get("TEST_DATA_PATH_EN")
+ model_name = os.environ.get("EN_MODEL_NAME")
+ sql_file_path = os.environ.get("SQL_FILE_PATH")
+
+ if not os.path.exists(sql_file_path):
+ os.makedirs(sql_file_path)
+
+ colossal_api = ColossalAPI.get_api(model_name, model_path)
+ llm = ColossalLLM(n=4, api=colossal_api)
+ memory = ConversationBufferWithSummary(
+ llm=llm, llm_kwargs={"max_new_tokens": 50, "temperature": 0.6, "do_sample": True}
+ )
+ retriever_data = DocumentLoader([[data_path, "company information"]]).all_data
+
+ # Split
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=20)
+ splits = text_splitter.split_documents(retriever_data)
+
+ embedding = HuggingFaceEmbeddings(
+ model_name="moka-ai/m3e-base", model_kwargs={"device": "cpu"}, encode_kwargs={"normalize_embeddings": False}
+ )
+
+ # create retriever
+ information_retriever = CustomRetriever(k=3, sql_file_path=sql_file_path)
+ information_retriever.add_documents(docs=splits, cleanup="incremental", mode="by_source", embedding=embedding)
+
+ memory.initiate_document_retrieval_chain(
+ llm,
+ PROMPT_RETRIEVAL_QA_ZH,
+ information_retriever,
+ chain_type_kwargs={
+ "chat_history": "",
+ },
+ )
+
+ # This keep the prompt length excluding dialogues the same
+ docs = information_retriever.get_relevant_documents("this is a test input.", return_scores=True)
+
+ for i in range(4):
+ chat_history = memory.load_memory_variables({"question": "this is a test input.", "input_documents": docs})[
+ "chat_history"
+ ]
+ assert chat_history.count("Assistant: this is a test output.") == i
+ assert chat_history.count("Human: this is a test input.") == i
+ memory.save_context({"question": "this is a test input."}, {"output": "this is a test output."})
+
+
+if __name__ == "__main__":
+ test_memory_short()
+ test_memory_long()
diff --git a/applications/ColossalQA/tests/test_retrieval_qa.py b/applications/ColossalQA/tests/test_retrieval_qa.py
new file mode 100644
index 000000000000..76867b562e8f
--- /dev/null
+++ b/applications/ColossalQA/tests/test_retrieval_qa.py
@@ -0,0 +1,62 @@
+import os
+
+from colossalqa.retrieval_conversation_universal import UniversalRetrievalConversation
+
+
+def test_en_retrievalQA():
+ data_path_en = os.environ.get('TEST_DATA_PATH_EN')
+ data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
+ en_model_path = os.environ.get('EN_MODEL_PATH')
+ zh_model_path = os.environ.get('ZH_MODEL_PATH')
+ zh_model_name = os.environ.get('ZH_MODEL_NAME')
+ en_model_name = os.environ.get('EN_MODEL_NAME')
+ sql_file_path = os.environ.get('SQL_FILE_PATH')
+ qa_session = UniversalRetrievalConversation(files_en=[{
+ 'data_path': data_path_en,
+ 'name': 'company information',
+ 'separator': '\n'
+ }],
+ files_zh=[{
+ 'data_path': data_path_zh,
+ 'name': 'company information',
+ 'separator': '\n'
+ }],
+ zh_model_path=zh_model_path,
+ en_model_path=en_model_path,
+ zh_model_name=zh_model_name,
+ en_model_name=en_model_name,
+ sql_file_path=sql_file_path)
+ ans = qa_session.run("which company runs business in hotel industry?", which_language='en')
+ print(ans)
+
+
+def test_zh_retrievalQA():
+ data_path_en = os.environ.get('TEST_DATA_PATH_EN')
+ data_path_zh = os.environ.get('TEST_DATA_PATH_ZH')
+ en_model_path = os.environ.get('EN_MODEL_PATH')
+ zh_model_path = os.environ.get('ZH_MODEL_PATH')
+ zh_model_name = os.environ.get('ZH_MODEL_NAME')
+ en_model_name = os.environ.get('EN_MODEL_NAME')
+ sql_file_path = os.environ.get('SQL_FILE_PATH')
+ qa_session = UniversalRetrievalConversation(files_en=[{
+ 'data_path': data_path_en,
+ 'name': 'company information',
+ 'separator': '\n'
+ }],
+ files_zh=[{
+ 'data_path': data_path_zh,
+ 'name': 'company information',
+ 'separator': '\n'
+ }],
+ zh_model_path=zh_model_path,
+ en_model_path=en_model_path,
+ zh_model_name=zh_model_name,
+ en_model_name=en_model_name,
+ sql_file_path=sql_file_path)
+ ans = qa_session.run("哪家公司在经营酒店业务?", which_language='zh')
+ print(ans)
+
+
+if __name__ == "__main__":
+ test_en_retrievalQA()
+ test_zh_retrievalQA()
diff --git a/applications/ColossalQA/tests/test_text_splitter.py b/applications/ColossalQA/tests/test_text_splitter.py
new file mode 100644
index 000000000000..90755a9bcb1c
--- /dev/null
+++ b/applications/ColossalQA/tests/test_text_splitter.py
@@ -0,0 +1,11 @@
+from colossalqa.text_splitter.chinese_text_splitter import ChineseTextSplitter
+
+
+def test_text_splitter():
+ # unit test
+ spliter = ChineseTextSplitter(chunk_size=30, chunk_overlap=0)
+ out = spliter.split_text(
+ "移动端语音唤醒模型,检测关键词为“小云小云”。模型主体为4层FSMN结构,使用CTC训练准则,参数量750K,适用于移动端设备运行。模型输入为Fbank特征,输出为基于char建模的中文全集token预测,测试工具根据每一帧的预测数据进行后处理得到输入音频的实时检测结果。模型训练采用“basetrain + finetune”的模式,basetrain过程使用大量内部移动端数据,在此基础上,使用1万条设备端录制安静场景“小云小云”数据进行微调,得到最终面向业务的模型。后续用户可在basetrain模型基础上,使用其他关键词数据进行微调,得到新的语音唤醒模型,但暂时未开放模型finetune功能。"
+ )
+ print(len(out))
+ assert len(out) == 4 # ChineseTextSplitter will not break sentence. Hence the actual chunk size is not 30
diff --git a/applications/ColossalQA/version.txt b/applications/ColossalQA/version.txt
new file mode 100644
index 000000000000..8a9ecc2ea99d
--- /dev/null
+++ b/applications/ColossalQA/version.txt
@@ -0,0 +1 @@
+0.0.1
\ No newline at end of file
diff --git a/applications/README.md b/applications/README.md
index f5078e06a73b..92096e5593d0 100644
--- a/applications/README.md
+++ b/applications/README.md
@@ -8,6 +8,7 @@ The list of applications include:
- [X] [ColossalEval](./ColossalEval): Evaluation Pipeline for LLMs.
- [X] [ColossalChat](./Chat/README.md): Replication of ChatGPT with RLHF.
- [X] [FastFold](https://github.com/hpcaitech/FastFold): Optimizing AlphaFold (Biomedicine) Training and Inference on GPU Clusters.
+- [X] [ColossalQA](./ColossalQA/README.md): Document Retrieval Conversation System
> Please note that the `Chatbot` application is migrated from the original `ChatGPT` folder.
diff --git a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
index 79661a44424f..439d13dcfc11 100644
--- a/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py
@@ -8,6 +8,7 @@
from torch import Tensor
from colossalai.logging import get_dist_logger
+from colossalai.utils.device import get_current_device
__all__ = ["BaseGradScaler"]
@@ -22,7 +23,7 @@ class BaseGradScaler(ABC):
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
- self._scale = torch.cuda.FloatTensor([initial_scale])
+ self._scale = torch.tensor([initial_scale], device=get_current_device(), dtype=torch.float)
self._verbose = verbose
if self._verbose:
diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
index 65133a4b3712..86ba919ee696 100644
--- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
@@ -5,6 +5,8 @@
import torch
+from colossalai.utils.device import get_current_device
+
from .base_grad_scaler import BaseGradScaler
__all__ = ["DynamicGradScaler"]
@@ -37,12 +39,12 @@ def __init__(
):
super().__init__(initial_scale, verbose)
if min_scale:
- self._min_scale = torch.cuda.FloatTensor([min_scale])
+ self._min_scale = torch.tensor([min_scale], device=get_current_device(), dtype=torch.float)
else:
self._min_scale = None
if max_scale:
- self._max_scale = torch.cuda.FloatTensor([max_scale])
+ self._max_scale = torch.tensor([max_scale], device=get_current_device(), dtype=torch.float)
else:
self._max_scale = None
@@ -115,7 +117,7 @@ def state_dict(self):
return state_dict
def load_state_dict(self, state_dict):
- self._scale = state_dict["scale"].cuda(torch.cuda.current_device())
+ self._scale = state_dict["scale"].to(get_current_device())
self._growth_factor = state_dict["growth_factor"]
self._backoff_factor = state_dict["backoff_factor"]
self._hysteresis = state_dict["hysteresis"]
diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py
index a6b4904f2617..a6628e29c2bc 100644
--- a/colossalai/auto_parallel/offload/solver.py
+++ b/colossalai/auto_parallel/offload/solver.py
@@ -11,7 +11,7 @@
import torch
from torch.fx.node import Node
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from .region import Region
from .training_simulator import AsynTrainingSimulator, SynTrainingSimulator, TrainingSimulator
diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py
index 7dce6e6da33e..443c4094c0e1 100644
--- a/colossalai/booster/mixed_precision/fp16_torch.py
+++ b/colossalai/booster/mixed_precision/fp16_torch.py
@@ -6,6 +6,7 @@
from torch.optim import Optimizer
from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.utils.device import autocast
from .mixed_precision_base import MixedPrecision
@@ -88,7 +89,7 @@ def __init__(self, module: nn.Module):
super().__init__(module)
def forward(self, *args, **kwargs):
- with torch.cuda.amp.autocast():
+ with autocast():
return self.module(*args, **kwargs)
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 20a931b816ea..261080dc9d20 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -5,7 +5,9 @@
from typing import Callable, Iterator, List, Optional, Tuple
import torch
+import torch.distributed as dist
import torch.nn as nn
+from torch.distributed.distributed_c10d import _get_default_group
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
@@ -19,9 +21,11 @@
save_state_dict,
save_state_dict_shards,
)
-from colossalai.cluster import DistCoordinator
+from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device
+from colossalai.utils.device import IS_NPU_AVAILABLE
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats
@@ -32,6 +36,26 @@
SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
+ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
+
+
+def get_param_info(optim: Optimizer):
+ # Get a backup of necessary information of parameters for future use, which includes:
+ # 1. A mapping from integer param_id to param32 shape.
+
+ if optim is None:
+ return {}
+ param_info = {"id2shape": {}}
+ start_index = 0
+ for group in optim.param_groups:
+ for param_id, param in enumerate(group["params"], start_index):
+ original_shape = param.shape if isinstance(param, torch.Tensor) else None
+ param_info["id2shape"][param_id] = original_shape
+
+ start_index += len(group["params"])
+
+ return param_info
+
class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None:
@@ -150,24 +174,24 @@ def save_sharded_optimizer(
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
+ index_file.append_meta_data("param_groups", param_group_file)
# Store the information of param groups to param_group_file.
- index_file.append_meta_data("param_groups", param_group_file)
- group_file_path = os.path.join(checkpoint, param_group_file)
- param_groups = optimizer.get_param_groups_for_saving()
- torch.save(param_groups, group_file_path)
+ if self.coordinator.is_master():
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ param_groups = optimizer.get_param_groups_for_saving()
+ torch.save(param_groups, group_file_path)
# States are broken into shards within max_shard_size.
state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True)
# Save shards of optimizer states.
- is_master = self.coordinator.is_master()
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
- is_master=is_master,
+ is_master=self.coordinator.is_master(),
use_safetensors=False,
)
@@ -284,6 +308,16 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`.
+ tp_size (int, optional): If 'tp_size' is set to be greater than 1, it means using tensor parallelism strategy, which is implemented in Shardformer, 'tp_size' determines the size of the tensor parallel process group. Default to 1.
+ extra_dp_size (int, optional): If 'extra_dp_size' is set to be greater than 1, it means creating another group to run with a ddp-like strategy. Default to 1.
+ enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
+ Currently all the optimization methods include fused normalization, flash attention and JIT.
+ Defaults to False.
+ enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
+ enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
+ enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
+ enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
+ enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
"""
@@ -317,10 +351,20 @@ def __init__(
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
+ tp_size: int = 1,
+ extra_dp_size: int = 1,
+ enable_all_optimization: bool = False,
+ enable_fused_normalization: bool = False,
+ enable_flash_attention: bool = False,
+ enable_sequence_parallelism: bool = False,
+ enable_jit_fused: bool = False,
+ enable_sequence_overlap: bool = False,
verbose: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
+ if IS_NPU_AVAILABLE:
+ assert placement_policy == "static", "NPU only supports static placement policy"
self.gemini_config = dict(
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
@@ -355,8 +399,41 @@ def __init__(
max_norm=max_norm,
norm_type=norm_type,
)
+ self.enable_tensor_parallelism = tp_size > 1
+ self.enable_all_optimization = enable_all_optimization
+ self.enable_fused_normalization = enable_fused_normalization
+ self.enable_flash_attention = enable_flash_attention
+ self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
+ self.enable_jit_fused = enable_jit_fused
+ self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose
+ self.tp_size = tp_size
+ self.extra_dp_size = extra_dp_size
+ world_size = dist.get_world_size()
+ self.zero_size = world_size // (self.tp_size * self.extra_dp_size)
+ assert (
+ world_size == (self.tp_size * self.extra_dp_size) * self.zero_size
+ ), f"The global group size can't be evenly divided by the subgroup size."
+
+ self.pg_mesh = ProcessGroupMesh(self.zero_size, self.extra_dp_size, self.tp_size)
+ self.zero_group = (
+ self.pg_mesh.get_group_along_axis(ZERO_AXIS) if self.zero_size < world_size else _get_default_group()
+ )
+ self.extra_dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) if self.extra_dp_size > 1 else None
+ self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS) if self.tp_size > 1 else None
+
+ self.shard_config = ShardConfig(
+ tensor_parallel_process_group=self.tp_group,
+ enable_tensor_parallelism=self.enable_tensor_parallelism,
+ enable_all_optimization=self.enable_all_optimization,
+ enable_fused_normalization=self.enable_fused_normalization,
+ enable_flash_attention=self.enable_flash_attention,
+ enable_jit_fused=self.enable_jit_fused,
+ enable_sequence_parallelism=self.enable_sequence_parallelism,
+ enable_sequence_overlap=self.enable_sequence_overlap,
+ )
+
def support_no_sync(self) -> bool:
return False
@@ -370,7 +447,7 @@ def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ["cuda"]
+ return ["cuda", "npu"]
def configure(
self,
@@ -380,6 +457,7 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
+ optimizer_params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
@@ -391,11 +469,27 @@ def configure(
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
- model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)
+ if self.enable_tensor_parallelism:
+ shardformer = ShardFormer(self.shard_config)
+ model, _ = shardformer.optimize(model)
+
+ model = GeminiDDP(
+ model,
+ **self.gemini_config,
+ zero_group=self.zero_group,
+ extra_dp_group=self.extra_dp_group,
+ verbose=self.verbose,
+ )
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
- optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
+ optimizer,
+ model,
+ **self.zero_optim_config,
+ **self.optim_kwargs,
+ tp_group=self.tp_group,
+ optimizer_params_info=optimizer_params_info,
+ verbose=self.verbose,
)
return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py
index 72c3ec46ae75..bbc36ceab2ec 100644
--- a/colossalai/booster/plugin/hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py
@@ -1,6 +1,6 @@
import ctypes
import random
-from contextlib import nullcontext
+from contextlib import contextmanager
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
@@ -25,9 +25,11 @@
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor
from colossalai.zero.low_level import LowLevelZeroOptimizer
+from colossalai.utils.device import get_current_device
from .pp_plugin_base import PipelinePluginBase
@@ -47,12 +49,17 @@ def __init__(
precision: str,
shard_config: ShardConfig,
dp_group: ProcessGroup,
+ tp_group: ProcessGroup,
use_ddp: bool,
ddp_config: dict,
custom_policy: Policy,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
+ self.shard_config = shard_config
self.dp_group = dp_group
+ self.tp_group = tp_group
+ self.use_dpp = use_ddp
+ self.require_grad_sync = True
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
@@ -75,7 +82,7 @@ def __init__(
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
- module = module.cuda()
+ module = module.to(get_current_device())
# setting input type cast when using mixed precision
self.convert_fn = None
@@ -98,19 +105,75 @@ def sync_shared_params(self):
dist.all_reduce(param.grad, group=group)
dist.barrier()
- def no_sync(self) -> Iterator[None]:
- # no sync grads across data parallel
- return nullcontext()
+ @contextmanager
+ def no_sync(self):
+ r"""
+ A context manager to disable automatic gradient synchronization (all-reduce) and allow manual synchronization
+ when 'no_sync' is active. Alternatively, synchronization will occur in the first forward-backward pass
+ when exiting the context.
+ """
+
+ # Store the current value of 'require_grad_sync' to restore it later.
+ old_require_grad_sync = self.require_grad_sync
+ # Disable automatic gradient synchronization.
+ self.require_grad_sync = False
+ try:
+ if self.use_dpp:
+ # If using data parallel processing (use_dpp), disable synchronization too.
+ with self.module.no_sync():
+ yield
+ else:
+ yield
+ finally:
+ # Restore the original value of 'require_grad_sync'.
+ self.require_grad_sync = old_require_grad_sync
+
+ def sync_dp_grads(self):
+ r"""
+ Synchronize gradients across data parallelism (DP) if the DP group size is greater than 1.
+ This function performs an all-reduce operation to combine gradients from different devices in the DP group.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
- def sync_grads(self):
- # sync grad across data parallel
+ # Check if the DP group size is 1, meaning no synchronization is needed.
if self.dp_group.size() == 1:
return
+
+ # Iterate through the model's parameters and perform gradient synchronization.
for p in self.module.parameters():
if p.grad is not None:
+ # Perform all-reduce to combine gradients from different devices.
dist.all_reduce(p.grad, group=self.dp_group)
+ # Normalize the gradient by dividing it by the DP group size.
p.grad.div_(self.dp_group.size())
+ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
+ r"""
+ Synchronize gradients that are partially derived within sequence parallelism
+ if sequence parallelism is enabled. Gradients can be provided explicitly or extracted
+ from the module.
+
+ Args:
+ grads (Optional[List[torch.Tensor]]): A list of gradient tensors to synchronize. If not
+ provided, gradients will be extracted from the model.
+
+ Returns:
+ None
+ """
+
+ if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
+ if grads is not None:
+ # Synchronize provided gradient tensors across the tensor parallelism group.
+ SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, grads=grads)
+ else:
+ # Synchronize gradients from the model across the tensor parallelism group.
+ SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_group, model=self.module)
+
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
@@ -166,7 +229,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(
self,
optim: Optimizer,
- model: Module,
+ model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
max_norm: float = 0,
@@ -176,13 +239,69 @@ def __init__(
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
+ self.model = model
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.max_norm = max_norm
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
+ self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
+ self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
super().__init__(optim)
+ def backward(self, loss: Tensor, *args, **kwargs):
+ r"""
+ Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
+
+ This method performs backward pass for gradient computation. If sequence parallelism is enabled
+ and gradient synchronization is required, it will synchronize gradients that are partially derived
+ within sequence parallelism across tp parallelism groups.
+
+ Args:
+ loss (Tensor): The loss tensor to compute gradients with respect to.
+ *args: Additional positional arguments to be passed to the superclass backward method.
+ **kwargs: Additional keyword arguments to be passed to the superclass backward method.
+
+ Returns:
+ None
+ """
+
+ # Call the superclass backward method to compute gradients.
+ super().backward(loss, *args, **kwargs)
+
+ if self.model.require_grad_sync:
+ # If gradient synchronization is required, sync sequence parallelism gradients.
+ self.model.sync_sp_grads()
+ else:
+ # If gradient synchronization is is not required, return.
+ return
+
+ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
+ """
+ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
+
+ This method performs a backward pass for gradient computation using a precomputed gradient tensor.
+ If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
+ gradients that are partially derived within sequence parallelism across tp parallelism groups.
+
+ Args:
+ tensor (Tensor): The input tensor for which gradients are computed.
+ grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
+
+ Returns:
+ None
+ """
+
+ # Call the superclass backward method to compute gradients.
+ super().backward_by_grad(tensor, grad)
+
+ if self.model.require_grad_sync:
+ # If gradient synchronization is required, sync sequence parallelism gradients.
+ self.model.sync_sp_grads()
+ else:
+ # If gradient synchronization is is not required, return.
+ return
+
def step(self, *args, **kwargs):
r"""
Perform an optimization step.
@@ -220,8 +339,6 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
if len(param_gradient_pairs) == 0:
return 0.0
- tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
- pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
# gradients used for norm calculation.
@@ -229,10 +346,10 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
- if tp_size > 1:
+ total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
+ if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
- if pp_size > 1:
+ if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
else:
@@ -250,16 +367,16 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
- if tp_size > 1:
+ if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
- grad_norm_exponentiated /= tp_size
+ grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
- if pp_size > 1:
+ if self.pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_shared_param = shared_param[self.stage_manager.stage]
@@ -268,11 +385,11 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
total_norm_exponentiated += grad_norm_exponentiated
- total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
- if tp_size > 1:
+ total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
+ if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
- if pp_size > 1:
+ if self.pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
@@ -314,7 +431,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
self,
optim: Optimizer,
- model: Module,
+ model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = "fp16",
@@ -329,11 +446,14 @@ def __init__(
tp_process_group: Optional[ProcessGroup] = None, # if using tp
pp_process_group: Optional[ProcessGroup] = None, # if using pp
):
+ self.model = model
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
self.tp_pg = tp_process_group
self.pp_pg = pp_process_group
+ self.tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
+ self.pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(
@@ -349,6 +469,59 @@ def __init__(
max_norm=max_norm,
)
+ def backward(self, loss: Tensor, *args, **kwargs):
+ r"""
+ Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
+
+ This method performs backward pass for gradient computation. If sequence parallelism is enabled
+ and gradient synchronization is required, it will synchronize gradients that are partially derived
+ within sequence parallelism across tp parallelism groups.
+
+ Args:
+ loss (Tensor): The loss tensor to compute gradients with respect to.
+ *args: Additional positional arguments to be passed to the superclass backward method.
+ **kwargs: Additional keyword arguments to be passed to the superclass backward method.
+
+ Returns:
+ None
+ """
+
+ # Call the superclass backward method to compute gradients.
+ super().backward(loss, *args, **kwargs)
+
+ if self.model.require_grad_sync:
+ # If gradient synchronization is required, sync sequence parallelism gradients.
+ self.model.sync_sp_grads()
+ else:
+ # If gradient synchronization is is not required, return.
+ return
+
+ def backward_by_grad(self, tensor: Tensor, grad: Tensor):
+ """
+ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
+
+ This method performs a backward pass for gradient computation using a precomputed gradient tensor.
+ If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
+ gradients that are partially derived within sequence parallelism across tp parallelism groups.
+
+ Args:
+ tensor (Tensor): The input tensor for which gradients are computed.
+ grad (Tensor): The precomputed gradient tensor to compute gradients with respect to the input tensor.
+
+ Returns:
+ None
+ """
+
+ # Call the superclass backward method to compute gradients.
+ super().backward_by_grad(tensor, grad)
+
+ if self.model.require_grad_sync:
+ # If gradient synchronization is required, sync sequence parallelism gradients.
+ self.model.sync_sp_grads()
+ else:
+ # If gradient synchronization is is not required, return.
+ return
+
def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_type: int = 2) -> int:
r"""
Compute and return the gradient norm for gradient clipping.
@@ -363,8 +536,6 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
if len(param_gradient_pairs) == 0:
return 0.0
- tp_size = get_world_size(self.tp_pg) if self.tp_pg is not None else 1
- pp_size = get_world_size(self.pp_pg) if self.pp_pg is not None else 1
norm_type = float(norm_type)
if norm_type == inf:
@@ -372,11 +543,11 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
# so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm = super()._compute_grad_norm(param_gradient_pairs, norm_type)
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
+ total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
- if tp_size > 1:
+ if self.tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
- if pp_size > 1:
+ if self.pp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.pp_pg)
total_norm = total_norm_cuda.item()
@@ -396,16 +567,16 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
# Consequently, there is no need to perform an 'all_reduce' operation for 'grad_norm'.
# However, we still perform the 'all_reduce' operation for the sake of good coding practices.
# To ensure mathematical equivalence, we divide the 'grad_norm' by 'tp_size.'
- if tp_size > 1:
+ if self.tp_size > 1:
param_for_grad = grad_to_param_mapping[id(grad)]
if not is_distributed_tensor(param_for_grad):
- grad_norm_exponentiated /= tp_size
+ grad_norm_exponentiated /= self.tp_size
# If 'pp_size' is greater than 1 and the gradient belongs to shared parameters,
# it means that this parameter is used in two different pipeline stages.
# To avoid redundant norm calculations, we divide the exponent of this norm by
# the number of shared stages.
- if pp_size > 1:
+ if self.pp_size > 1:
for shared_param in self.shared_params:
if self.stage_manager.stage in shared_param:
stage_working_shared_param = shared_param[self.stage_manager.stage]
@@ -415,11 +586,11 @@ def _compute_grad_norm(self, param_gradient_pairs: List[Tuple[Tensor]], norm_typ
total_norm_exponentiated += grad_norm_exponentiated
- total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
- if tp_size > 1:
+ total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
+ if self.tp_size > 1:
# compute norm in tp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.tp_pg)
- if pp_size > 1:
+ if self.pp_size > 1:
# compute norm in pp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.pp_pg)
@@ -433,7 +604,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
def __init__(
self,
optimizer: Optimizer,
- model: Module,
+ model: HybridParallelModule,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
@@ -455,6 +626,7 @@ def __init__(
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
):
+ self.model = model
self.param_info = param_info
self.stage_manager = model.stage_manager
self.shared_params = model.shared_params
@@ -483,6 +655,123 @@ def __init__(
forced_dtype=forced_dtype,
)
+ def sync_dp_grads(self):
+ r"""
+ Synchronize gradients in the data parallelism dimension.
+
+ This method wraps the existing `_sync_grad` method in order to explicitly synchronize gradients
+ in the data parallelism dimension. It is necessary due to the introduction of new parallel dimensions,
+ namely tp (tensor parallelism) and pp (pipeline parallelism). This ensures better code organization
+ and readability.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
+
+ # Call the superclass `_sync_grad` method to synchronize gradients.
+ super()._sync_grad()
+
+ def _sync_sp_grads(self):
+ r"""
+ Synchronize gradients that are partially derived within sequence parallelism.
+
+ This method is responsible for synchronizing partially derived gradients across tp parallelism groups.
+ It identifies gradients that ara partially derived or not and synchronizes them.
+ If synchronization is required and gradients are found to be synchronized,
+ it performs the synchronization.
+
+ Args:
+ None
+
+ Returns:
+ None
+ """
+
+ def _get_all_working_grads() -> List[Tensor]:
+ """Retrieve all working gradients from different parameter groups."""
+ all_working_grads = []
+ for group_id in range(self.num_param_groups):
+ working_grads = self._grad_store.get_working_grads_by_group_id(group_id)
+ all_working_grads.extend(working_grads)
+ return all_working_grads
+
+ def _get_grads_to_sync(all_working_grads) -> Union[List[Tensor], None]:
+ """Identify gradients to be synchronized in the sequence parallelism."""
+ grads_to_sync = []
+ for grad in all_working_grads:
+ param_id_for_grad = self._grad_store.get_param_id_for_grad(grad)
+ param_for_grad = ctypes.cast(param_id_for_grad, ctypes.py_object).value
+ if SeqParallelUtils.is_sp_partial_derived_param(param_for_grad):
+ grads_to_sync.append(grad)
+
+ if len(grads_to_sync) > 0:
+ return grads_to_sync
+ else:
+ return None
+
+ # Get all working gradients and gradients to be synchronized.
+ all_working_grads = _get_all_working_grads()
+ grads_to_sync = _get_grads_to_sync(all_working_grads)
+
+ if self.require_grad_sync and grads_to_sync is not None:
+ # Synchronize sequence parallelism gradients if required.
+ SeqParallelUtils.allreduce_partial_data_grad(tp_group=self.tp_pg, grads=grads_to_sync)
+ else:
+ return
+
+ def backward(self, loss, retain_graph=False):
+ """
+ Backpropagate gradients through the model and optionally synchronize sequence parallelism gradients.
+
+ This method performs the backward pass for gradient computation based on a given loss tensor.
+ If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
+ gradients that are partially derived within sequence parallelism across TP parallelism groups.
+
+ Args:
+ loss: The loss tensor to compute gradients with respect to.
+ retain_graph (bool): Whether to retain the computation graph.
+
+ Returns:
+ None
+ """
+ # Call the superclass backward method to compute gradients.
+ super().backward(loss, retain_graph)
+
+ if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
+ # If gradient synchronization is required, sync sequence parallelism gradients.
+ self._sync_sp_grads()
+ else:
+ # If gradient synchronization is is not required, return.
+ return
+
+ def backward_by_grad(self, tensor, grad):
+ """
+ Backpropagate gradients through the model using a precomputed gradient and optionally synchronize sequence parallelism gradients.
+
+ This method performs a backward pass for gradient computation based on a precomputed gradient tensor.
+ If sequence parallelism is enabled and gradient synchronization is required, it will synchronize
+ gradients that are partially derived within sequence parallelism across TP parallelism groups.
+
+ Args:
+ tensor: The input tensor for which gradients are computed.
+ grad: The precomputed gradient tensor to compute gradients with respect to the input tensor.
+
+ Returns:
+ None
+ """
+ # Call the superclass backward_by_grad method to compute gradients.
+ super().backward_by_grad(tensor, grad)
+
+ if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
+ # If gradient synchronization is required, sync sequence parallelism gradients.
+ self._sync_sp_grads()
+ else:
+ # If gradient synchronization is is not required, return.
+ return
+
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.
@@ -509,7 +798,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
# so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm = super()._compute_grad_norm(gradients, norm_type)
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
+ total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float32)
if tp_size > 1:
dist.all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX, group=self.tp_pg)
@@ -548,7 +837,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
total_norm_exponentiated += grad_norm_exponentiated
- total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
+ total_norm_exponentiated_cuda = torch.tensor([float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float32)
if dp_size > 1:
# compute norm in dp process group
dist.all_reduce(tensor=total_norm_exponentiated_cuda, op=dist.ReduceOp.SUM, group=self.dp_pg)
@@ -739,7 +1028,7 @@ def enable_pipeline_parallelism(self) -> bool:
return self.pp_size > 1
def supported_devices(self) -> List[str]:
- return ["cuda"]
+ return ["cuda", "npu"]
def supported_precisions(self) -> List[str]:
return ["fp16", "bf16", "fp32"]
@@ -768,7 +1057,14 @@ def configure(
if not isinstance(model, ModelWrapper):
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(
- model, self.precision, self.shard_config, self.dp_group, use_ddp, self.ddp_config, self.custom_policy
+ model,
+ precision=self.precision,
+ shard_config=self.shard_config,
+ dp_group=self.dp_group,
+ tp_group=self.tp_group,
+ use_ddp=use_ddp,
+ ddp_config=self.ddp_config,
+ custom_policy=self.custom_policy,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
@@ -826,17 +1122,32 @@ def execute_pipeline(
return_outputs: bool = False,
) -> dict:
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
- # return loss or outputs if needed
+
+ # Create a context for gradient synchronization based on the optimizer type.
+ # If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
+ # This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),
+ # so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
+
with ctx:
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
+
+ # Synchronize the grads of shared parameters of the model.
model.sync_shared_params()
+
+ # Synchronize sequence parallelism gradients of the model.
+ model.sync_sp_grads()
+
+ # Check if the optimizer is a HybridParallelZeroOptimizer and synchronize data parallelism gradients if so.
+ # Otherwise, synchronize data parallelism gradients of the model.
+ # This is because these are two different forms of data parallelism.
if isinstance(optimizer, HybridParallelZeroOptimizer):
- optimizer.sync_grad()
+ optimizer.sync_dp_grads()
else:
- model.sync_grads()
+ model.sync_dp_grads()
+
return outputs
def prepare_dataloader(
diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py
index dc78fe8c094c..89102820cd38 100644
--- a/colossalai/booster/plugin/low_level_zero_plugin.py
+++ b/colossalai/booster/plugin/low_level_zero_plugin.py
@@ -119,11 +119,12 @@ def save_sharded_optimizer(
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
+ index_file.append_meta_data("param_groups", param_group_file)
# Store the information of param groups to param_group_file.
- index_file.append_meta_data("param_groups", param_group_file)
- group_file_path = os.path.join(checkpoint, param_group_file)
- save_param_groups(state_dict, group_file_path)
+ if self.coordinator.is_master():
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
total_size = 0
@@ -305,7 +306,7 @@ def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
- return ["cuda"]
+ return ["cuda", "npu"]
def configure(
self,
diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
new file mode 100644
index 000000000000..e976d0aaf014
--- /dev/null
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -0,0 +1,386 @@
+import random
+from types import MethodType
+from typing import Callable, Optional, OrderedDict, Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+from torch.nn import Module
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from colossalai.booster.plugin.hybrid_parallel_plugin import (
+ HybridParallelAMPOptimizer,
+ HybridParallelModule,
+ HybridParallelNaiveOptimizer,
+ HybridParallelPlugin,
+ get_param_info,
+ init_pipeline_optimizer,
+)
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.interface import ModelWrapper, OptimizerWrapper
+from colossalai.moe import MoECheckpintIO
+from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig
+from colossalai.shardformer.policies.base_policy import Policy
+from colossalai.zero.low_level import LowLevelZeroOptimizer
+
+PP_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
+
+
+class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ model: Module,
+ use_pipeline: bool,
+ param_info: OrderedDict,
+ initial_scale: int = 2**16, # grad scaler config
+ min_scale: int = 1,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000,
+ hysteresis: int = 2,
+ max_scale: int = 2**24,
+ clip_grad_norm: float = 0.0, # grad clipping
+ verbose: bool = False,
+ reduce_bucket_size: int = 1024 * 1024, # communication
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ partition_grad: bool = False, # stage 2 flag
+ cpu_offload: bool = False, # cpu offload
+ dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
+ tp_process_group: Optional[ProcessGroup] = None, # if using tp
+ pp_process_group: Optional[ProcessGroup] = None,
+ forced_dtype: Optional[torch.dtype] = None,
+ moe_extra_dp_process_group: Optional[ProcessGroup] = None,
+ ):
+ self.param_info = param_info
+ self.stage_manager = model.stage_manager
+ self.shared_params = model.shared_params
+ self.dp_pg = dp_process_group
+ self.tp_pg = tp_process_group
+ self.pp_pg = pp_process_group
+ if use_pipeline:
+ init_pipeline_optimizer(optimizer, model)
+ super().__init__(
+ optimizer=optimizer,
+ initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale,
+ clip_grad_norm=clip_grad_norm,
+ verbose=verbose,
+ reduce_bucket_size=reduce_bucket_size,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ partition_grad=partition_grad,
+ cpu_offload=cpu_offload,
+ dp_process_group=dp_process_group,
+ forced_dtype=forced_dtype,
+ moe_extra_dp_process_group=moe_extra_dp_process_group,
+ )
+
+
+class MoeHybridParallelPlugin(HybridParallelPlugin):
+ """
+ Plugin for Moe Hybrid Parallel Training.
+ Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
+ The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
+
+ Example:
+ >>> from colossalai.booster import Booster
+ >>> from colossalai.booster.plugin import HybridParallelPlugin
+
+ >>> model, train_dataset, optimizer, criterion = ...
+ >>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
+
+ >>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
+ >>> booster = Booster(plugin=plugin)
+ >>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
+
+ Args:
+ tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
+ pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
+ precision (str, optional): Specifies the precision of parameters during training.
+ Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
+ Defaults to 'fp16'.
+ zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
+ When set to 0, ZeRO will not be used. Defaults to 0.
+ enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
+ Currently all the optimization methods include fused normalization, flash attention and JIT.
+ Defaults to False.
+ enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
+ enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
+ enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
+ enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
+ enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
+ num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
+ microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
+ Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
+ If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
+ initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
+ min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
+ growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
+ backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
+ growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
+ hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
+ max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
+ max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
+ broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
+ ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
+ find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
+ check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
+ gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
+ static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
+ zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
+ cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
+ communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
+ overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ tp_size: int,
+ pp_size: int,
+ extra_dp_size: int = 1,
+ precision: str = "fp16",
+ zero_stage: int = 0,
+ enable_all_optimization: bool = False,
+ enable_fused_normalization: bool = False,
+ enable_flash_attention: bool = False,
+ enable_jit_fused: bool = False,
+ enable_sequence_parallelism: bool = False,
+ enable_sequence_overlap: bool = False,
+ num_microbatches: Optional[int] = None,
+ microbatch_size: Optional[int] = None,
+ initial_scale: float = 2**16,
+ min_scale: float = 1,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ max_scale: float = 2**32,
+ max_norm: float = 0,
+ broadcast_buffers: bool = True,
+ ddp_bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False,
+ zero_bucket_size_in_m: int = 12,
+ cpu_offload: bool = False,
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = True,
+ use_ep_inside: bool = True,
+ custom_policy: Policy = None,
+ ) -> None:
+ assert (
+ dist.get_world_size() % (tp_size * pp_size) == 0
+ ), f"world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}"
+
+ if enable_sequence_parallelism:
+ assert tp_size > 1, "Sequence parallelism must be enabled when using tensor parallelism"
+
+ self.tp_size = tp_size
+ self.pp_size = pp_size
+ self.dp_size = dist.get_world_size() // (tp_size * pp_size)
+ self.precision = precision
+ self.zero_stage = zero_stage
+ self.cpu_offload = cpu_offload
+ self.enable_all_optimization = enable_all_optimization
+ self.enable_fused_normalization = enable_fused_normalization
+ self.enable_flash_attention = enable_flash_attention
+ self.enable_jit_fused = enable_jit_fused
+ self.enable_sequence_parallelism = enable_sequence_parallelism
+ # we change pg mesh to (pp, dp, tp) for better moe performance
+ self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
+
+ # sync moe in outer dp group, and sync other param in global dp group
+ if extra_dp_size > 1:
+ ep_size = self.dp_size // extra_dp_size
+ if use_ep_inside:
+ self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, extra_dp_size, ep_size)
+ self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(1)
+ if dist.get_rank() == 0:
+ print(f"Zero Parallel: pp {self.pp_size}, outer_dp {extra_dp_size}, inner_dp {ep_size}")
+ else:
+ self.pg_mesh_moe = ProcessGroupMesh(self.pp_size, ep_size, extra_dp_size)
+ self.moe_extra_dp_group = self.pg_mesh_moe.get_group_along_axis(2)
+ if dist.get_rank() == 0:
+ print(f"Zero Parallel: pp {self.pp_size}, outer_dp {ep_size}, inner_dp {extra_dp_size}")
+ else:
+ self.moe_extra_dp_group = None
+
+ self.stage_manager = None
+ self.schedule = None
+ self.custom_policy = custom_policy
+ assert zero_stage in (0, 1, 2)
+ if self.pp_size > 1:
+ assert (
+ num_microbatches is not None or microbatch_size is not None
+ ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
+ assert self.zero_stage <= 1, "zero stage must be 0 or 1 when using pipeline parallelism"
+ self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
+ self.schedule = OneForwardOneBackwardSchedule(
+ self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
+ )
+ self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
+ self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
+ self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
+ self.shard_config = ShardConfig(
+ tensor_parallel_process_group=self.tp_group,
+ pipeline_stage_manager=self.stage_manager,
+ enable_tensor_parallelism=self.tp_size > 1,
+ enable_all_optimization=self.enable_all_optimization,
+ enable_fused_normalization=self.enable_fused_normalization,
+ enable_flash_attention=self.enable_flash_attention,
+ enable_jit_fused=self.enable_jit_fused,
+ enable_sequence_parallelism=enable_sequence_parallelism,
+ enable_sequence_overlap=enable_sequence_overlap,
+ )
+ self.amp_config = dict(
+ initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ )
+
+ self.ddp_config = dict(
+ broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=ddp_bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph,
+ )
+
+ self.zero_config = dict(
+ reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
+ communication_dtype=communication_dtype,
+ overlap_communication=overlap_communication,
+ cpu_offload=cpu_offload,
+ partition_grad=(self.zero_stage == 2),
+ )
+
+ self.max_norm = max_norm
+
+ def prepare_dataloader(
+ self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
+ ):
+ r"""
+ Prepare a dataloader for distributed training. The dataloader will be wrapped by
+ `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
+
+
+ Args:
+ dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
+ seed (int, optional): Random worker seed for sampling, defaults to 1024.
+ add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
+ drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
+ is not divisible by the batch size. If False and the size of dataset is not divisible by
+ the batch size, then the last batch will be smaller, defaults to False.
+ pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
+ num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
+ kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
+ `DataLoader `_.
+
+ Returns:
+ :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
+ """
+ _kwargs = kwargs.copy()
+ sampler = DistributedSampler(
+ dataset, num_replicas=self.pg_mesh.size(DP_AXIS), rank=self.pg_mesh.coordinate(DP_AXIS), shuffle=shuffle
+ )
+
+ # Deterministic dataloader
+ def seed_worker(worker_id):
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs,
+ )
+
+ def get_checkpoint_io(self) -> MoECheckpintIO:
+ self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
+ return self.checkpoint_io
+
+ def configure(
+ self,
+ model: Module,
+ optimizer: Optional[Optimizer] = None,
+ criterion: Optional[Callable] = None,
+ dataloader: Optional[DataLoader] = None,
+ lr_scheduler: Optional[LRScheduler] = None,
+ ) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
+ param_info = get_param_info(optimizer)
+ if not isinstance(model, ModelWrapper):
+ use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
+ model = HybridParallelModule(
+ module=model,
+ precision=self.precision,
+ shard_config=self.shard_config,
+ dp_group=self.dp_group,
+ tp_group=self.tp_group,
+ use_ddp=use_ddp,
+ ddp_config=self.ddp_config,
+ custom_policy=self.custom_policy,
+ )
+ if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
+ if self.zero_stage == 0:
+ if self.precision in ["fp16", "bf16"]:
+ optimizer = HybridParallelAMPOptimizer(
+ optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ precision=self.precision,
+ max_norm=self.max_norm,
+ **self.amp_config,
+ )
+ else:
+ optimizer = HybridParallelNaiveOptimizer(
+ optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
+ )
+ else:
+ assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
+ assert self.precision != "fp32", "Please set precision to 'fp16' or 'bf16' when using ZeRO."
+ optimizer = HybridParallelZeroOptimizer(
+ optimizer,
+ model,
+ use_pipeline=self.enable_pipeline_parallelism,
+ param_info=param_info,
+ dp_process_group=self.dp_group,
+ tp_process_group=self.tp_group,
+ pp_process_group=self.pp_group,
+ moe_extra_dp_process_group=self.moe_extra_dp_group,
+ verbose=True,
+ clip_grad_norm=self.max_norm,
+ **self.zero_config,
+ **self.amp_config,
+ )
+ # inject update_master_params
+ model.update_master_params = MethodType(optimizer.update_master_params, model)
+
+ return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
index 779ff42d75a1..b7900bc0f217 100644
--- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
@@ -1,4 +1,5 @@
import copy
+from functools import reduce
import logging
import os
from pathlib import Path
@@ -313,9 +314,13 @@ def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, s
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
+ missing_keys = []
+ missing_file_keys = []
+
def _load(name: str):
if name not in weight_map:
- raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
+ missing_file_keys.append(name)
+ return
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
@@ -324,7 +329,6 @@ def _load(name: str):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
- missing_keys = []
load_state_dict_into_model(
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
@@ -357,6 +361,27 @@ def _load(name: str):
if self.verbose and self.coordinator.is_master():
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+ if len(missing_keys) == 0:
+ raise RuntimeError(
+ "No weigth is loaded into the model. Please check the checkpoint files and the model structure."
+ )
+
+ remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
+ remain_keys = remain_keys.union(set(missing_file_keys))
+ if len(remain_keys) > 0:
+ if strict:
+ error_msgs = "Missing key(s) in state_dict: {}. ".format(
+ ", ".join('"{}"'.format(k) for k in missing_keys)
+ )
+ raise RuntimeError(
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
+ self.__class__.__name__, "\n\t".join(error_msgs)
+ )
+ )
+ else:
+ if self.coordinator.is_master():
+ logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")
+
def save_sharded_optimizer(
self,
optimizer: OptimizerWrapper,
diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py
index 06dab1fdb72a..e1800f29b0af 100644
--- a/colossalai/checkpoint_io/utils.py
+++ b/colossalai/checkpoint_io/utils.py
@@ -11,6 +11,7 @@
import torch.nn as nn
from packaging.version import Version
from torch.optim import Optimizer
+from torch.utils._pytree import tree_map
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
@@ -293,7 +294,6 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Helper functions for saving state dict
# ======================================
-
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
@@ -303,6 +303,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
checkpoint_file_path (str): path to the checkpoint file.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
+ # Move all tensors in the state_dict to CPU before saving to avoid serialization issues
+ state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
+
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith(
@@ -310,9 +313,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
- safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
+ safe_save_file(state_dict_cpu, checkpoint_file_path, metadata={"format": "pt"})
else:
- torch.save(state_dict, checkpoint_file_path)
+ torch.save(state_dict_cpu, checkpoint_file_path)
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
diff --git a/colossalai/cluster/process_group_mesh.py b/colossalai/cluster/process_group_mesh.py
index 3885bc962561..7a3bde44869c 100644
--- a/colossalai/cluster/process_group_mesh.py
+++ b/colossalai/cluster/process_group_mesh.py
@@ -1,3 +1,4 @@
+import gc
import itertools
from functools import reduce
from operator import mul
@@ -44,6 +45,24 @@ def __init__(self, *size: int) -> None:
self._ranks_to_group: Dict[Tuple[int, ...], ProcessGroup] = {}
self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
+ def __del__(self):
+ r"""
+ Destructor method for the ProcessGroupMesh class.
+
+ When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
+ cleaning up any process groups that were created during the lifetime of the object.
+
+ Note:
+ All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
+ when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
+ system resources.
+ """
+ for group in self._ranks_to_group.values():
+ dist.destroy_process_group(group)
+
+ # Manually clear all process groups to save memory
+ gc.collect()
+
@property
def shape(self) -> Tuple[int, ...]:
return self._shape
@@ -206,3 +225,4 @@ def get_group_along_axis(
# no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group]
+
\ No newline at end of file
diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py
index ab57301bb910..3e94b7cfe993 100644
--- a/colossalai/context/__init__.py
+++ b/colossalai/context/__init__.py
@@ -1,7 +1,5 @@
from .config import Config, ConfigException
-# from .moe_context import MOE_CONTEXT
-
__all__ = [
"Config",
"ConfigException",
diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py
deleted file mode 100644
index 066dfc7222e1..000000000000
--- a/colossalai/context/moe_context.py
+++ /dev/null
@@ -1,132 +0,0 @@
-from typing import Tuple
-
-import torch
-import torch.distributed as dist
-
-from colossalai.context.singleton_meta import SingletonMeta
-from colossalai.legacy.tensor import ProcessGroup
-
-
-def _check_sanity():
- from colossalai.legacy.core import global_context as gpc
-
- if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
- raise NotImplementedError("Moe is not compatible with tensor or " "pipeline parallel at present.")
-
-
-class MoeParallelInfo:
- """Moe parallelism information, storing parallel sizes and groups."""
-
- def __init__(self, ep_size: int, dp_size: int):
- _check_sanity()
- self.ep_size = ep_size
- self.dp_size = dp_size
- self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
- self.ep_group = self.pg.tp_process_group()
- self.dp_group = self.pg.dp_process_group()
-
-
-class MoeContext(metaclass=SingletonMeta):
- """MoE parallel context manager. This class manages different
- parallel groups in MoE context and MoE loss in training.
- """
-
- def __init__(self):
- self.world_size = 1
- # Users may want to set maximum expert parallel size smaller than the world size
- # since very low bandwidth across nodes may constrain the performance of MoE
- # When we have a maximum expert parallel size, we have a minimum data parallel size naturally
- self.max_ep_size = 1
- self.min_dp_size = 1
- self.aux_loss = None
- self.use_kernel_optim = True
-
- self.has_setup = False
- self._parallel_info_dict = dict()
-
- @property
- def parallel_info_dict(self):
- return self._parallel_info_dict
-
- @property
- def is_initialized(self):
- return self.has_setup
-
- def setup(self, seed: int, use_kernel_optim: bool = True):
- assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
- _check_sanity()
- assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
-
- self.world_size = dist.get_world_size()
-
- from colossalai.legacy.core import global_context as gpc
-
- self.max_ep_size = gpc.config.get("max_ep_size", self.world_size)
- assert (
- self.world_size % self.max_ep_size == 0
- ), "Maximum expert parallel size must be a factor of the number of GPUs"
- self.min_dp_size = self.world_size // self.max_ep_size
-
- # Enabling kernel optimization may raise error in some cases
- # Users can close kernel optimization manually
- self.use_kernel_optim = use_kernel_optim
-
- from .random import moe_set_seed
-
- moe_set_seed(seed)
- self.has_setup = True
-
- def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
- """Calculate the Data Parallel Group and Expert Parallel Group.
-
- Parameters
- ----------
- num_experts : int
- The number experts
-
- Returns
- -------
- int, MoeParallelInfo
- number of local experts, the MoeParallelInfo of the current ep_size
- """
-
- gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
- lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
-
- assert gt_flag or lt_flag, (
- "Automatic experts placement dose not not support expert number"
- " is not a multiple of ep size or vice versa."
- )
-
- # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
- # there are multiple experts in each GPU and each GPU has different experts
- # So it's data parallel size is 1
- # Otherwise, there is only one expert in each GPU
- # The data parallel size should be calculated
- dp_size = 1 if gt_flag else self.max_ep_size // num_experts
- ep_size = self.max_ep_size // dp_size
-
- # Calculate the number of experts for each GPU
- num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
-
- # Don't forget to multiply minimum data parallel size
- dp_size *= self.min_dp_size
- if not (ep_size in self.parallel_info_dict):
- self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
-
- return num_local_experts, self.parallel_info_dict[ep_size]
-
- def set_kernel_not_use(self):
- self.use_kernel_optim = False
-
- def reset_loss(self):
- self.aux_loss = 0
-
- def add_loss(self, loss):
- self.aux_loss += loss
-
- def get_loss(self):
- return self.aux_loss
-
-
-MOE_CONTEXT = MoeContext()
diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py
index 72f199203a9d..3949590e8378 100644
--- a/colossalai/device/device_mesh.py
+++ b/colossalai/device/device_mesh.py
@@ -38,7 +38,7 @@ class DeviceMesh:
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
"""
- _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
+ _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo", "npu": "hccl"}
def __init__(
self,
diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md
index d0c281e057b3..dfac7cfd9be9 100644
--- a/colossalai/inference/README.md
+++ b/colossalai/inference/README.md
@@ -1,6 +1,14 @@
# 🚀 Colossal-Inference
-## Table of contents
+
+## Table of Contents
+
+- [💡 Introduction](#introduction)
+- [🔗 Design](#design)
+- [🔨 Usage](#usage)
+ - [Quick start](#quick-start)
+ - [Example](#example)
+- [📊 Performance](#performance)
## Introduction
@@ -8,22 +16,23 @@
## Design
-Colossal Inference is composed of two main components:
+Colossal Inference is composed of three main components:
1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly.
2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference.
1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release.
2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch.
3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods.
- 1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference:
+ 1. `HybridEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel, pipline parallel) inference:
2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama)
3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way.
-## Pipeline of inference:
+
+## Architecture of inference:
In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
-
+
## Roadmap of our implementation
@@ -34,11 +43,15 @@ In this section we discuss how the colossal inference works and integrates with
- [x] policy
- [x] context forward
- [x] token forward
-- [ ] Replace the kernels with `faster-transformer` in token-forward stage
-- [ ] Support all models
+ - [x] support flash-decoding
+- [x] Support all models
- [x] Llama
+ - [x] Llama-2
- [x] Bloom
- - [ ] Chatglm2
+ - [x] Chatglm2
+- [x] Quantization
+ - [x] GPTQ
+ - [x] SmoothQuant
- [ ] Benchmarking for all models
## Get started
@@ -51,23 +64,19 @@ pip install -e .
### Requirements
-dependencies
+Install dependencies.
```bash
-pytorch= 1.13.1 (gpu)
-cuda>= 11.6
-transformers= 4.30.2
-triton==2.0.0.dev20221202
-# for install vllm, please use this branch to install https://github.com/tiandiao123/vllm/tree/setup_branch
-vllm
-# for install flash-attention, please use commit hash: 67ae6fd74b4bc99c36b2ce524cf139c35663793c
-flash-attention
-
-# install lightllm since we depend on lightllm triton kernels
-git clone https://github.com/ModelTC/lightllm
-git checkout 28c1267cfca536b7b4f28e921e03de735b003039
-cd lightllm
-pip3 install -e .
+pip install -r requirements/requirements-infer.txt
+
+# if you want use smoothquant quantization, please install torch-int
+git clone --recurse-submodules https://github.com/Guangxuan-Xiao/torch-int.git
+cd torch-int
+git checkout 65266db1eadba5ca78941b789803929e6e6c6856
+pip install -r requirements.txt
+source environment.sh
+bash build_cutlass.sh
+python setup.py install
```
### Docker
@@ -83,22 +92,60 @@ docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcai
cd /path/to/CollossalAI
pip install -e .
-# install lightllm
-git clone https://github.com/ModelTC/lightllm
-git checkout 28c1267cfca536b7b4f28e921e03de735b003039
-cd lightllm
-pip3 install -e .
-
-
```
-### Dive into fast-inference!
+## Usage
+### Quick start
example files are in
```bash
-cd colossalai.examples
-python xx
+cd ColossalAI/examples
+python hybrid_llama.py --path /path/to/model --tp_size 2 --pp_size 2 --batch_size 4 --max_input_size 32 --max_out_len 16 --micro_batch_size 2
+```
+
+
+
+### Example
+```python
+# import module
+from colossalai.inference import CaiInferEngine
+import colossalai
+from transformers import LlamaForCausalLM, LlamaTokenizer
+
+#launch distributed environment
+colossalai.launch_from_torch(config={})
+
+# load original model and tokenizer
+model = LlamaForCausalLM.from_pretrained("/path/to/model")
+tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
+
+# generate token ids
+input = ["Introduce a landmark in London","Introduce a landmark in Singapore"]
+data = tokenizer(input, return_tensors='pt')
+
+# set parallel parameters
+tp_size=2
+pp_size=2
+max_output_len=32
+micro_batch_size=1
+
+# initial inference engine
+engine = CaiInferEngine(
+ tp_size=tp_size,
+ pp_size=pp_size,
+ model=model,
+ max_output_len=max_output_len,
+ micro_batch_size=micro_batch_size,
+)
+
+# inference
+output = engine.generate(data)
+
+# get results
+if dist.get_rank() == 0:
+ assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
+
```
## Performance
@@ -113,7 +160,9 @@ For various models, experiments were conducted using multiple batch sizes under
Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to further optimize the performance of LLM models. Please stay tuned.
-#### Llama
+### Tensor Parallelism Inference
+
+##### Llama
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
@@ -122,7 +171,7 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc

-### Bloom
+#### Bloom
| batch_size | 8 | 16 | 32 |
| :---------------------: | :----: | :----: | :----: |
@@ -131,4 +180,50 @@ Currently the stats below are calculated based on A100 (single GPU), and we calc

+
+### Pipline Parallelism Inference
+We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G. We set input length=1024, output length=128.
+
+
+#### A10 7b, fp16
+
+| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
+| :-------------------------: | :---: | :---:| :---: | :---: | :---: | :---: |
+| Pipeline Inference | 40.35 | 77.10| 139.03| 232.70| 257.81| OOM |
+| Hugging Face | 41.43 | 65.30| 91.93 | 114.62| OOM | OOM |
+
+
+
+
+#### A10 13b, fp16
+
+| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
+| :---: | :---: | :---: | :---: | :---: |
+| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
+| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
+
+
+
+
+#### A800 7b, fp16
+
+| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
+| :---: | :---: | :---: | :---: | :---: | :---: |
+| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
+| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
+
+
+
+### Quantization LLama
+
+| batch_size | 8 | 16 | 32 |
+| :---------------------: | :----: | :----: | :----: |
+| auto-gptq | 199.20 | 232.56 | 253.26 |
+| smooth-quant | 142.28 | 222.96 | 300.59 |
+| colossal-gptq | 231.98 | 388.87 | 573.03 |
+
+
+
+
+
The results of more models are coming soon!
diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py
index 35891307e754..a95205efaa78 100644
--- a/colossalai/inference/__init__.py
+++ b/colossalai/inference/__init__.py
@@ -1,3 +1,4 @@
-from .pipeline import PPInferEngine
+from .engine import InferenceEngine
+from .engine.policies import BloomModelInferPolicy, ChatGLM2InferPolicy, LlamaModelInferPolicy
-__all__ = ["PPInferEngine"]
+__all__ = ["InferenceEngine", "LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy"]
diff --git a/colossalai/inference/engine/__init__.py b/colossalai/inference/engine/__init__.py
new file mode 100644
index 000000000000..6e60da695a22
--- /dev/null
+++ b/colossalai/inference/engine/__init__.py
@@ -0,0 +1,3 @@
+from .engine import InferenceEngine
+
+__all__ = ["InferenceEngine"]
diff --git a/colossalai/inference/engine/engine.py b/colossalai/inference/engine/engine.py
new file mode 100644
index 000000000000..61da5858aa86
--- /dev/null
+++ b/colossalai/inference/engine/engine.py
@@ -0,0 +1,195 @@
+from typing import Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from transformers.utils import logging
+
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.pipeline.schedule.generate import GenerateSchedule
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.base_policy import Policy
+
+from ..kv_cache import MemoryManager
+from .microbatch_manager import MicroBatchManager
+from .policies import model_policy_map
+
+PP_AXIS, TP_AXIS = 0, 1
+
+_supported_models = [
+ "LlamaForCausalLM",
+ "BloomForCausalLM",
+ "LlamaGPTQForCausalLM",
+ "SmoothLlamaForCausalLM",
+ "ChatGLMForConditionalGeneration",
+]
+
+
+class InferenceEngine:
+ """
+ InferenceEngine is a class that handles the pipeline parallel inference.
+
+ Args:
+ tp_size (int): the size of tensor parallelism.
+ pp_size (int): the size of pipeline parallelism.
+ dtype (str): the data type of the model, should be one of 'fp16', 'fp32', 'bf16'.
+ model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
+ model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model. It will be determined by the model type if not provided.
+ micro_batch_size (int): the micro batch size. Only useful when `pp_size` > 1.
+ micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
+ max_batch_size (int): the maximum batch size.
+ max_input_len (int): the maximum input length.
+ max_output_len (int): the maximum output length.
+ quant (str): the quantization method, should be one of 'smoothquant', 'gptq', None.
+ verbose (bool): whether to return the time cost of each step.
+
+ """
+
+ def __init__(
+ self,
+ tp_size: int = 1,
+ pp_size: int = 1,
+ dtype: str = "fp16",
+ model: nn.Module = None,
+ model_policy: Policy = None,
+ micro_batch_size: int = 1,
+ micro_batch_buffer_size: int = None,
+ max_batch_size: int = 4,
+ max_input_len: int = 32,
+ max_output_len: int = 32,
+ quant: str = None,
+ verbose: bool = False,
+ # TODO: implement early_stopping, and various gerneration options
+ early_stopping: bool = False,
+ do_sample: bool = False,
+ num_beams: int = 1,
+ ) -> None:
+ if quant == "gptq":
+ from ..quant.gptq import GPTQManager
+
+ self.gptq_manager = GPTQManager(model.quantize_config, max_input_len=max_input_len)
+ model = model.model
+ elif quant == "smoothquant":
+ model = model.model
+
+ assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
+ assert (
+ tp_size * pp_size == dist.get_world_size()
+ ), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
+ assert model, "Model should be provided."
+ assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
+
+ assert max_batch_size <= 64, "Max batch size exceeds the constraint"
+ assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
+ assert quant in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'"
+ self.pp_size = pp_size
+ self.tp_size = tp_size
+ self.quant = quant
+
+ logger = logging.get_logger(__name__)
+ if quant == "smoothquant" and dtype != "fp32":
+ dtype = "fp32"
+ logger.warning_once("Warning: smoothquant only support fp32 and int8 mix precision. set dtype to fp32")
+
+ if dtype == "fp16":
+ self.dtype = torch.float16
+ model.half()
+ elif dtype == "bf16":
+ self.dtype = torch.bfloat16
+ model.to(torch.bfloat16)
+ else:
+ self.dtype = torch.float32
+
+ if model_policy is None:
+ model_policy = model_policy_map[model.config.model_type]()
+
+ # Init pg mesh
+ pg_mesh = ProcessGroupMesh(pp_size, tp_size)
+
+ stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True if pp_size * tp_size > 1 else False)
+ self.cache_manager_list = [
+ self._init_manager(model, max_batch_size, max_input_len, max_output_len)
+ for _ in range(micro_batch_buffer_size or pp_size)
+ ]
+ self.mb_manager = MicroBatchManager(
+ stage_manager.stage,
+ micro_batch_size,
+ micro_batch_buffer_size or pp_size,
+ max_input_len,
+ max_output_len,
+ self.cache_manager_list,
+ )
+ self.verbose = verbose
+ self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
+
+ self.model = self._shardformer(
+ model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS) if pp_size * tp_size > 1 else None
+ )
+ if quant == "gptq":
+ self.gptq_manager.post_init_gptq_buffer(self.model)
+
+ def generate(self, input_list: Union[list, dict]):
+ """
+ Args:
+ input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
+
+ Returns:
+ out (list): a list of output data, each element is a list of token.
+ timestamp (float): the time cost of the inference, only return when verbose is `True`.
+ """
+
+ out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
+ if self.verbose:
+ return out, timestamp
+ else:
+ return out
+
+ def _shardformer(self, model, model_policy, stage_manager, tp_group):
+ shardconfig = ShardConfig(
+ tensor_parallel_process_group=tp_group,
+ pipeline_stage_manager=stage_manager,
+ enable_tensor_parallelism=(self.tp_size > 1),
+ enable_fused_normalization=False,
+ enable_all_optimization=False,
+ enable_flash_attention=False,
+ enable_jit_fused=False,
+ enable_sequence_parallelism=False,
+ extra_kwargs={"quant": self.quant},
+ )
+ shardformer = ShardFormer(shard_config=shardconfig)
+ shard_model, _ = shardformer.optimize(model, model_policy)
+ return shard_model.cuda()
+
+ def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
+ max_total_token_num = max_batch_size * (max_input_len + max_output_len)
+ if model.config.model_type == "llama":
+ head_dim = model.config.hidden_size // model.config.num_attention_heads
+ head_num = model.config.num_key_value_heads // self.tp_size
+ num_hidden_layers = (
+ model.config.num_hidden_layers
+ if hasattr(model.config, "num_hidden_layers")
+ else model.config.num_layers
+ )
+ layer_num = num_hidden_layers // self.pp_size
+ elif model.config.model_type == "bloom":
+ head_dim = model.config.hidden_size // model.config.n_head
+ head_num = model.config.n_head // self.tp_size
+ num_hidden_layers = model.config.n_layer
+ layer_num = num_hidden_layers // self.pp_size
+ elif model.config.model_type == "chatglm":
+ head_dim = model.config.hidden_size // model.config.num_attention_heads
+ if model.config.multi_query_attention:
+ head_num = model.config.multi_query_group_num // self.tp_size
+ else:
+ head_num = model.config.num_attention_heads // self.tp_size
+ num_hidden_layers = model.config.num_layers
+ layer_num = num_hidden_layers // self.pp_size
+ else:
+ raise NotImplementedError("Only support llama, bloom and chatglm model.")
+
+ if self.quant == "smoothquant":
+ cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
+ else:
+ cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
+ return cache_manager
diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/engine/microbatch_manager.py
similarity index 72%
rename from colossalai/inference/pipeline/microbatch_manager.py
rename to colossalai/inference/engine/microbatch_manager.py
index 49d1bf3f42cb..d698c89f9936 100644
--- a/colossalai/inference/pipeline/microbatch_manager.py
+++ b/colossalai/inference/engine/microbatch_manager.py
@@ -1,8 +1,10 @@
from enum import Enum
-from typing import Dict, Tuple
+from typing import Dict
import torch
+from ..kv_cache import BatchInferState, MemoryManager
+
__all__ = "MicroBatchManager"
@@ -27,21 +29,19 @@ class MicroBatchDescription:
def __init__(
self,
inputs_dict: Dict[str, torch.Tensor],
- output_dict: Dict[str, torch.Tensor],
- new_length: int,
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager: MemoryManager,
) -> None:
- assert output_dict.get("hidden_states") is not None
- self.mb_length = output_dict["hidden_states"].shape[-2]
- self.target_length = self.mb_length + new_length
- self.kv_cache = ()
-
- def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
- if output_dict is not None:
- self._update_kvcache(output_dict["past_key_values"])
+ self.mb_length = inputs_dict["input_ids"].shape[-1]
+ self.target_length = self.mb_length + max_output_len
+ self.infer_state = BatchInferState.init_from_batch(
+ batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
+ )
+ # print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}")
- def _update_kvcache(self, kv_cache: Tuple):
- assert type(kv_cache) == tuple
- self.kv_cache = kv_cache
+ def update(self, *args, **kwargs):
+ pass
@property
def state(self):
@@ -75,22 +75,24 @@ class HeadMicroBatchDescription(MicroBatchDescription):
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
- new_length (int): the new length of the input sequence.
"""
def __init__(
- self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
+ self,
+ inputs_dict: Dict[str, torch.Tensor],
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager: MemoryManager,
) -> None:
- super().__init__(inputs_dict, output_dict, new_length)
+ super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
assert inputs_dict is not None
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
self.input_ids = inputs_dict["input_ids"]
self.attn_mask = inputs_dict["attention_mask"]
self.new_tokens = None
- def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
- super().update(output_dict, new_token)
+ def update(self, new_token: torch.Tensor = None):
if new_token is not None:
self._update_newtokens(new_token)
if self.state is not Status.DONE and new_token is not None:
@@ -125,16 +127,16 @@ class BodyMicroBatchDescription(MicroBatchDescription):
Args:
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
- output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""
def __init__(
- self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
+ self,
+ inputs_dict: Dict[str, torch.Tensor],
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager: MemoryManager,
) -> None:
- super().__init__(inputs_dict, output_dict, new_length)
-
- def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
- super().update(output_dict, new_token)
+ super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
@property
def cur_length(self):
@@ -142,10 +144,7 @@ def cur_length(self):
When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
"""
- if len(self.kv_cache) == 0:
- return self.mb_length
- else:
- return self.kv_cache[0][0].shape[-2] + 1
+ return self.infer_state.seq_len.max().item()
class MicroBatchManager:
@@ -154,22 +153,41 @@ class MicroBatchManager:
Args:
stage (int): stage id of current stage.
- new_length (int): the new length of the input sequence.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
"""
- def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
+ def __init__(
+ self,
+ stage: int,
+ micro_batch_size: int,
+ micro_batch_buffer_size: int,
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager_list: MemoryManager,
+ ):
self.stage = stage
- self.new_length = new_length
self.micro_batch_size = micro_batch_size
self.buffer_size = micro_batch_buffer_size
+ self.max_input_len = max_input_len
+ self.max_output_len = max_output_len
+ self.cache_manager_list = cache_manager_list
self.mb_descrption_buffer = {}
self.new_tokens_buffer = {}
self.idx = 0
- def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
+ def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
+ if self.stage == 0:
+ self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
+ inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
+ )
+ else:
+ self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
+ inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
+ )
+
+ def step(self, new_token: torch.Tensor = None):
"""
Update the state if microbatch manager, 2 conditions.
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
@@ -181,11 +199,7 @@ def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, ne
new_token (torch.Tensor): the new token generated by current stage.
"""
# Add descrption first if the descrption is None
- if inputs_dict is None and output_dict is None and new_token is None:
- return Status.PREFILL
- if self.mb_descrption_buffer.get(self.idx) is None:
- self._add_descrption(inputs_dict, output_dict)
- self.cur_descrption.update(output_dict, new_token)
+ self.cur_descrption.update(new_token)
return self.cur_state
def export_new_tokens(self):
@@ -204,16 +218,12 @@ def is_micro_batch_done(self):
def clear(self):
self.mb_descrption_buffer.clear()
+ for cache in self.cache_manager_list:
+ cache.free_all()
def next(self):
self.idx = (self.idx + 1) % self.buffer_size
- def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]):
- if self.stage == 0:
- self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length)
- else:
- self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length)
-
def _remove_descrption(self):
self.mb_descrption_buffer.pop(self.idx)
@@ -222,10 +232,10 @@ def cur_descrption(self) -> MicroBatchDescription:
return self.mb_descrption_buffer.get(self.idx)
@property
- def cur_kv_cache(self):
+ def cur_infer_state(self):
if self.cur_descrption is None:
return None
- return self.cur_descrption.kv_cache
+ return self.cur_descrption.infer_state
@property
def cur_state(self):
diff --git a/colossalai/inference/engine/modeling/__init__.py b/colossalai/inference/engine/modeling/__init__.py
new file mode 100644
index 000000000000..8a9e9999d3c5
--- /dev/null
+++ b/colossalai/inference/engine/modeling/__init__.py
@@ -0,0 +1,5 @@
+from .bloom import BloomInferenceForwards
+from .chatglm2 import ChatGLM2InferenceForwards
+from .llama import LlamaInferenceForwards
+
+__all__ = ["LlamaInferenceForwards", "BloomInferenceForwards", "ChatGLM2InferenceForwards"]
diff --git a/colossalai/inference/tensor_parallel/modeling/_utils.py b/colossalai/inference/engine/modeling/_utils.py
similarity index 100%
rename from colossalai/inference/tensor_parallel/modeling/_utils.py
rename to colossalai/inference/engine/modeling/_utils.py
diff --git a/colossalai/inference/engine/modeling/bloom.py b/colossalai/inference/engine/modeling/bloom.py
new file mode 100644
index 000000000000..4c098d3e4c80
--- /dev/null
+++ b/colossalai/inference/engine/modeling/bloom.py
@@ -0,0 +1,452 @@
+import math
+import warnings
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.distributed as dist
+from torch.nn import functional as F
+from transformers.models.bloom.modeling_bloom import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BloomAttention,
+ BloomBlock,
+ BloomForCausalLM,
+ BloomModel,
+)
+from transformers.utils import logging
+
+from colossalai.inference.kv_cache.batch_infer_state import BatchInferState
+from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+try:
+ from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import (
+ context_attention_fwd as lightllm_bloom_context_attention_fwd,
+ )
+
+ HAS_LIGHTLLM_KERNEL = True
+except:
+ HAS_LIGHTLLM_KERNEL = False
+
+
+def generate_alibi(n_head, dtype=torch.float16):
+ """
+ This method is adapted from `_generate_alibi` function
+ in `lightllm/models/bloom/layer_weights/transformer_layer_weight.py`
+ of the ModelTC/lightllm GitHub repository.
+ This method is originally the `build_alibi_tensor` function
+ in `transformers/models/bloom/modeling_bloom.py`
+ of the huggingface/transformers GitHub repository.
+ """
+
+ def get_slopes_power_of_2(n):
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
+ return [start * start**i for i in range(n)]
+
+ def get_slopes(n):
+ if math.log2(n).is_integer():
+ return get_slopes_power_of_2(n)
+ else:
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
+ slopes_power_of_2 = get_slopes_power_of_2(closest_power_of_2)
+ slopes_double = get_slopes(2 * closest_power_of_2)
+ slopes_combined = slopes_power_of_2 + slopes_double[0::2][: n - closest_power_of_2]
+ return slopes_combined
+
+ slopes = get_slopes(n_head)
+ return torch.tensor(slopes, dtype=dtype)
+
+
+class BloomInferenceForwards:
+ """
+ This class serves a micro library for bloom inference forwards.
+ We intend to replace the forward methods for BloomForCausalLM, BloomModel, BloomBlock, and BloomAttention,
+ as well as prepare_inputs_for_generation method for BloomForCausalLM.
+ For future improvement, we might want to skip replacing methods for BloomForCausalLM,
+ and call BloomModel.forward iteratively in TpInferEngine
+ """
+
+ @staticmethod
+ def bloom_for_causal_lm_forward(
+ self: BloomForCausalLM,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = False,
+ infer_state: BatchInferState = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ tp_group: Optional[dist.ProcessGroup] = None,
+ **deprecated_arguments,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ logger = logging.get_logger(__name__)
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ # If is first stage and hidden_states is not None, go throught lm_head first
+ if stage_manager.is_first_stage() and hidden_states is not None:
+ lm_logits = self.lm_head(hidden_states)
+ return {"logits": lm_logits}
+
+ outputs = BloomInferenceForwards.bloom_model_forward(
+ self.transformer,
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ tp_group=tp_group,
+ )
+
+ return outputs
+
+ @staticmethod
+ def bloom_model_forward(
+ self: BloomModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = None,
+ infer_state: BatchInferState = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ tp_group: Optional[dist.ProcessGroup] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+ logger = logging.get_logger(__name__)
+
+ # add warnings here
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
+ use_cache = False
+
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
+ " passing `position_ids`.",
+ FutureWarning,
+ )
+ if len(deprecated_arguments) > 0:
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape batch_size x num_heads x N x N
+ # head_mask has shape n_layer x batch x num_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ # first stage
+ if stage_manager.is_first_stage():
+ # check inputs and inputs embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+ # other stage
+ else:
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+
+ if infer_state.is_context_stage:
+ past_key_values_length = 0
+ else:
+ past_key_values_length = infer_state.max_len_in_batch - 1
+
+ if seq_length != 1:
+ # prefill stage
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ BatchInferState.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device)
+ else:
+ attention_mask = attention_mask.to(hidden_states.device)
+
+ # NOTE revise: we might want to store a single 1D alibi(length is #heads) in model,
+ # or store to BatchInferState to prevent re-calculating
+ # When we have multiple process group (e.g. dp together with tp), we need to pass the pg to here
+ tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
+ curr_tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
+ alibi = (
+ generate_alibi(self.num_heads * tp_size)
+ .contiguous()[curr_tp_rank * self.num_heads : (curr_tp_rank + 1) * self.num_heads]
+ .cuda()
+ )
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+
+ infer_state.decode_layer_id = 0
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ if past_key_values is None:
+ past_key_values = tuple([None] * (end_idx - start_idx + 1))
+
+ for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
+ block = self.h[idx]
+ outputs = block(
+ hidden_states,
+ layer_past=past_key_value,
+ attention_mask=causal_mask,
+ head_mask=head_mask[idx],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ infer_state=infer_state,
+ )
+
+ infer_state.decode_layer_id += 1
+ hidden_states = outputs[0]
+
+ if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
+ hidden_states = self.ln_f(hidden_states)
+
+ # update indices
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.max_len_in_batch += 1
+
+ # always return dict for imediate stage
+ return {"hidden_states": hidden_states}
+
+ @staticmethod
+ def bloom_block_forward(
+ self: BloomBlock,
+ hidden_states: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ # hidden_states: [batch_size, seq_length, hidden_size]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+
+ # Layer norm post the self attention.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+
+ # Self attention.
+ attn_outputs = self.self_attention(
+ layernorm_output,
+ residual,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ alibi=alibi,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ infer_state=infer_state,
+ )
+
+ attention_output = attn_outputs[0]
+
+ outputs = attn_outputs[1:]
+
+ layernorm_output = self.post_attention_layernorm(attention_output)
+
+ # Get residual
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = attention_output
+
+ # MLP.
+ output = self.mlp(layernorm_output, residual)
+
+ if use_cache:
+ outputs = (output,) + outputs
+ else:
+ outputs = (output,) + outputs[1:]
+
+ return outputs # hidden_states, present, attentions
+
+ @staticmethod
+ def bloom_attention_forward(
+ self: BloomAttention,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor,
+ alibi: torch.Tensor,
+ attention_mask: torch.Tensor,
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+ batch_size, q_length, H, D_HEAD = query_layer.shape
+ k = key_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+ v = value_layer.reshape(-1, H, D_HEAD) # batch_size * q_length, H, D_HEAD, q_lenth == 1
+
+ mem_manager = infer_state.cache_manager
+ layer_id = infer_state.decode_layer_id
+
+ if infer_state.is_context_stage:
+ # context process
+ max_input_len = q_length
+ b_start_loc = infer_state.start_loc
+ b_seq_len = infer_state.seq_len[:batch_size]
+ q = query_layer.reshape(-1, H, D_HEAD)
+
+ copy_kv_cache_to_dest(k, infer_state.context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(v, infer_state.context_mem_index, mem_manager.value_buffer[layer_id])
+
+ # output = self.output[:batch_size*q_length, :, :]
+ output = torch.empty_like(q)
+
+ if HAS_LIGHTLLM_KERNEL:
+ lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len)
+ else:
+ bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
+
+ context_layer = output.view(batch_size, q_length, H * D_HEAD)
+ else:
+ # query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
+ # need shape: batch_size, H, D_HEAD (q_length == 1), input q shape : (batch_size, q_length(1), H, D_HEAD)
+ assert q_length == 1, "for non-context process, we only support q_length == 1"
+ q = query_layer.reshape(-1, H, D_HEAD)
+
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_k.copy_(k)
+ cache_v.copy_(v)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head]
+ copy_kv_cache_to_dest(k, infer_state.decode_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(v, infer_state.decode_mem_index, mem_manager.value_buffer[layer_id])
+
+ b_start_loc = infer_state.start_loc
+ b_loc = infer_state.block_loc
+ b_seq_len = infer_state.seq_len
+ output = torch.empty_like(q)
+ token_attention_fwd(
+ q,
+ mem_manager.key_buffer[layer_id],
+ mem_manager.value_buffer[layer_id],
+ output,
+ b_loc,
+ b_start_loc,
+ b_seq_len,
+ infer_state.max_len_in_batch,
+ alibi,
+ )
+
+ context_layer = output.view(batch_size, q_length, H * D_HEAD)
+
+ # NOTE: always set present as none for now, instead of returning past key value to the next decoding,
+ # we create the past key value pair from the cache manager
+ present = None
+
+ # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
+ if self.pretraining_tp > 1 and self.slow_but_exact:
+ slices = self.hidden_size / self.pretraining_tp
+ output_tensor = torch.zeros_like(context_layer)
+ for i in range(self.pretraining_tp):
+ output_tensor = output_tensor + F.linear(
+ context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
+ self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
+ )
+ else:
+ output_tensor = self.dense(context_layer)
+
+ # dropout is not required here during inference
+ output_tensor = residual + output_tensor
+
+ outputs = (output_tensor, present)
+ assert output_attentions is False, "we do not support output_attentions at this time"
+
+ return outputs
diff --git a/colossalai/inference/engine/modeling/chatglm2.py b/colossalai/inference/engine/modeling/chatglm2.py
new file mode 100644
index 000000000000..56e777bb2b87
--- /dev/null
+++ b/colossalai/inference/engine/modeling/chatglm2.py
@@ -0,0 +1,492 @@
+from typing import List, Optional, Tuple
+
+import torch
+from transformers.utils import logging
+
+from colossalai.inference.kv_cache import BatchInferState
+from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig
+from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
+ ChatGLMForConditionalGeneration,
+ ChatGLMModel,
+ GLMBlock,
+ GLMTransformer,
+ SelfAttention,
+ split_tensor_along_last_dim,
+)
+
+from ._utils import copy_kv_to_mem_cache
+
+try:
+ from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
+ from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
+ context_attention_fwd as lightllm_llama2_context_attention_fwd,
+ )
+
+ HAS_LIGHTLLM_KERNEL = True
+except:
+ print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
+ HAS_LIGHTLLM_KERNEL = False
+
+
+def get_masks(self, input_ids, past_length, padding_mask=None):
+ batch_size, seq_length = input_ids.shape
+ full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
+ full_attention_mask.tril_()
+ if past_length:
+ full_attention_mask = torch.cat(
+ (
+ torch.ones(batch_size, seq_length, past_length, device=input_ids.device),
+ full_attention_mask,
+ ),
+ dim=-1,
+ )
+
+ if padding_mask is not None:
+ full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
+ if not past_length and padding_mask is not None:
+ full_attention_mask -= padding_mask.unsqueeze(-1) - 1
+ full_attention_mask = (full_attention_mask < 0.5).bool()
+ full_attention_mask.unsqueeze_(1)
+ return full_attention_mask
+
+
+def get_position_ids(batch_size, seq_length, device):
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
+ return position_ids
+
+
+class ChatGLM2InferenceForwards:
+ """
+ This class holds forwards for Chatglm2 inference.
+ We intend to replace the forward methods for ChatGLMModel, ChatGLMEecoderLayer, and ChatGLMAttention.
+ """
+
+ @staticmethod
+ def chatglm_for_conditional_generation_forward(
+ self: ChatGLMForConditionalGeneration,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = True,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ return_last_logit: Optional[bool] = False,
+ infer_state: Optional[BatchInferState] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ logger = logging.get_logger(__name__)
+
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ # If is first stage and hidden_states is not None, go throught lm_head first
+ if stage_manager.is_first_stage() and hidden_states is not None:
+ if return_last_logit:
+ hidden_states = hidden_states[-1:]
+ lm_logits = self.transformer.output_layer(hidden_states)
+ lm_logits = lm_logits.transpose(0, 1).contiguous()
+ return {"logits": lm_logits}
+
+ outputs = self.transformer(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+
+ return outputs
+
+ @staticmethod
+ def chatglm_model_forward(
+ self: ChatGLMModel,
+ input_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.BoolTensor] = None,
+ full_attention_mask: Optional[torch.BoolTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: BatchInferState = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+
+ if stage_manager.is_first_stage():
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+ if inputs_embeds is None:
+ inputs_embeds = self.embedding(input_ids)
+ if position_ids is None:
+ position_ids = get_position_ids(batch_size, seq_length, input_ids.device)
+ hidden_states = inputs_embeds
+ else:
+ assert hidden_states is not None, "hidden_states should not be None in non-first stage"
+ seq_length, batch_size, _ = hidden_states.shape
+ if position_ids is None:
+ position_ids = get_position_ids(batch_size, seq_length, hidden_states.device)
+
+ if infer_state.is_context_stage:
+ past_key_values_length = 0
+ else:
+ past_key_values_length = infer_state.max_len_in_batch - 1
+
+ seq_length_with_past = seq_length + past_key_values_length
+
+ # prefill stage at first
+ if seq_length != 1:
+ infer_state.is_context_stage = True
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ infer_state.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(
+ f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
+ )
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+
+ # related to rotary embedding
+ if infer_state.is_context_stage:
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ else:
+ seq_len = infer_state.seq_len
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
+
+ if self.pre_seq_len is not None:
+ if past_key_values is None:
+ past_key_values = self.get_prompt(
+ batch_size=batch_size,
+ device=input_ids.device,
+ dtype=inputs_embeds.dtype,
+ )
+ if attention_mask is not None:
+ attention_mask = torch.cat(
+ [
+ attention_mask.new_ones((batch_size, self.pre_seq_len)),
+ attention_mask,
+ ],
+ dim=-1,
+ )
+ if full_attention_mask is None:
+ if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
+ full_attention_mask = get_masks(
+ self, input_ids, infer_state.cache_manager.past_key_values_length, padding_mask=attention_mask
+ )
+
+ # Run encoder.
+ hidden_states = self.encoder(
+ hidden_states,
+ full_attention_mask,
+ kv_caches=past_key_values,
+ use_cache=use_cache,
+ output_hidden_states=output_hidden_states,
+ infer_state=infer_state,
+ stage_manager=stage_manager,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+
+ # update indices
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.max_len_in_batch += 1
+
+ return {"hidden_states": hidden_states}
+
+ @staticmethod
+ def chatglm_encoder_forward(
+ self: GLMTransformer,
+ hidden_states,
+ attention_mask,
+ kv_caches=None,
+ use_cache: Optional[bool] = True,
+ output_hidden_states: Optional[bool] = False,
+ infer_state: Optional[BatchInferState] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ hidden_states = hidden_states.transpose(0, 1).contiguous()
+
+ infer_state.decode_layer_id = 0
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ if kv_caches is None:
+ kv_caches = tuple([None] * (end_idx - start_idx + 1))
+
+ for idx, kv_cache in zip(range(start_idx, end_idx), kv_caches):
+ layer = self.layers[idx]
+ layer_ret = layer(
+ hidden_states,
+ attention_mask,
+ kv_cache=kv_cache,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+ infer_state.decode_layer_id += 1
+
+ hidden_states, _ = layer_ret
+
+ hidden_states = hidden_states.transpose(0, 1).contiguous()
+
+ if self.post_layer_norm and (stage_manager.is_last_stage() or stage_manager.num_stages == 1):
+ # Final layer norm.
+ hidden_states = self.final_layernorm(hidden_states)
+
+ return hidden_states
+
+ @staticmethod
+ def chatglm_glmblock_forward(
+ self: GLMBlock,
+ hidden_states,
+ attention_mask,
+ kv_cache=None,
+ use_cache=True,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ # hidden_states: [s, b, h]
+
+ # Layer norm at the beginning of the transformer layer.
+ layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ attention_output, kv_cache = self.self_attention(
+ layernorm_output,
+ attention_mask,
+ kv_cache=kv_cache,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+ # Residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = hidden_states
+ layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
+ layernorm_input = residual + layernorm_input
+ # Layer norm post the self attention.
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
+ # MLP.
+ mlp_output = self.mlp(layernorm_output)
+
+ # Second residual connection.
+ if self.apply_residual_connection_post_layernorm:
+ residual = layernorm_output
+ else:
+ residual = layernorm_input
+
+ output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
+ output = residual + output
+ return output, kv_cache
+
+ @staticmethod
+ def chatglm_flash_attn_kvcache_forward(
+ self: SelfAttention,
+ hidden_states,
+ attention_mask,
+ kv_cache=None,
+ use_cache=True,
+ infer_state: Optional[BatchInferState] = None,
+ ):
+ assert use_cache is True, "use_cache should be set to True using this chatglm attention"
+ # hidden_states: original :[sq, b, h] --> this [b, sq, h]
+ batch_size = hidden_states.shape[0]
+ hidden_size = hidden_states.shape[-1]
+ # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
+ mixed_x_layer = self.query_key_value(hidden_states)
+ if self.multi_query_attention:
+ (query_layer, key_layer, value_layer) = mixed_x_layer.split(
+ [
+ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
+ ],
+ dim=-1,
+ )
+ query_layer = query_layer.view(
+ query_layer.size()[:-1]
+ + (
+ self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ )
+ key_layer = key_layer.view(
+ key_layer.size()[:-1]
+ + (
+ self.num_multi_query_groups_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ )
+ value_layer = value_layer.view(
+ value_layer.size()[:-1]
+ + (
+ self.num_multi_query_groups_per_partition,
+ self.hidden_size_per_attention_head,
+ )
+ )
+
+ else:
+ new_tensor_shape = mixed_x_layer.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
+ # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
+ (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
+ cos, sin = infer_state.position_cos, infer_state.position_sin
+
+ chatglm2_rotary_emb_fwd(
+ query_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head), cos, sin
+ )
+ if self.multi_query_attention:
+ chatglm2_rotary_emb_fwd(
+ key_layer.view(-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head),
+ cos,
+ sin,
+ )
+ else:
+ chatglm2_rotary_emb_fwd(
+ key_layer.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
+ cos,
+ sin,
+ )
+
+ # reshape q k v to [bsz*sql, num_heads, head_dim] 2*1 ,32/2 ,128
+ query_layer = query_layer.reshape(
+ -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head
+ )
+ key_layer = key_layer.reshape(
+ -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
+ )
+ value_layer = value_layer.reshape(
+ -1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
+ )
+
+ if infer_state.is_context_stage:
+ # first token generation:
+ # copy key and value calculated in current step to memory manager
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_layer,
+ value_layer,
+ infer_state.context_mem_index,
+ infer_state.cache_manager,
+ )
+ attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
+
+ # NOTE: no bug in context attn fwd (del it )
+ lightllm_llama2_context_attention_fwd(
+ query_layer,
+ key_layer,
+ value_layer,
+ attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ )
+
+ else:
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_k.copy_(key_layer)
+ cache_v.copy_(value_layer)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_layer,
+ value_layer,
+ infer_state.decode_mem_index,
+ infer_state.cache_manager,
+ )
+
+ # second token and follows
+ attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ : infer_state.decode_mem_end, :, :
+ ]
+
+ # ==================================
+ # core attention computation is replaced by triton kernel
+ # ==================================
+ Llama2TokenAttentionForwards.token_attn(
+ query_layer,
+ cache_k,
+ cache_v,
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ infer_state.other_kv_index,
+ )
+
+ # =================
+ # Output:[b,sq, h]
+ # =================
+ output = self.dense(attn_output).reshape(batch_size, -1, hidden_size)
+
+ return output, kv_cache
diff --git a/colossalai/inference/engine/modeling/llama.py b/colossalai/inference/engine/modeling/llama.py
new file mode 100644
index 000000000000..b7bc94d0eae0
--- /dev/null
+++ b/colossalai/inference/engine/modeling/llama.py
@@ -0,0 +1,492 @@
+# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
+import math
+from typing import List, Optional, Tuple
+
+import torch
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
+from transformers.utils import logging
+
+from colossalai.inference.kv_cache.batch_infer_state import BatchInferState
+from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
+from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+from ._utils import copy_kv_to_mem_cache
+
+try:
+ from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
+ context_attention_fwd as lightllm_llama2_context_attention_fwd,
+ )
+ from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
+ context_attention_fwd as lightllm_context_attention_fwd,
+ )
+ from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
+
+ HAS_LIGHTLLM_KERNEL = True
+except:
+ print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
+ HAS_LIGHTLLM_KERNEL = False
+
+try:
+ from colossalai.kernel.triton.flash_decoding import token_flash_decoding
+ HAS_TRITON_FLASH_DECODING_KERNEL = True
+except:
+ print("no triton flash decoding support, please install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
+ HAS_TRITON_FLASH_DECODING_KERNEL = False
+
+try:
+ from flash_attn import flash_attn_with_kvcache
+ HAS_FLASH_KERNEL = True
+except:
+ HAS_FLASH_KERNEL = False
+ print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def llama_triton_context_attention(
+ query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
+):
+ if num_key_value_groups == 1:
+ if HAS_LIGHTLLM_KERNEL is False:
+ llama_context_attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ )
+ else:
+ lightllm_context_attention_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ )
+ else:
+ assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
+ lightllm_llama2_context_attention_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ )
+
+def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1, q_head_num = -1, head_dim = -1):
+ if HAS_TRITON_FLASH_DECODING_KERNEL and q_head_num != -1 and head_dim != -1:
+ token_flash_decoding(q = query_states,
+ o_tensor = attn_output,
+ infer_state = infer_state,
+ q_head_num = q_head_num,
+ head_dim = head_dim,
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id])
+ return
+
+ if num_key_value_groups == 1:
+ token_attention_fwd(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ )
+ else:
+ Llama2TokenAttentionForwards.token_attn(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ infer_state.other_kv_index,
+ )
+
+
+class LlamaInferenceForwards:
+ """
+ This class holds forwards for llama inference.
+ We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
+ """
+
+ @staticmethod
+ def llama_causal_lm_forward(
+ self: LlamaForCausalLM,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: BatchInferState = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ # If is first stage and hidden_states is None, go throught lm_head first
+ if stage_manager.is_first_stage() and hidden_states is not None:
+ lm_logits = self.lm_head(hidden_states)
+ return {"logits": lm_logits}
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = LlamaInferenceForwards.llama_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ )
+
+ return outputs
+
+ @staticmethod
+ def llama_model_forward(
+ self: LlamaModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: BatchInferState = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ # retrieve input_ids and inputs_embeds
+ if stage_manager is None or stage_manager.is_first_stage():
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ assert stage_manager is not None
+ assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}"
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ if infer_state.is_context_stage:
+ past_key_values_length = 0
+ else:
+ past_key_values_length = infer_state.max_len_in_batch - 1
+
+ # NOTE: differentiate with prefill stage
+ # block_loc require different value-assigning method for two different stage
+ if use_cache and seq_length != 1:
+ # NOTE assume prefill stage
+ # allocate memory block
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ infer_state.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
+ else:
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.repeat(batch_size, 1)
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if infer_state.is_context_stage:
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+
+ else:
+ seq_len = infer_state.seq_len
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device
+ )
+
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+ )
+
+ # decoder layers
+ infer_state.decode_layer_id = 0
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ if past_key_values is None:
+ past_key_values = tuple([None] * (end_idx - start_idx + 1))
+
+ for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
+ decoder_layer = self.layers[idx]
+ # NOTE: modify here for passing args to decoder layer
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+ infer_state.decode_layer_id += 1
+ hidden_states = layer_outputs[0]
+
+ if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
+ hidden_states = self.norm(hidden_states)
+
+ # update indices
+ # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.max_len_in_batch += 1
+
+ return {"hidden_states": hidden_states}
+
+ @staticmethod
+ def llama_decoder_layer_forward(
+ self: LlamaDecoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ @staticmethod
+ def llama_flash_attn_kvcache_forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ assert use_cache is True, "use_cache should be set to True using this llama attention"
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # NOTE might think about better way to handle transposed k and v
+ # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
+ # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+
+ # NOTE might want to revise
+ # need some way to record the length of past key values cache
+ # since we won't return past_key_value_cache right now
+
+ cos, sin = infer_state.position_cos, infer_state.position_sin
+
+ llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
+ llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
+
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
+ value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
+
+ if infer_state.is_context_stage:
+ # first token generation
+ # copy key and value calculated in current step to memory manager
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.context_mem_index,
+ infer_state.cache_manager,
+ )
+ attn_output = torch.empty_like(query_states)
+
+ llama_triton_context_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state,
+ num_key_value_groups=self.num_key_value_groups,
+ )
+ else:
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_k.copy_(key_states)
+ cache_v.copy_(value_states)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.decode_mem_index,
+ infer_state.cache_manager,
+ )
+
+ if HAS_LIGHTLLM_KERNEL:
+
+ attn_output = torch.empty_like(query_states)
+ llama_triton_token_attention(query_states = query_states,
+ attn_output = attn_output,
+ infer_state = infer_state,
+ num_key_value_groups = self.num_key_value_groups,
+ q_head_num = q_len * self.num_heads,
+ head_dim = self.head_dim)
+ else:
+ self.num_heads // self.num_key_value_heads
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
+
+ query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
+ copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
+ copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
+
+ attn_output = flash_attn_with_kvcache(
+ q=query_states,
+ k_cache=copy_cache_k,
+ v_cache=copy_cache_v,
+ softmax_scale=1 / math.sqrt(self.head_dim),
+ causal=True,
+ )
+
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ # return past_key_value as None
+ return attn_output, None, None
diff --git a/colossalai/inference/engine/policies/__init__.py b/colossalai/inference/engine/policies/__init__.py
new file mode 100644
index 000000000000..269d1c57b276
--- /dev/null
+++ b/colossalai/inference/engine/policies/__init__.py
@@ -0,0 +1,11 @@
+from .bloom import BloomModelInferPolicy
+from .chatglm2 import ChatGLM2InferPolicy
+from .llama import LlamaModelInferPolicy
+
+model_policy_map = {
+ "llama": LlamaModelInferPolicy,
+ "bloom": BloomModelInferPolicy,
+ "chatglm": ChatGLM2InferPolicy,
+}
+
+__all__ = ["LlamaModelInferPolicy", "BloomModelInferPolicy", "ChatGLM2InferPolicy", "model_polic_map"]
diff --git a/colossalai/inference/engine/policies/bloom.py b/colossalai/inference/engine/policies/bloom.py
new file mode 100644
index 000000000000..f35b50189e82
--- /dev/null
+++ b/colossalai/inference/engine/policies/bloom.py
@@ -0,0 +1,127 @@
+from functools import partial
+from typing import List
+
+import torch
+from torch.nn import LayerNorm, Module
+
+import colossalai.shardformer.layer as col_nn
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
+from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
+
+from ..modeling.bloom import BloomInferenceForwards
+
+try:
+ from colossalai.kernel.triton import layer_norm
+
+ HAS_TRITON_NORM = True
+except:
+ print("Some of our kernels require triton. You might want to install triton from https://github.com/openai/triton")
+ HAS_TRITON_NORM = False
+
+
+def get_triton_layernorm_forward():
+ if HAS_TRITON_NORM:
+
+ def _triton_layernorm_forward(self: LayerNorm, hidden_states: torch.Tensor):
+ return layer_norm(hidden_states, self.weight.data, self.bias, self.eps)
+
+ return _triton_layernorm_forward
+ else:
+ return None
+
+
+class BloomModelInferPolicy(BloomForCausalLMPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
+
+ policy = super().module_policy()
+ if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
+ from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
+
+ policy[BloomBlock] = ModulePolicyDescription(
+ attribute_replacement={
+ "self_attention.hidden_size": self.model.config.hidden_size
+ // self.shard_config.tensor_parallel_size,
+ "self_attention.split_size": self.model.config.hidden_size
+ // self.shard_config.tensor_parallel_size,
+ "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attention.query_key_value",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 3},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attention.attention_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
+ ),
+ ],
+ )
+ # NOTE set inference mode to shard config
+ self.shard_config._infer()
+
+ # set as default, in inference we also use pipeline style forward, just setting stage as 1
+ self.set_pipeline_forward(
+ model_cls=BloomForCausalLM,
+ new_forward=partial(
+ BloomInferenceForwards.bloom_for_causal_lm_forward,
+ tp_group=self.shard_config.tensor_parallel_process_group,
+ ),
+ policy=policy,
+ )
+
+ method_replacement = {"forward": BloomInferenceForwards.bloom_model_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomModel)
+
+ method_replacement = {"forward": BloomInferenceForwards.bloom_block_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=BloomBlock)
+
+ method_replacement = {"forward": BloomInferenceForwards.bloom_attention_forward}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=BloomAttention
+ )
+
+ if HAS_TRITON_NORM:
+ infer_method = get_triton_layernorm_forward()
+ method_replacement = {"forward": partial(infer_method)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LayerNorm
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == "BloomModel":
+ module = self.model
+ else:
+ module = self.model.transformer
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.word_embeddings)
+ held_layers.append(module.word_embeddings_layernorm)
+ held_layers.append(self.model.lm_head)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.h[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.ln_f)
+
+ return held_layers
diff --git a/colossalai/inference/engine/policies/chatglm2.py b/colossalai/inference/engine/policies/chatglm2.py
new file mode 100644
index 000000000000..3e1d94f4785c
--- /dev/null
+++ b/colossalai/inference/engine/policies/chatglm2.py
@@ -0,0 +1,89 @@
+from typing import List
+
+import torch.nn as nn
+
+from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
+ ChatGLMForConditionalGeneration,
+ ChatGLMModel,
+ GLMBlock,
+ GLMTransformer,
+ SelfAttention,
+)
+
+# import colossalai
+from colossalai.shardformer.policies.chatglm2 import ChatGLMModelPolicy
+
+from ..modeling._utils import init_to_get_rotary
+from ..modeling.chatglm2 import ChatGLM2InferenceForwards
+
+try:
+ HAS_TRITON_RMSNORM = True
+except:
+ print("you should install triton from https://github.com/openai/triton")
+ HAS_TRITON_RMSNORM = False
+
+
+class ChatGLM2InferPolicy(ChatGLMModelPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ self.shard_config._infer()
+
+ model_infer_forward = ChatGLM2InferenceForwards.chatglm_model_forward
+ method_replacement = {"forward": model_infer_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=ChatGLMModel)
+
+ encoder_infer_forward = ChatGLM2InferenceForwards.chatglm_encoder_forward
+ method_replacement = {"forward": encoder_infer_forward}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=GLMTransformer
+ )
+
+ encoder_layer_infer_forward = ChatGLM2InferenceForwards.chatglm_glmblock_forward
+ method_replacement = {"forward": encoder_layer_infer_forward}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=GLMBlock)
+
+ attn_infer_forward = ChatGLM2InferenceForwards.chatglm_flash_attn_kvcache_forward
+ method_replacement = {"forward": attn_infer_forward}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=SelfAttention
+ )
+ if self.shard_config.enable_tensor_parallelism:
+ policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = (
+ self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
+ )
+ # for rmsnorm and others, we need to check the shape
+
+ self.set_pipeline_forward(
+ model_cls=ChatGLMForConditionalGeneration,
+ new_forward=ChatGLM2InferenceForwards.chatglm_for_conditional_generation_forward,
+ policy=policy,
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[nn.Module]:
+ module = self.model.transformer
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embedding)
+ held_layers.append(module.output_layer)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.encoder.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ if module.encoder.post_layer_norm:
+ held_layers.append(module.encoder.final_layernorm)
+
+ # rotary_pos_emb is needed for all stages
+ held_layers.append(module.rotary_pos_emb)
+
+ return held_layers
+
+ def postprocess(self):
+ init_to_get_rotary(self.model.transformer)
+ return self.model
diff --git a/colossalai/inference/engine/policies/llama.py b/colossalai/inference/engine/policies/llama.py
new file mode 100644
index 000000000000..11517d7e8a13
--- /dev/null
+++ b/colossalai/inference/engine/policies/llama.py
@@ -0,0 +1,206 @@
+from functools import partial
+from typing import List
+
+import torch
+from torch.nn import Module
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaModel,
+ LlamaRMSNorm,
+)
+
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
+
+# import colossalai
+from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
+
+from ..modeling._utils import init_to_get_rotary
+from ..modeling.llama import LlamaInferenceForwards
+
+try:
+ from colossalai.kernel.triton import rmsnorm_forward
+
+ HAS_TRITON_RMSNORM = True
+except:
+ print("you should install triton from https://github.com/openai/triton")
+ HAS_TRITON_RMSNORM = False
+
+
+def get_triton_rmsnorm_forward():
+ if HAS_TRITON_RMSNORM:
+
+ def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
+ return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
+
+ return _triton_rmsnorm_forward
+ else:
+ return None
+
+
+class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ "self_attn.num_key_value_heads": self.model.config.num_key_value_heads
+ // self.shard_config.tensor_parallel_size,
+ }
+ if self.shard_config.extra_kwargs.get("quant", None) == "gptq":
+ from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
+
+ policy[LlamaDecoderLayer] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=RowCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.gate_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.up_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.down_proj",
+ target_module=RowCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ ],
+ )
+
+ elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant":
+ from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer
+ from colossalai.inference.quant.smoothquant.models.parallel_linear import (
+ ColW8A8BFP32OFP32Linear,
+ RowW8A8B8O8Linear,
+ RowW8A8BFP32O32LinearSiLU,
+ RowW8A8BFP32OFP32Linear,
+ )
+
+ policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=RowW8A8B8O8Linear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=RowW8A8B8O8Linear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=RowW8A8B8O8Linear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=ColW8A8BFP32OFP32Linear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.gate_proj",
+ target_module=RowW8A8BFP32O32LinearSiLU,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.up_proj",
+ target_module=RowW8A8BFP32OFP32Linear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.down_proj",
+ target_module=ColW8A8BFP32OFP32Linear,
+ kwargs={"split_num": 1},
+ ),
+ ],
+ )
+ self.shard_config._infer()
+
+ infer_forward = LlamaInferenceForwards.llama_model_forward
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
+
+ infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
+ )
+
+ infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaAttention
+ )
+
+ # set as default, in inference we also use pipeline style forward, just setting stage as 1
+ self.set_pipeline_forward(
+ model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy
+ )
+
+ infer_forward = None
+ if HAS_TRITON_RMSNORM:
+ infer_forward = get_triton_rmsnorm_forward()
+
+ if infer_forward is not None:
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaRMSNorm
+ )
+
+ return policy
+
+ def postprocess(self):
+ init_to_get_rotary(self.model.model)
+ return self.model
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == "LlamaModel":
+ module = self.model
+ else:
+ module = self.model.model
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ held_layers.append(self.model.lm_head)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.norm)
+
+ return held_layers
diff --git a/colossalai/inference/kv_cache/__init__.py b/colossalai/inference/kv_cache/__init__.py
new file mode 100644
index 000000000000..5b6ca182efae
--- /dev/null
+++ b/colossalai/inference/kv_cache/__init__.py
@@ -0,0 +1,2 @@
+from .batch_infer_state import BatchInferState
+from .kvcache_manager import MemoryManager
diff --git a/colossalai/inference/kv_cache/batch_infer_state.py b/colossalai/inference/kv_cache/batch_infer_state.py
new file mode 100644
index 000000000000..f707a86df37e
--- /dev/null
+++ b/colossalai/inference/kv_cache/batch_infer_state.py
@@ -0,0 +1,118 @@
+# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
+from dataclasses import dataclass
+
+import torch
+from transformers.tokenization_utils_base import BatchEncoding
+
+from .kvcache_manager import MemoryManager
+
+
+# adapted from: lightllm/server/router/model_infer/infer_batch.py
+@dataclass
+class BatchInferState:
+ r"""
+ Information to be passed and used for a batch of inputs during
+ a single model forward
+ """
+ batch_size: int
+ max_len_in_batch: int
+
+ cache_manager: MemoryManager = None
+
+ block_loc: torch.Tensor = None
+ start_loc: torch.Tensor = None
+ seq_len: torch.Tensor = None
+ past_key_values_len: int = None
+
+ is_context_stage: bool = False
+ context_mem_index: torch.Tensor = None
+ decode_is_contiguous: bool = None
+ decode_mem_start: int = None
+ decode_mem_end: int = None
+ decode_mem_index: torch.Tensor = None
+ decode_layer_id: int = None
+
+ device: torch.device = torch.device("cuda")
+
+ @property
+ def total_token_num(self):
+ # return self.batch_size * self.max_len_in_batch
+ assert self.seq_len is not None and self.seq_len.size(0) > 0
+ return int(torch.sum(self.seq_len))
+
+ def set_cache_manager(self, manager: MemoryManager):
+ self.cache_manager = manager
+
+ # adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
+ @staticmethod
+ def init_block_loc(
+ b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
+ ):
+ """in-place update block loc mapping based on the sequence length of the inputs in current bath"""
+ start_index = 0
+ seq_len_numpy = seq_len.cpu().numpy()
+ for i, cur_seq_len in enumerate(seq_len_numpy):
+ b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
+ start_index : start_index + cur_seq_len
+ ]
+ start_index += cur_seq_len
+ return
+
+ @classmethod
+ def init_from_batch(
+ cls,
+ batch: torch.Tensor,
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager: MemoryManager,
+ ):
+ if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)):
+ raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state")
+
+ input_ids_list = None
+ attention_mask = None
+
+ if isinstance(batch, (BatchEncoding, dict)):
+ input_ids_list = batch["input_ids"]
+ attention_mask = batch["attention_mask"]
+ else:
+ input_ids_list = batch
+ if isinstance(input_ids_list[0], int): # for a single input
+ input_ids_list = [input_ids_list]
+ attention_mask = [attention_mask] if attention_mask is not None else attention_mask
+
+ batch_size = len(input_ids_list)
+
+ seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ start_index = 0
+
+ max_len_in_batch = -1
+ if isinstance(batch, (BatchEncoding, dict)):
+ for i, attn_mask in enumerate(attention_mask):
+ curr_seq_len = len(attn_mask)
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ else:
+ length = max(len(input_id) for input_id in input_ids_list)
+ for i, input_ids in enumerate(input_ids_list):
+ curr_seq_len = length
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda")
+
+ return cls(
+ batch_size=batch_size,
+ max_len_in_batch=max_len_in_batch,
+ seq_len=seq_lengths.to("cuda"),
+ start_loc=seq_start_indexes.to("cuda"),
+ block_loc=block_loc,
+ decode_layer_id=0,
+ past_key_values_len=0,
+ is_context_stage=True,
+ cache_manager=cache_manager,
+ )
diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py
new file mode 100644
index 000000000000..dda46a756cc3
--- /dev/null
+++ b/colossalai/inference/kv_cache/kvcache_manager.py
@@ -0,0 +1,106 @@
+"""
+Refered/Modified from lightllm/common/mem_manager.py
+of the ModelTC/lightllm GitHub repository
+https://github.com/ModelTC/lightllm/blob/050af3ce65edca617e2f30ec2479397d5bb248c9/lightllm/common/mem_manager.py
+we slightly changed it to make it suitable for our colossal-ai shardformer TP-engine design.
+"""
+import torch
+from transformers.utils import logging
+
+
+class MemoryManager:
+ r"""
+ Manage token block indexes and allocate physical memory for key and value cache
+
+ Args:
+ size: maximum token number used as the size of key and value buffer
+ dtype: data type of cached key and value
+ head_num: number of heads the memory manager is responsible for
+ head_dim: embedded size per head
+ layer_num: the number of layers in the model
+ device: device used to store the key and value cache
+ """
+
+ def __init__(
+ self,
+ size: int,
+ dtype: torch.dtype,
+ head_num: int,
+ head_dim: int,
+ layer_num: int,
+ device: torch.device = torch.device("cuda"),
+ ):
+ self.logger = logging.get_logger(__name__)
+ self.available_size = size
+ self.max_len_in_batch = 0
+ self._init_mem_states(size, device)
+ self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
+
+ def _init_mem_states(self, size, device):
+ """Initialize tensors used to manage memory states"""
+ self.mem_state = torch.ones((size,), dtype=torch.bool, device=device)
+ self.mem_cum_sum = torch.empty((size,), dtype=torch.int32, device=device)
+ self.indexes = torch.arange(0, size, dtype=torch.long, device=device)
+
+ def _init_kv_buffers(self, size, device, dtype, head_num, head_dim, layer_num):
+ """Initialize key buffer and value buffer on specified device"""
+ self.key_buffer = [
+ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
+ ]
+ self.value_buffer = [
+ torch.empty((size, head_num, head_dim), dtype=dtype, device=device) for _ in range(layer_num)
+ ]
+
+ @torch.no_grad()
+ def alloc(self, required_size):
+ """allocate space of required_size by providing indexes representing available physical spaces"""
+ if required_size > self.available_size:
+ self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
+ return None
+ torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
+ select_index = torch.logical_and(self.mem_cum_sum <= required_size, self.mem_state == 1)
+ select_index = self.indexes[select_index]
+ self.mem_state[select_index] = 0
+ self.available_size -= len(select_index)
+ return select_index
+
+ @torch.no_grad()
+ def alloc_contiguous(self, required_size):
+ """allocate contiguous space of required_size"""
+ if required_size > self.available_size:
+ self.logger.warning(f"No enough cache: required_size {required_size} " f"left_size {self.available_size}")
+ return None
+ torch.cumsum(self.mem_state, dim=0, dtype=torch.int32, out=self.mem_cum_sum)
+ sum_size = len(self.mem_cum_sum)
+ loc_sums = (
+ self.mem_cum_sum[required_size - 1 :]
+ - self.mem_cum_sum[0 : sum_size - required_size + 1]
+ + self.mem_state[0 : sum_size - required_size + 1]
+ )
+ can_used_loc = self.indexes[0 : sum_size - required_size + 1][loc_sums == required_size]
+ if can_used_loc.shape[0] == 0:
+ self.logger.info(
+ f"No enough contiguous cache: required_size {required_size} " f"left_size {self.available_size}"
+ )
+ return None
+ start_loc = can_used_loc[0]
+ select_index = self.indexes[start_loc : start_loc + required_size]
+ self.mem_state[select_index] = 0
+ self.available_size -= len(select_index)
+ start = start_loc.item()
+ end = start + required_size
+ return select_index, start, end
+
+ @torch.no_grad()
+ def free(self, free_index):
+ """free memory by updating memory states based on given indexes"""
+ self.available_size += free_index.shape[0]
+ self.mem_state[free_index] = 1
+
+ @torch.no_grad()
+ def free_all(self):
+ """free all memory by updating memory states"""
+ self.available_size = len(self.mem_state)
+ self.mem_state[:] = 1
+ self.max_len_in_batch = 0
+ # self.logger.info("freed all space of memory manager")
diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py
deleted file mode 100644
index 41af9f3ef948..000000000000
--- a/colossalai/inference/pipeline/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .engine import PPInferEngine
-
-__all__ = ["PPInferEngine"]
diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py
deleted file mode 100644
index 4f42385caf8f..000000000000
--- a/colossalai/inference/pipeline/engine.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import torch
-import torch.nn as nn
-
-from colossalai.cluster import ProcessGroupMesh
-from colossalai.pipeline.schedule.generate import GenerateSchedule
-from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.shardformer import ShardConfig, ShardFormer
-from colossalai.shardformer.policies.base_policy import Policy
-
-from .microbatch_manager import MicroBatchManager
-
-
-class PPInferEngine:
- """
- PPInferEngine is a class that handles the pipeline parallel inference.
-
- Args:
- pp_size (int): the number of pipeline stages.
- pp_model (`nn.Module`): the model already in pipeline parallelism style.
- model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
- model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
- micro_batch_size (int): the micro batch size.
- micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
- new_length (int): the new length of the input sequence.
- early_stopping (bool): whether to stop early.
-
- Example:
-
- ```python
- from colossalai.ppinference import PPInferEngine
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
-
- model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
- # assume the model is infered with 4 pipeline stages
- inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})
-
- input = ["Hello, my dog is cute, and I like"]
- tokenized_input = tokenizer(input, return_tensors='pt')
- output = engine.inference([tokenized_input])
- ```
-
- """
-
- def __init__(
- self,
- pp_size: int,
- dtype: str = "fp16",
- pp_model: nn.Module = None,
- model: nn.Module = None,
- model_policy: Policy = None,
- new_length: int = 32,
- micro_batch_size: int = 1,
- micro_batch_buffer_size: int = None,
- verbose: bool = False,
- # TODO: implement early_stopping, and various gerneration options
- early_stopping: bool = False,
- do_sample: bool = False,
- num_beams: int = 1,
- ) -> None:
- assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
- self.pp_size = pp_size
- self.pg_mesh = ProcessGroupMesh(pp_size)
- self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
- self.mb_manager = MicroBatchManager(
- self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
- )
- self.verbose = verbose
- self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
-
- assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
- if dtype == "fp16":
- model.half()
- elif dtype == "bf16":
- model.to(torch.bfloat16)
- self.model = pp_model or self._shardformer(model, model_policy)
-
- def inference(self, input_list):
- out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
- if self.verbose:
- return out, timestamp
- else:
- return out
-
- def _shardformer(self, model, model_policy):
- shardconfig = ShardConfig(
- tensor_parallel_process_group=None,
- pipeline_stage_manager=self.stage_manager,
- enable_tensor_parallelism=False,
- enable_fused_normalization=False,
- enable_all_optimization=False,
- enable_flash_attention=False,
- enable_jit_fused=False,
- enable_sequence_parallelism=False,
- )
- shardformer = ShardFormer(shard_config=shardconfig)
- shard_model, _ = shardformer.optimize(model, model_policy)
- return shard_model.cuda()
diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py
deleted file mode 100644
index d2bfcb8b6842..000000000000
--- a/colossalai/inference/pipeline/modeling/gpt2.py
+++ /dev/null
@@ -1,280 +0,0 @@
-from typing import Dict, List, Optional, Tuple, Union
-
-import torch
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
-from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
-from transformers.utils import logging
-
-from colossalai.pipeline.stage_manager import PipelineStageManager
-
-
-class GPT2PipelineForwards:
- """
- This class serves as a micro library for forward function substitution of GPT2 models
- under pipeline setting.
- """
-
- @staticmethod
- def gpt2_model_forward(
- self: GPT2Model,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- stage_manager: Optional[PipelineStageManager] = None,
- hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None,
- ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
- # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
- # Please refer to original code of transformers for more details.
- logger = logging.get_logger(__name__)
-
- # Preprocess passed in arguments
- if output_attentions:
- logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
- output_attentions = False
- if output_hidden_states:
- logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
- output_hidden_states = False
-
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if past_key_values is None:
- past_length = 0
- past_key_values = tuple([None] * len(self.h))
- else:
- past_length = past_key_values[0][0].size(-2)
-
- if stage_manager.is_first_stage():
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- batch_size = input_ids.shape[0]
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- batch_size = inputs_embeds.shape[0]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
-
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if token_type_ids is not None:
- token_type_ids = token_type_ids.view(-1, input_shape[-1])
- else:
- if hidden_states is None:
- raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
- input_shape = hidden_states.size()[:-1]
- batch_size, seq_length = input_shape[0], input_shape[1]
- device = hidden_states.device
-
- # GPT2Attention mask.
- if attention_mask is not None:
- if batch_size <= 0:
- raise ValueError("batch_size has to be defined and > 0")
- attention_mask = attention_mask.view(batch_size, -1)
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- attention_mask = attention_mask[:, None, None, :]
-
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and the dtype's smallest value for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
-
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- if self.config.add_cross_attention and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_attention_mask = None
-
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_heads x N x N
- # head_mask has shape n_layer x batch x n_heads x N x N
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
-
- if stage_manager.is_first_stage():
- if position_ids is not None:
- position_ids = position_ids.view(-1, input_shape[-1])
- else:
- position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
- position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
-
- if inputs_embeds is None:
- inputs_embeds = self.wte(input_ids)
- position_embeds = self.wpe(position_ids)
- hidden_states = inputs_embeds + position_embeds
- if token_type_ids is not None:
- token_type_embeds = self.wte(token_type_ids)
- hidden_states = hidden_states + token_type_embeds
- hidden_states = self.drop(hidden_states)
-
- output_shape = input_shape + (hidden_states.size(-1),)
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- presents = () if use_cache else None
- all_self_attentions = () if output_attentions else None
- all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
- all_hidden_states = () if output_hidden_states else None
-
- # Going through held blocks.
- start_idx, end_idx = stage_index[0], stage_index[1]
- for i, layer_past in zip(range(start_idx, end_idx), past_key_values):
- block = self.h[i]
- # Model parallel
- if self.model_parallel:
- torch.cuda.set_device(hidden_states.device)
- # Ensure layer_past is on same device as hidden_states (might not be correct)
- if layer_past is not None:
- layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
- # Ensure that attention_mask is always on the same device as hidden_states
- if attention_mask is not None:
- attention_mask = attention_mask.to(hidden_states.device)
- if isinstance(head_mask, torch.Tensor):
- head_mask = head_mask.to(hidden_states.device)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, use_cache, output_attentions)
-
- return custom_forward
-
- outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- None,
- attention_mask,
- head_mask[i],
- encoder_hidden_states,
- encoder_attention_mask,
- )
- else:
- outputs = block(
- hidden_states,
- layer_past=layer_past,
- attention_mask=attention_mask,
- head_mask=head_mask[i],
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
-
- hidden_states = outputs[0]
- if use_cache is True:
- presents = presents + (outputs[1],)
-
- if output_attentions:
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
- if self.config.add_cross_attention:
- all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
-
- # Model Parallel: If it's the last layer for that device, put things on the next device
- if self.model_parallel:
- for k, v in self.device_map.items():
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
-
- if stage_manager.is_last_stage():
- hidden_states = self.ln_f(hidden_states)
-
- hidden_states = hidden_states.view(output_shape)
-
- # Add last hidden state
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- return {"hidden_states": hidden_states, "past_key_values": presents}
-
- @staticmethod
- def gpt2_lmhead_model_forward(
- self: GPT2LMHeadModel,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- stage_manager: Optional[PipelineStageManager] = None,
- hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None,
- ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
-
- This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
- Please refer to original code of transformers for more details.
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # If is first stage and after warmup, go throught lm_head first
- if stage_manager.is_first_stage() and hidden_states is not None:
- lm_logits = self.lm_head(hidden_states)
- return {"logits": lm_logits}
-
- # Not first stage or before warmup, go through gpt2 model
- outputs = GPT2PipelineForwards.gpt2_model_forward(
- self.transformer,
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- stage_manager=stage_manager,
- hidden_states=hidden_states,
- stage_index=stage_index,
- )
-
- return outputs
diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py
deleted file mode 100644
index f46e1fbdd7b3..000000000000
--- a/colossalai/inference/pipeline/modeling/llama.py
+++ /dev/null
@@ -1,229 +0,0 @@
-from typing import List, Optional
-
-import torch
-from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
-from transformers.utils import logging
-
-from colossalai.pipeline.stage_manager import PipelineStageManager
-
-
-class LlamaPipelineForwards:
- """
- This class serves as a micro library for forward function substitution of Llama models
- under pipeline setting.
- """
-
- def llama_model_forward(
- self: LlamaModel,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- stage_manager: Optional[PipelineStageManager] = None,
- hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None,
- ):
- logger = logging.get_logger(__name__)
-
- # Preprocess passed in arguments
- if output_attentions:
- logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
- output_attentions = False
- if output_hidden_states:
- logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
- output_hidden_states = False
-
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # retrieve input_ids and inputs_embeds
- if stage_manager.is_first_stage():
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- hidden_states = inputs_embeds
- else:
- input_shape = hidden_states.shape[:-1]
- batch_size, seq_length = input_shape
- device = hidden_states.device
-
- seq_length_with_past = seq_length
- past_key_values_length = 0
-
- if past_key_values is not None:
- past_key_values_length = past_key_values[0][0].shape[2]
- seq_length_with_past = seq_length_with_past + past_key_values_length
-
- if position_ids is None:
- position_ids = torch.arange(
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
- )
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
- else:
- position_ids = position_ids.view(-1, seq_length).long()
-
- # embed positions, for the first stage, hidden_states is the input embeddings,
- # for the other stages, hidden_states is the output of the previous stage
- if attention_mask is None:
- attention_mask = torch.ones(
- (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
- )
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
- )
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = () if use_cache else None
-
- start_idx, end_idx = stage_index[0], stage_index[1]
- if past_key_values is None:
- past_key_values = tuple([None] * (end_idx - start_idx + 1))
-
- for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
- decoder_layer = self.layers[idx]
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
-
- # past_key_value = past_key_values[idx] if past_key_values is not None else None
-
- if self.gradient_checkpointing and self.training:
-
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, output_attentions, None)
-
- return custom_forward
-
- layer_outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(decoder_layer),
- hidden_states,
- attention_mask,
- position_ids,
- None,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
-
- if stage_manager.is_last_stage():
- hidden_states = self.norm(hidden_states)
-
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- next_cache = next_decoder_cache if use_cache else None
-
- # always return dict for imediate stage
- return {"hidden_states": hidden_states, "past_key_values": next_cache}
-
- def llama_for_causal_lm_forward(
- self: LlamaForCausalLM,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- stage_manager: Optional[PipelineStageManager] = None,
- hidden_states: Optional[torch.FloatTensor] = None,
- stage_index: Optional[List[int]] = None,
- ):
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
-
- Returns:
-
- Example:
-
- ```python
- >>> from transformers import AutoTokenizer, LlamaForCausalLM
-
- >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
-
- >>> prompt = "Hey, are you consciours? Can you talk to me?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
-
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
- ```"""
- logger = logging.get_logger(__name__)
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if output_attentions:
- logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
- output_attentions = False
- if output_hidden_states:
- logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
- output_hidden_states = False
-
- # If is first stage and after warmup, go throught lm_head first
- if stage_manager.is_first_stage() and hidden_states is not None:
- lm_logits = self.lm_head(hidden_states)
- return {"logits": lm_logits}
-
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = LlamaPipelineForwards.llama_model_forward(
- self.model,
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- stage_manager=stage_manager,
- hidden_states=hidden_states,
- stage_index=stage_index,
- )
-
- return outputs
diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py
deleted file mode 100644
index 51e6425b113e..000000000000
--- a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py
+++ /dev/null
@@ -1,74 +0,0 @@
-from functools import partial
-from typing import Callable, Dict, List
-
-from torch import Tensor, nn
-
-import colossalai.shardformer.layer as col_nn
-from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
-from colossalai.shardformer.policies.gpt2 import GPT2Policy
-
-from ..modeling.gpt2 import GPT2PipelineForwards
-
-
-class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
- def __init__(self) -> None:
- super().__init__()
-
- def module_policy(self):
- from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
-
- module_policy = super().module_policy()
-
- if self.shard_config.enable_tensor_parallelism:
- addon_module = {
- GPT2LMHeadModel: ModulePolicyDescription(
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
- )
- ]
- )
- }
- module_policy.update(addon_module)
-
- if self.pipeline_stage_manager is not None:
- self.set_pipeline_forward(
- model_cls=GPT2LMHeadModel,
- new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
- policy=module_policy,
- )
- return module_policy
-
- def get_held_layers(self) -> List[nn.Module]:
- held_layers = super().get_held_layers()
- # make the tie weight lm_head and embedding in the same device to save memory
- # if self.pipeline_stage_manager.is_first_stage():
- if self.pipeline_stage_manager.is_first_stage():
- held_layers.append(self.model.lm_head)
- return held_layers
-
- def get_shared_params(self) -> List[Dict[int, Tensor]]:
- """The weights of wte and lm_head are shared."""
- module = self.model
- stage_manager = self.pipeline_stage_manager
- if stage_manager is not None:
- if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
- first_stage, last_stage = 0, stage_manager.num_stages - 1
- return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
- return []
-
- def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
- """If under pipeline parallel setting, replacing the original forward method of huggingface
- to customized forward method, and add this changing to policy."""
- if not self.pipeline_stage_manager:
- raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
- stage_manager = self.pipeline_stage_manager
- if self.model.__class__.__name__ == "GPT2Model":
- module = self.model
- else:
- module = self.model.transformer
-
- layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
- stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
- method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
- self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
diff --git a/colossalai/inference/pipeline/policy/llama_ppinfer.py b/colossalai/inference/pipeline/policy/llama_ppinfer.py
deleted file mode 100644
index 6e12ed61bf7b..000000000000
--- a/colossalai/inference/pipeline/policy/llama_ppinfer.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from typing import List
-
-from torch.nn import Module
-
-from colossalai.shardformer.layer import Linear1D_Col
-from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
-from colossalai.shardformer.policies.llama import LlamaPolicy
-
-from ..modeling.llama import LlamaPipelineForwards
-
-
-class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
- def __init__(self) -> None:
- super().__init__()
-
- def module_policy(self):
- from transformers import LlamaForCausalLM
-
- policy = super().module_policy()
-
- if self.shard_config.enable_tensor_parallelism:
- # add a new item for casual lm
- new_item = {
- LlamaForCausalLM: ModulePolicyDescription(
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
- )
- ]
- )
- }
- policy.update(new_item)
-
- if self.pipeline_stage_manager:
- # set None as default
- self.set_pipeline_forward(
- model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
- )
-
- return policy
-
- def get_held_layers(self) -> List[Module]:
- """Get pipeline layers for current stage."""
- stage_manager = self.pipeline_stage_manager
- held_layers = super().get_held_layers()
- if stage_manager.is_first_stage():
- held_layers.append(self.model.lm_head)
- return held_layers
diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py
deleted file mode 100644
index c26aa4e40b71..000000000000
--- a/colossalai/inference/pipeline/utils.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from typing import Set
-
-import torch.nn as nn
-
-from colossalai.shardformer._utils import getattr_, setattr_
-
-
-def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None:
- """
- Set all parameters and buffers of model to None
-
- Args:
- model (nn.Module): The model to set
- """
- for module_suffix in include:
- set_module = getattr_(model, module_suffix)
- for n, p in set_module.named_parameters():
- setattr_(set_module, n, None)
- for n, buf in set_module.named_buffers():
- setattr_(set_module, n, None)
- setattr_(model, module_suffix, None)
-
-
-def get_suffix_name(suffix: str, name: str):
- """
- Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit,
- and 'name' when `suffix` is empty.
-
- Args:
- suffix (str): The suffix of the suffix module
- name (str): The name of the current module
- """
- point = "" if suffix is "" else "."
- suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}"
- return suffix_name
diff --git a/colossalai/inference/quant/__init__.py b/colossalai/inference/quant/__init__.py
new file mode 100644
index 000000000000..18e0de9cc9fc
--- /dev/null
+++ b/colossalai/inference/quant/__init__.py
@@ -0,0 +1 @@
+from .smoothquant.models.llama import SmoothLlamaForCausalLM
diff --git a/colossalai/inference/quant/gptq/__init__.py b/colossalai/inference/quant/gptq/__init__.py
index c035f397923a..4cf1fd658a41 100644
--- a/colossalai/inference/quant/gptq/__init__.py
+++ b/colossalai/inference/quant/gptq/__init__.py
@@ -2,3 +2,4 @@
if HAS_AUTO_GPTQ:
from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
+ from .gptq_manager import GPTQManager
diff --git a/colossalai/inference/quant/gptq/gptq_manager.py b/colossalai/inference/quant/gptq/gptq_manager.py
new file mode 100644
index 000000000000..2d352fbef2b9
--- /dev/null
+++ b/colossalai/inference/quant/gptq/gptq_manager.py
@@ -0,0 +1,61 @@
+import torch
+
+
+class GPTQManager:
+ def __init__(self, quant_config, max_input_len: int = 1):
+ self.max_dq_buffer_size = 1
+ self.max_inner_outer_dim = 1
+ self.bits = quant_config.bits
+ self.use_act_order = quant_config.desc_act
+ self.max_input_len = 1
+ self.gptq_temp_state_buffer = None
+ self.gptq_temp_dq_buffer = None
+ self.quant_config = quant_config
+
+ def post_init_gptq_buffer(self, model: torch.nn.Module) -> None:
+ from .cai_gptq import CaiQuantLinear
+
+ HAS_GPTQ_CUDA = False
+ try:
+ from colossalai.kernel.op_builder.gptq import GPTQBuilder
+
+ gptq_cuda = GPTQBuilder().load()
+ HAS_GPTQ_CUDA = True
+ except ImportError:
+ warnings.warn("CUDA gptq is not installed")
+ HAS_GPTQ_CUDA = False
+
+ for name, submodule in model.named_modules():
+ if isinstance(submodule, CaiQuantLinear):
+ self.max_dq_buffer_size = max(self.max_dq_buffer_size, submodule.qweight.numel() * 8)
+
+ if self.use_act_order:
+ self.max_inner_outer_dim = max(
+ self.max_inner_outer_dim, submodule.infeatures, submodule.outfeatures
+ )
+ self.bits = submodule.bits
+ if not (HAS_GPTQ_CUDA and self.bits == 4):
+ return
+
+ max_input_len = 1
+ if self.use_act_order:
+ max_input_len = self.max_input_len
+ # The temp_state buffer is required to reorder X in the act-order case.
+ # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
+ self.gptq_temp_state_buffer = torch.zeros(
+ (max_input_len, self.max_inner_outer_dim), dtype=torch.float16, device=torch.cuda.current_device()
+ )
+ self.gptq_temp_dq_buffer = torch.zeros(
+ (1, self.max_dq_buffer_size), dtype=torch.float16, device=torch.cuda.current_device()
+ )
+
+ gptq_cuda.prepare_buffers(
+ torch.device(torch.cuda.current_device()), self.gptq_temp_state_buffer, self.gptq_temp_dq_buffer
+ )
+ # Using the default from exllama repo here.
+ matmul_recons_thd = 8
+ matmul_fused_remap = False
+ matmul_no_half2 = False
+ gptq_cuda.set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
+
+ torch.cuda.empty_cache()
diff --git a/colossalai/inference/quant/smoothquant/models/__init__.py b/colossalai/inference/quant/smoothquant/models/__init__.py
index 77541d8610c5..1663028da138 100644
--- a/colossalai/inference/quant/smoothquant/models/__init__.py
+++ b/colossalai/inference/quant/smoothquant/models/__init__.py
@@ -4,9 +4,7 @@
HAS_TORCH_INT = True
except ImportError:
HAS_TORCH_INT = False
- raise ImportError(
- "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
- )
+ print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
if HAS_TORCH_INT:
from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
diff --git a/colossalai/inference/quant/smoothquant/models/base_model.py b/colossalai/inference/quant/smoothquant/models/base_model.py
index 6a1d96ecec8f..f3afe5d83bb0 100644
--- a/colossalai/inference/quant/smoothquant/models/base_model.py
+++ b/colossalai/inference/quant/smoothquant/models/base_model.py
@@ -9,7 +9,6 @@
from os.path import isdir, isfile, join
from typing import Dict, List, Optional, Union
-import accelerate
import numpy as np
import torch
import torch.nn as nn
@@ -21,8 +20,16 @@
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import PushToHubMixin, cached_file
-from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
-from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
+from colossalai.inference.kv_cache.batch_infer_state import BatchInferState, MemoryManager
+
+try:
+ import accelerate
+
+ HAS_ACCELERATE = True
+except ImportError:
+ HAS_ACCELERATE = False
+ print("accelerate is not installed.")
+
SUPPORTED_MODELS = ["llama"]
@@ -87,7 +94,6 @@ def init_batch_state(self, max_output_len=256, **kwargs):
batch_infer_state.start_loc = seq_start_indexes.to("cuda")
batch_infer_state.block_loc = block_loc
batch_infer_state.decode_layer_id = 0
- batch_infer_state.past_key_values_len = 0
batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager)
batch_infer_state.cache_manager.free_all()
diff --git a/colossalai/inference/quant/smoothquant/models/linear.py b/colossalai/inference/quant/smoothquant/models/linear.py
index 969c390a0849..03d994b32489 100644
--- a/colossalai/inference/quant/smoothquant/models/linear.py
+++ b/colossalai/inference/quant/smoothquant/models/linear.py
@@ -1,17 +1,25 @@
# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
import torch
-from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
-from torch_int.functional.quantization import quantize_per_tensor_absmax
+
+try:
+ from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
+ from torch_int.functional.quantization import quantize_per_tensor_absmax
+
+ HAS_TORCH_INT = True
+except ImportError:
+ HAS_TORCH_INT = False
+ print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
+
try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
-except ImportError:
+except:
HAS_SMOOTHQUANT_CUDA = False
- raise ImportError("CUDA smoothquant linear is not installed")
+ print("CUDA smoothquant linear is not installed")
class W8A8BFP32O32LinearSiLU(torch.nn.Module):
@@ -138,21 +146,23 @@ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
)
self.register_buffer(
"bias",
- torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
+ torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False),
)
self.register_buffer("a", torch.tensor(alpha))
def _apply(self, fn):
# prevent the bias from being converted to half
super()._apply(fn)
- self.bias = self.bias.to(torch.float32)
+ if self.bias is not None:
+ self.bias = self.bias.to(torch.float32)
return self
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
- self.bias = self.bias.to(*args, **kwargs)
- self.bias = self.bias.to(torch.float32)
+ if self.bias is not None:
+ self.bias = self.bias.to(*args, **kwargs)
+ self.bias = self.bias.to(torch.float32)
return self
@torch.no_grad()
diff --git a/colossalai/inference/quant/smoothquant/models/llama.py b/colossalai/inference/quant/smoothquant/models/llama.py
index 4c3d6dcc0b23..bb74dc49d7af 100644
--- a/colossalai/inference/quant/smoothquant/models/llama.py
+++ b/colossalai/inference/quant/smoothquant/models/llama.py
@@ -8,7 +8,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
@@ -18,12 +17,11 @@
LlamaDecoderLayer,
LlamaMLP,
LlamaRotaryEmbedding,
- repeat_kv,
rotate_half,
)
from transformers.utils import add_start_docstrings_to_model_forward
-from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.inference.kv_cache.batch_infer_state import BatchInferState
from colossalai.kernel.triton import (
copy_kv_cache_to_dest,
int8_rotary_embedding_fwd,
@@ -31,10 +29,31 @@
smooth_token_attention_fwd,
)
+try:
+ from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
+
+ HAS_TORCH_INT = True
+except ImportError:
+ HAS_TORCH_INT = False
+ print("Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
+
+
from .base_model import BaseSmoothForCausalLM
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
class LLamaSmoothquantAttention(nn.Module):
def __init__(
self,
@@ -116,7 +135,6 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
def forward(
self,
hidden_states: torch.Tensor,
- rotary_emb: Tuple[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
@@ -131,8 +149,7 @@ def forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
- cos = rotary_emb[0]
- sin = rotary_emb[1]
+ cos, sin = infer_state.position_cos, infer_state.position_sin
int8_rotary_embedding_fwd(
query_states.view(-1, self.num_heads, self.head_dim),
@@ -149,12 +166,6 @@ def forward(
self.k_rotary_output_scale.item(),
)
- # NOTE might want to revise
- # need some way to record the length of past key values cache
- # since we won't return past_key_value_cache right now
- if infer_state.decode_layer_id == 0: # once per model.forward
- infer_state.cache_manager.past_key_values_length += q_len # seq_len
-
def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
@@ -229,7 +240,7 @@ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index,
infer_state.block_loc,
infer_state.start_loc,
infer_state.seq_len,
- infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
)
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
@@ -354,7 +365,6 @@ def pack(
def forward(
self,
hidden_states: torch.Tensor,
- rotary_emb: Tuple[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
@@ -384,7 +394,6 @@ def forward(
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
- rotary_emb=rotary_emb,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
@@ -592,17 +601,13 @@ def llama_model_forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
- seq_length_with_past = seq_length
- past_key_values_length = 0
-
infer_state = self.infer_state
+ if infer_state.is_context_stage:
+ past_key_values_length = 0
+ else:
+ past_key_values_length = infer_state.max_len_in_batch - 1
- if past_key_values is not None:
- # NOT READY FOR PRIME TIME
- # dummy but work, revise it
- past_key_values_length = infer_state.cache_manager.past_key_values_length
- # past_key_values_length = past_key_values[0][0].shape[2]
- seq_length_with_past = seq_length_with_past + past_key_values_length
+ seq_length_with_past = seq_length + past_key_values_length
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
@@ -623,9 +628,7 @@ def llama_model_forward(
infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
- print(
- f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
- )
+ print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
@@ -662,15 +665,15 @@ def llama_model_forward(
raise NotImplementedError("not implement gradient_checkpointing and training options ")
if past_key_values_length == 0:
- position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
- position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
else:
- position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
- position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
# decoder layers
all_hidden_states = () if output_hidden_states else None
@@ -685,7 +688,6 @@ def llama_model_forward(
layer_outputs = decoder_layer(
hidden_states,
- rotary_emb=(position_cos, position_sin),
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
@@ -713,6 +715,7 @@ def llama_model_forward(
infer_state.is_context_stage = False
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
+ infer_state.max_len_in_batch += 1
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
diff --git a/colossalai/inference/quant/smoothquant/models/parallel_linear.py b/colossalai/inference/quant/smoothquant/models/parallel_linear.py
new file mode 100644
index 000000000000..962b687a1d05
--- /dev/null
+++ b/colossalai/inference/quant/smoothquant/models/parallel_linear.py
@@ -0,0 +1,264 @@
+from typing import List, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+from colossalai.lazy import LazyInitContext
+from colossalai.shardformer.layer import ParallelModule
+
+from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
+
+
+def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1):
+ qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0)
+ if smooth_linear.bias is not None:
+ bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0)
+
+ smooth_split_out_features = para_linear.out_features // split_num
+
+ for i in range(split_num):
+ para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][
+ tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, :
+ ]
+
+ if para_linear.bias is not None:
+ para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][
+ :, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features
+ ]
+
+
+def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1):
+ qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1)
+
+ smooth_split_in_features = para_linear.in_features // split_num
+
+ for i in range(split_num):
+ para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][
+ :, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features
+ ]
+
+ if smooth_linear.bias is not None:
+ para_linear.bias.copy_(smooth_linear.bias)
+
+
+class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
+ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
+ super().__init__(in_features, out_features, alpha, beta)
+ self.process_group = None
+ self.tp_size = 1
+ self.tp_rank = 0
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ out_features = module.out_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if out_features < tp_size:
+ return module
+
+ if out_features % tp_size != 0:
+ raise ValueError(
+ f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
+ )
+ linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size)
+ linear_1d.tp_size = tp_size
+ linear_1d.tp_rank = tp_rank
+ linear_1d.process_group = process_group
+ linear_1d.a = module.a.clone().detach()
+ linear_1d.b = module.b.clone().detach()
+ split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ return linear_1d
+
+
+class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
+ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
+ super().__init__(in_features, out_features, alpha, beta)
+ self.process_group = None
+ self.tp_size = 1
+ self.tp_rank = 0
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
+ )
+ linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features)
+ linear_1d.tp_size = tp_size
+ linear_1d.tp_rank = tp_rank
+ linear_1d.process_group = process_group
+ linear_1d.a = torch.tensor(module.a)
+ linear_1d.b = torch.tensor(module.b)
+
+ split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ if linear_1d.bias is not None:
+ linear_1d.bias = linear_1d.bias // tp_size
+
+ return linear_1d
+
+ @torch.no_grad()
+ def forward(self, x):
+ output = super().forward(x)
+ if self.tp_size > 1:
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
+ return output
+
+
+class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule):
+ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
+ super().__init__(in_features, out_features, alpha, beta)
+ self.process_group = None
+ self.tp_size = 1
+ self.tp_rank = 0
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ out_features = module.out_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if out_features < tp_size:
+ return module
+
+ if out_features % tp_size != 0:
+ raise ValueError(
+ f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
+ )
+ linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size)
+ linear_1d.tp_size = tp_size
+ linear_1d.tp_rank = tp_rank
+ linear_1d.process_group = process_group
+ linear_1d.a = module.a.clone().detach()
+
+ split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ return linear_1d
+
+
+class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
+ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
+ super().__init__(in_features, out_features, alpha, beta)
+ self.process_group = None
+ self.tp_size = 1
+ self.tp_rank = 0
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ out_features = module.out_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if out_features < tp_size:
+ return module
+
+ if out_features % tp_size != 0:
+ raise ValueError(
+ f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
+ )
+ linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size)
+ linear_1d.tp_size = tp_size
+ linear_1d.tp_rank = tp_rank
+ linear_1d.process_group = process_group
+ linear_1d.a = module.a.clone().detach()
+
+ split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ return linear_1d
+
+
+class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
+ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
+ super().__init__(in_features, out_features, alpha, beta)
+ self.process_group = None
+ self.tp_size = 1
+ self.tp_rank = 0
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
+ )
+ linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features)
+ linear_1d.tp_size = tp_size
+ linear_1d.tp_rank = tp_rank
+ linear_1d.process_group = process_group
+ linear_1d.a = module.a.clone().detach()
+
+ split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ if linear_1d.bias is not None:
+ linear_1d.bias = linear_1d.bias / tp_size
+
+ return linear_1d
+
+ @torch.no_grad()
+ def forward(self, x):
+ output = super().forward(x)
+ if self.tp_size > 1:
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
+ return output
diff --git a/colossalai/inference/tensor_parallel/batch_infer_state.py b/colossalai/inference/tensor_parallel/batch_infer_state.py
deleted file mode 100644
index de150311cc08..000000000000
--- a/colossalai/inference/tensor_parallel/batch_infer_state.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
-from dataclasses import dataclass
-
-import torch
-
-from .kvcache_manager import MemoryManager
-
-# adapted from: lightllm/server/router/model_infer/infer_batch.py
-@dataclass
-class BatchInferState:
- r"""
- Information to be passed and used for a batch of inputs during
- a single model forward
- """
- batch_size: int
- max_len_in_batch: int
-
- cache_manager: MemoryManager = None
-
- block_loc: torch.Tensor = None
- start_loc: torch.Tensor = None
- seq_len: torch.Tensor = None
- past_key_values_len: int = None
-
- is_context_stage: bool = False
- context_mem_index: torch.Tensor = None
- decode_is_contiguous: bool = None
- decode_mem_start: int = None
- decode_mem_end: int = None
- decode_mem_index: torch.Tensor = None
- decode_layer_id: int = None
-
- device: torch.device = torch.device("cuda")
-
- @property
- def total_token_num(self):
- # return self.batch_size * self.max_len_in_batch
- assert self.seq_len is not None and self.seq_len.size(0) > 0
- return int(torch.sum(self.seq_len))
-
- def set_cache_manager(self, manager: MemoryManager):
- self.cache_manager = manager
-
- # adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
- @staticmethod
- def init_block_loc(
- b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
- ):
- """in-place update block loc mapping based on the sequence length of the inputs in current bath"""
- start_index = 0
- seq_len_numpy = seq_len.cpu().numpy()
- for i, cur_seq_len in enumerate(seq_len_numpy):
- b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
- start_index : start_index + cur_seq_len
- ]
- start_index += cur_seq_len
- return
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index aac57d34a2c1..25076b742c26 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -11,7 +11,7 @@
from colossalai.context import Config
from colossalai.logging import get_dist_logger
-from colossalai.utils import set_device, set_seed
+from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed
def launch(
@@ -47,12 +47,15 @@ def launch(
if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.")
+ if IS_NPU_AVAILABLE and backend == "nccl":
+ backend = "hccl"
+
# init default process group
init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device
- if torch.cuda.is_available():
+ if torch.cuda.is_available() or IS_NPU_AVAILABLE:
# if local rank is not given, calculate automatically
set_device(local_rank)
diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam.h b/colossalai/kernel/cuda_native/csrc/cpu_adam.h
index bf9b85997c78..db1f26d5f6da 100644
--- a/colossalai/kernel/cuda_native/csrc/cpu_adam.h
+++ b/colossalai/kernel/cuda_native/csrc/cpu_adam.h
@@ -142,6 +142,7 @@ class Adam_Optimizer {
}
}
+#if defined(__AVX512__) or defined(__AVX256__) or defined(__AVX2__)
inline void simd_load(bool is_half, float *ptr, __half *h_ptr,
AVX_Data &data) {
if (is_half) {
@@ -159,6 +160,7 @@ class Adam_Optimizer {
SIMD_STORE(ptr, data.data);
}
}
+#endif
void step(size_t step, float lr, float beta1, float beta2, float epsilon,
float weight_decay, bool bias_correction, torch::Tensor ¶ms,
diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp
new file mode 100644
index 000000000000..a715a2711576
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.cpp
@@ -0,0 +1,304 @@
+#include "cpu_adam_arm.h"
+
+void AdamOptimizer::Step_1(void *_params, void *grads, void *_exp_avg,
+ void *_exp_avg_sq, size_t _param_size,
+ at::ScalarType param_dtype,
+ at::ScalarType grad_dtype,
+ at::ScalarType exp_avg_dtype,
+ at::ScalarType exp_avg_sq_dtype, float loss_scale) {
+ size_t rounded_size = 0;
+#if defined(__aarch64__)
+ rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
+#endif
+
+ float betta1_minus1 = 1 - _betta1;
+ float betta2_minus1 = 1 - _betta2;
+ float step_size = -1 * _alpha / _bias_correction1;
+ float w_decay = -1 * _alpha * _weight_decay;
+
+#if defined(__aarch64__)
+ float32x4_t betta1_4 = simd_set(_betta1);
+ float32x4_t betta2_4 = simd_set(_betta2);
+ float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
+ float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
+ float32x4_t bias2_sqrt = simd_set(_bias_correction2);
+ float32x4_t eps_4 = simd_set(_eps);
+ float32x4_t step_size_4 = simd_set(step_size);
+ float32x4_t weight_decay_4;
+ if (_weight_decay > 0) {
+ weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
+ }
+ for (size_t t = 0; t < rounded_size; t += TILE) {
+ size_t copy_size = TILE;
+ if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
+ size_t offset = copy_size + t;
+
+#pragma omp parallel for
+ for (size_t i = t; i < offset; i += SIMD_WIDTH) {
+ float32x4_t grad_4 = simd_load_offset(grads, grad_dtype, i);
+ if (loss_scale > 0) {
+ float32x4_t loss_scale_vec = simd_set(loss_scale);
+ grad_4 = vdivq_f32(grad_4, loss_scale_vec);
+ }
+ float32x4_t momentum_4 = simd_load_offset(_exp_avg, exp_avg_dtype, i);
+ float32x4_t variance_4 =
+ simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i);
+ float32x4_t param_4 = simd_load_offset(_params, param_dtype, i);
+ if (_weight_decay > 0 && !_adamw_mode) {
+ grad_4 = vfmaq_f32(grad_4, param_4, weight_decay_4);
+ }
+ momentum_4 = vmulq_f32(momentum_4, betta1_4);
+ momentum_4 = vfmaq_f32(momentum_4, grad_4, betta1_minus1_4);
+ variance_4 = vmulq_f32(variance_4, betta2_4);
+ grad_4 = vmulq_f32(grad_4, grad_4);
+ variance_4 = vfmaq_f32(variance_4, grad_4, betta2_minus1_4);
+ grad_4 = vsqrtq_f32(variance_4);
+ grad_4 = vfmaq_f32(eps_4, grad_4, bias2_sqrt);
+ grad_4 = vdivq_f32(momentum_4, grad_4);
+ if (_weight_decay > 0 && _adamw_mode) {
+ param_4 = vfmaq_f32(param_4, param_4, weight_decay_4);
+ }
+ param_4 = vfmaq_f32(param_4, grad_4, step_size_4);
+ simd_store_offset(_params, param_dtype, param_4, i);
+ simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4, i);
+ simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4, i);
+ }
+ }
+#endif
+ if (_param_size > rounded_size) {
+ for (size_t t = rounded_size; t < _param_size; t += TILE) {
+ size_t copy_size = TILE;
+ if ((t + TILE) > _param_size) copy_size = _param_size - t;
+ size_t offset = copy_size + t;
+
+#pragma omp parallel for
+ for (size_t k = t; k < offset; k++) {
+ float grad = scalar_load_offset(grads, grad_dtype, k);
+ if (loss_scale > 0) {
+ grad /= loss_scale;
+ }
+ float param = scalar_load_offset(_params, param_dtype, k);
+ float momentum = scalar_load_offset(_exp_avg, exp_avg_dtype, k);
+ float variance = scalar_load_offset(_exp_avg_sq, exp_avg_sq_dtype, k);
+ if (_weight_decay > 0 && !_adamw_mode) {
+ grad = param * _weight_decay + grad;
+ }
+ momentum = momentum * _betta1;
+ momentum = grad * betta1_minus1 + momentum;
+
+ variance = variance * _betta2;
+ grad = grad * grad;
+ variance = grad * betta2_minus1 + variance;
+
+ grad = sqrt(variance);
+ grad = grad * _bias_correction2 + _eps;
+ grad = momentum / grad;
+ if (_weight_decay > 0 && _adamw_mode) {
+ param += w_decay * param;
+ }
+ param = grad * step_size + param;
+
+ scalar_store_offset(_params, param_dtype, param, k);
+ scalar_store_offset(_exp_avg, exp_avg_dtype, momentum, k);
+ scalar_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance, k);
+ }
+ }
+ }
+}
+
+void AdamOptimizer::Step_4(void *_params, void *grads, void *_exp_avg,
+ void *_exp_avg_sq, size_t _param_size,
+ at::ScalarType param_dtype,
+ at::ScalarType grad_dtype,
+ at::ScalarType exp_avg_dtype,
+ at::ScalarType exp_avg_sq_dtype, float loss_scale) {
+ size_t rounded_size = 0;
+#if defined(__aarch64__)
+ rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 4);
+#endif
+
+ float betta1_minus1 = 1 - _betta1;
+ float betta2_minus1 = 1 - _betta2;
+ float step_size = -1 * _alpha / _bias_correction1;
+ float w_decay = -1 * _alpha * _weight_decay;
+
+#if defined(__aarch64__)
+ float32x4_t betta1_4 = simd_set(_betta1);
+ float32x4_t betta2_4 = simd_set(_betta2);
+ float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
+ float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
+ float32x4_t bias2_sqrt = simd_set(_bias_correction2);
+ float32x4_t eps_4 = simd_set(_eps);
+ float32x4_t step_size_4 = simd_set(step_size);
+ float32x4_t weight_decay_4;
+ if (_weight_decay > 0) {
+ weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
+ }
+
+ for (size_t t = 0; t < rounded_size; t += TILE) {
+ size_t copy_size = TILE;
+ if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
+ size_t offset = copy_size + t;
+
+#pragma omp parallel for
+ for (size_t i = t; i < offset; i += SIMD_WIDTH * 4) {
+ float32x4_t grad_4[4];
+ float32x4_t momentum_4[4];
+ float32x4_t variance_4[4];
+ float32x4_t param_4[4];
+#pragma unroll 4
+ for (int j = 0; j < 4; j++) {
+ grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
+ if (loss_scale > 0) {
+ float32x4_t loss_scale_vec = simd_set(loss_scale);
+ grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
+ }
+ momentum_4[j] =
+ simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
+ variance_4[j] =
+ simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
+ param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
+ if (_weight_decay > 0 && !_adamw_mode) {
+ grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
+ }
+ momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
+ momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
+ variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
+ grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
+ variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
+ grad_4[j] = vsqrtq_f32(variance_4[j]);
+ grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
+ grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
+ if (_weight_decay > 0 && _adamw_mode) {
+ param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
+ }
+ param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
+ simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
+ simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
+ i + SIMD_WIDTH * j);
+ simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
+ i + SIMD_WIDTH * j);
+ }
+ }
+ }
+#endif
+ if (_param_size > rounded_size) {
+ Step_1(scalar_seek_offset(_params, param_dtype, rounded_size),
+ scalar_seek_offset(grads, grad_dtype, rounded_size),
+ scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
+ scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
+ (_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
+ exp_avg_sq_dtype, loss_scale);
+ }
+}
+
+void AdamOptimizer::Step_8(void *_params, void *grads, void *_exp_avg,
+ void *_exp_avg_sq, size_t _param_size,
+ at::ScalarType param_dtype,
+ at::ScalarType grad_dtype,
+ at::ScalarType exp_avg_dtype,
+ at::ScalarType exp_avg_sq_dtype, float loss_scale) {
+ size_t rounded_size = 0;
+#if defined(__aarch64__)
+ rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * 8);
+#endif
+
+ float betta1_minus1 = 1 - _betta1;
+ float betta2_minus1 = 1 - _betta2;
+ float step_size = -1 * _alpha / _bias_correction1;
+ float w_decay = -1 * _alpha * _weight_decay;
+#if defined(__aarch64__)
+ float32x4_t betta1_4 = simd_set(_betta1);
+ float32x4_t betta2_4 = simd_set(_betta2);
+ float32x4_t betta1_minus1_4 = simd_set(betta1_minus1);
+ float32x4_t betta2_minus1_4 = simd_set(betta2_minus1);
+ float32x4_t bias2_sqrt = simd_set(_bias_correction2);
+ float32x4_t eps_4 = simd_set(_eps);
+ float32x4_t step_size_4 = simd_set(step_size);
+ float32x4_t weight_decay_4;
+ if (_weight_decay > 0) {
+ weight_decay_4 = _adamw_mode ? simd_set(w_decay) : simd_set(_weight_decay);
+ }
+
+ for (size_t t = 0; t < rounded_size; t += TILE) {
+ size_t copy_size = TILE;
+ if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
+ size_t offset = copy_size + t;
+
+#pragma omp parallel for
+ for (size_t i = t; i < offset; i += SIMD_WIDTH * 8) {
+ float32x4_t grad_4[8];
+ float32x4_t momentum_4[8];
+ float32x4_t variance_4[8];
+ float32x4_t param_4[8];
+#pragma unroll 4
+ for (int j = 0; j < 8; j++) {
+ grad_4[j] = simd_load_offset(grads, grad_dtype, i + SIMD_WIDTH * j);
+ if (loss_scale > 0) {
+ float32x4_t loss_scale_vec = simd_set(loss_scale);
+ grad_4[j] = vdivq_f32(grad_4[j], loss_scale_vec);
+ }
+ momentum_4[j] =
+ simd_load_offset(_exp_avg, exp_avg_dtype, i + SIMD_WIDTH * j);
+ variance_4[j] =
+ simd_load_offset(_exp_avg_sq, exp_avg_sq_dtype, i + SIMD_WIDTH * j);
+ param_4[j] = simd_load_offset(_params, param_dtype, i + SIMD_WIDTH * j);
+ if (_weight_decay > 0 && !_adamw_mode) {
+ grad_4[j] = vfmaq_f32(grad_4[j], param_4[j], weight_decay_4);
+ }
+ momentum_4[j] = vmulq_f32(momentum_4[j], betta1_4);
+ momentum_4[j] = vfmaq_f32(momentum_4[j], grad_4[j], betta1_minus1_4);
+ variance_4[j] = vmulq_f32(variance_4[j], betta2_4);
+ grad_4[j] = vmulq_f32(grad_4[j], grad_4[j]);
+ variance_4[j] = vfmaq_f32(variance_4[j], grad_4[j], betta2_minus1_4);
+ grad_4[j] = vsqrtq_f32(variance_4[j]);
+ grad_4[j] = vfmaq_f32(eps_4, grad_4[j], bias2_sqrt);
+ grad_4[j] = vdivq_f32(momentum_4[j], grad_4[j]);
+ if (_weight_decay > 0 && _adamw_mode) {
+ param_4[j] = vfmaq_f32(param_4[j], param_4[j], weight_decay_4);
+ }
+ param_4[j] = vfmaq_f32(param_4[j], grad_4[j], step_size_4);
+ simd_store_offset(_params, param_dtype, param_4[j], i + SIMD_WIDTH * j);
+ simd_store_offset(_exp_avg, exp_avg_dtype, momentum_4[j],
+ i + SIMD_WIDTH * j);
+ simd_store_offset(_exp_avg_sq, exp_avg_sq_dtype, variance_4[j],
+ i + SIMD_WIDTH * j);
+ }
+ }
+ }
+#endif
+ if (_param_size > rounded_size) {
+ Step_4(scalar_seek_offset(_params, param_dtype, rounded_size),
+ scalar_seek_offset(grads, grad_dtype, rounded_size),
+ scalar_seek_offset(_exp_avg, exp_avg_dtype, rounded_size),
+ scalar_seek_offset(_exp_avg_sq, exp_avg_sq_dtype, rounded_size),
+ (_param_size - rounded_size), param_dtype, grad_dtype, exp_avg_dtype,
+ exp_avg_sq_dtype, loss_scale);
+ }
+}
+
+void AdamOptimizer::step(size_t step, float lr, float beta1, float beta2,
+ float epsilon, float weight_decay,
+ bool bias_correction, torch::Tensor ¶ms,
+ torch::Tensor &grads, torch::Tensor &exp_avg,
+ torch::Tensor &exp_avg_sq, float loss_scale) {
+ auto params_c = params.contiguous();
+ auto grads_c = grads.contiguous();
+ auto exp_avg_c = exp_avg.contiguous();
+ auto exp_avg_sq_c = exp_avg_sq.contiguous();
+
+ this->IncrementStep(step, beta1, beta2);
+ this->update_state(lr, epsilon, weight_decay, bias_correction);
+ this->Step_8(params_c.data_ptr(), grads_c.data_ptr(), exp_avg_c.data_ptr(),
+ exp_avg_sq_c.data_ptr(), params_c.numel(),
+ params_c.scalar_type(), grads_c.scalar_type(),
+ exp_avg_c.scalar_type(), exp_avg_sq_c.scalar_type(), loss_scale);
+}
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ py::class_(m, "CPUAdamOptimizer")
+ .def(py::init())
+ .def("step", &AdamOptimizer::step);
+}
diff --git a/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h
new file mode 100644
index 000000000000..c731850edc31
--- /dev/null
+++ b/colossalai/kernel/cuda_native/csrc/cpu_adam_arm.h
@@ -0,0 +1,201 @@
+#pragma once
+#include
+#include
+
+#include
+
+#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
+#define TILE (128 * 1024 * 1024)
+
+#if defined(__aarch64__)
+#include
+#define SIMD_WIDTH 4
+
+inline float32x4_t simd_load_offset(const void *ptr, at::ScalarType dtype,
+ size_t offset) {
+ switch (dtype) {
+ case at::ScalarType::Float: {
+ auto ptr_f = reinterpret_cast(ptr);
+ return vld1q_f32(ptr_f + offset);
+ }
+ case at::ScalarType::Half: {
+ auto ptr_h = reinterpret_cast(ptr);
+ return vcvt_f32_f16(vld1_f16(ptr_h + offset));
+ }
+ // case at::ScalarType::BFloat16: {
+ // auto ptr_b = reinterpret_cast(ptr);
+ // return vcvt_f32_bf16(vld1_bf16(ptr_b + offset));
+ // }
+ default:
+ AT_ERROR("Unsupported dtype");
+ break;
+ }
+}
+inline float32x4_t simd_load(void const *ptr, at::ScalarType dtype) {
+ return simd_load_offset(ptr, dtype, 0);
+}
+
+inline void simd_store_offset(void *ptr, at::ScalarType dtype, float32x4_t data,
+ size_t offset) {
+ switch (dtype) {
+ case at::ScalarType::Float: {
+ auto ptr_f = reinterpret_cast(ptr);
+ vst1q_f32(ptr_f + offset, data);
+ break;
+ }
+ case at::ScalarType::Half: {
+ auto ptr_h = reinterpret_cast(ptr);
+ vst1_f16(ptr_h + offset, vcvt_f16_f32(data));
+ break;
+ }
+ // case at::ScalarType::BFloat16: {
+ // auto ptr_b = reinterpret_cast(ptr);
+ // vst1_bf16(ptr_b + offset, vcvt_bf16_f32(data));
+ // break;
+ // }
+ default:
+ AT_ERROR("Unsupported dtype");
+ break;
+ }
+}
+
+inline void simd_store(void *ptr, at::ScalarType dtype, float32x4_t data) {
+ return simd_store_offset(ptr, dtype, data, 0);
+}
+
+inline float32x4_t simd_set(float value) {
+ auto val = static_cast(value);
+ return vdupq_n_f32(val);
+}
+
+#endif
+
+inline float scalar_load_offset(const void *ptr, at::ScalarType dtype,
+ size_t offset) {
+ switch (dtype) {
+ case at::ScalarType::Float:
+ return *(reinterpret_cast(ptr) + offset);
+ case at::ScalarType::Half:
+ return static_cast(
+ *(reinterpret_cast(ptr) + offset));
+ // case at::ScalarType::BFloat16:
+ // return static_cast(
+ // *(reinterpret_cast(ptr) + offset));
+ default:
+ AT_ERROR("Unsupported dtype");
+ break;
+ }
+}
+
+inline void scalar_store_offset(void *ptr, at::ScalarType dtype, float data,
+ size_t offset) {
+ switch (dtype) {
+ case at::ScalarType::Float:
+ *(reinterpret_cast(ptr) + offset) = data;
+ break;
+ case at::ScalarType::Half:
+ *(reinterpret_cast(ptr) + offset) = data;
+ break;
+ // case at::ScalarType::BFloat16:
+ // *(reinterpret_cast(ptr) + offset) = data;
+ break;
+ default:
+ AT_ERROR("Unsupported dtype");
+ break;
+ }
+}
+
+inline void *scalar_seek_offset(void *ptr, at::ScalarType dtype,
+ size_t offset) {
+ switch (dtype) {
+ case at::ScalarType::Float:
+ return reinterpret_cast(ptr) + offset;
+ case at::ScalarType::Half:
+ return reinterpret_cast(ptr) + offset;
+ // case at::ScalarType::BFloat16:
+ // return reinterpret_cast(ptr) + offset;
+ default:
+ AT_ERROR("Unsupported dtype");
+ break;
+ }
+}
+#define STEP(SPAN) \
+ void Step_##SPAN(void *_params, void *grads, void *_exp_avg, \
+ void *_exp_avg_sq, size_t _param_size, \
+ at::ScalarType param_dtype, at::ScalarType grad_dtype, \
+ at::ScalarType exp_avg_dtype, \
+ at::ScalarType exp_avg_sq_dtype, float loss_scale = -1);
+
+class AdamOptimizer {
+ private:
+ float _alpha;
+ float _betta1;
+ float _betta2;
+ float _eps;
+ float _weight_decay;
+
+ float _betta1_t;
+ float _betta2_t;
+ size_t _step;
+
+ float _bias_correction1;
+ float _bias_correction2;
+
+ bool _adamw_mode;
+
+ public:
+ AdamOptimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
+ float eps = 1e-8, float weight_decay = 0,
+ bool adamw_mode = true)
+ : _alpha(alpha),
+ _betta1(betta1),
+ _betta2(betta2),
+ _eps(eps),
+ _weight_decay(weight_decay),
+ _betta1_t(1.0),
+ _betta2_t(1.0),
+ _step(0),
+ _adamw_mode(adamw_mode) {}
+ ~AdamOptimizer() {}
+
+ STEP(1)
+ STEP(4)
+ STEP(8)
+ inline void IncrementStep(size_t step, float beta1, float beta2) {
+ if (beta1 != _betta1 || beta2 != _betta2) {
+ _step = step;
+ _betta1 = beta1;
+ _betta2 = beta2;
+ _betta1_t = std::pow(_betta1, step);
+ _betta2_t = std::pow(_betta2, step);
+ } else {
+ _step++;
+ if (_step != step) {
+ _betta1_t = std::pow(_betta1, step);
+ _betta2_t = std::pow(_betta2, step);
+ _step = step;
+ } else {
+ _betta1_t *= _betta1;
+ _betta2_t *= _betta2;
+ }
+ }
+ }
+ inline void update_state(float lr, float epsilon, float weight_decay,
+ bool bias_correction) {
+ _alpha = lr;
+ _eps = epsilon;
+ _weight_decay = weight_decay;
+
+ _bias_correction1 = 1.0f;
+ _bias_correction2 = 1.0f;
+ if (bias_correction == 1) {
+ _bias_correction1 = 1 - _betta1_t;
+ _bias_correction2 = 1 / sqrt(1 - _betta2_t);
+ }
+ }
+
+ void step(size_t step, float lr, float beta1, float beta2, float epsilon,
+ float weight_decay, bool bias_correction, torch::Tensor ¶ms,
+ torch::Tensor &grads, torch::Tensor &exp_avg,
+ torch::Tensor &exp_avg_sq, float loss_scale);
+};
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
index 9a6a8ebc3983..e5ac17308640 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
+++ b/colossalai/kernel/cuda_native/csrc/kernels/cuda_util.cu
@@ -1,6 +1,6 @@
#include
#include
-
+#include
#include "cuda_util.h"
/* GPU function guard */
diff --git a/colossalai/kernel/cuda_native/mha/utils.py b/colossalai/kernel/cuda_native/mha/utils.py
index fe31921b961b..5f01e3ef327d 100644
--- a/colossalai/kernel/cuda_native/mha/utils.py
+++ b/colossalai/kernel/cuda_native/mha/utils.py
@@ -5,7 +5,7 @@
import torch.nn.functional as F
from einops import rearrange
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
class Unpad(torch.autograd.Function):
diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py
index 1fe292289f3d..20da71d394bd 100644
--- a/colossalai/kernel/triton/__init__.py
+++ b/colossalai/kernel/triton/__init__.py
@@ -2,7 +2,6 @@
import triton
HAS_TRITON = True
-
except ImportError:
HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.")
diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py
index 1b4f6e44b0f2..3d9a23d2f5d2 100644
--- a/colossalai/kernel/triton/context_attention.py
+++ b/colossalai/kernel/triton/context_attention.py
@@ -5,7 +5,6 @@
try:
import triton
import triton.language as tl
-
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
@@ -16,127 +15,224 @@
this function is modified from
https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10
"""
+ if triton.__version__ < "2.1.0":
+ @triton.jit
+ def _context_flash_attention_kernel(
+ Q,
+ K,
+ V,
+ sm_scale,
+ B_Start_Loc,
+ B_Seqlen,
+ TMP,
+ alibi_ptr,
+ Out,
+ stride_qbs,
+ stride_qh,
+ stride_qd,
+ stride_kbs,
+ stride_kh,
+ stride_kd,
+ stride_vbs,
+ stride_vh,
+ stride_vd,
+ stride_obs,
+ stride_oh,
+ stride_od,
+ stride_tmp_b,
+ stride_tmp_h,
+ stride_tmp_s,
+ # suggtest set-up 64, 128, 256, 512
+ BLOCK_M: tl.constexpr,
+ BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ ):
+ batch_id = tl.program_id(0)
+ cur_head = tl.program_id(1)
+ start_m = tl.program_id(2)
+
+ # initialize offsets
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
- @triton.jit
- def _context_flash_attention_kernel(
- Q,
- K,
- V,
- sm_scale,
- B_Start_Loc,
- B_Seqlen,
- TMP,
- alibi_ptr,
- Out,
- stride_qbs,
- stride_qh,
- stride_qd,
- stride_kbs,
- stride_kh,
- stride_kd,
- stride_vbs,
- stride_vh,
- stride_vd,
- stride_obs,
- stride_oh,
- stride_od,
- stride_tmp_b,
- stride_tmp_h,
- stride_tmp_s,
- # suggtest set-up 64, 128, 256, 512
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- ):
- batch_id = tl.program_id(0)
- cur_head = tl.program_id(1)
- start_m = tl.program_id(2)
-
- # initialize offsets
- offs_n = tl.arange(0, BLOCK_N)
- offs_d = tl.arange(0, BLOCK_DMODEL)
- offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
-
- # get batch info
- cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
- cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
- block_start_loc = BLOCK_M * start_m
-
- load_p_ptrs = (
- Q
- + (cur_batch_start_index + offs_m[:, None]) * stride_qbs
- + cur_head * stride_qh
- + offs_d[None, :] * stride_qd
- )
- q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
-
- k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
- v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
- t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
-
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
- l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
- acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
-
- if alibi_ptr is not None:
- alibi_m = tl.load(alibi_ptr + cur_head)
-
- block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
-
- for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
- start_n = tl.multiple_of(start_n, BLOCK_N)
- k = tl.load(
- k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
- mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
- other=0.0,
+ # get batch info
+ cur_batch_seq_len = tl.load(B_Seqlen + batch_id)
+ cur_batch_start_index = tl.load(B_Start_Loc + batch_id)
+ block_start_loc = BLOCK_M * start_m
+
+ load_p_ptrs = (
+ Q
+ + (cur_batch_start_index + offs_m[:, None]) * stride_qbs
+ + cur_head * stride_qh
+ + offs_d[None, :] * stride_qd
)
+ q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
+
+ k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
+ v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
+ t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
- qk += tl.dot(q, k)
- qk *= sm_scale
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
if alibi_ptr is not None:
- alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
- qk -= alibi_loc * alibi_m
-
- qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
-
- m_ij = tl.max(qk, 1)
- p = tl.exp(qk - m_ij[:, None])
- l_ij = tl.sum(p, 1)
- # -- update m_i and l_i
- m_i_new = tl.maximum(m_i, m_ij)
- alpha = tl.exp(m_i - m_i_new)
- beta = tl.exp(m_ij - m_i_new)
- l_i_new = alpha * l_i + beta * l_ij
- # -- update output accumulator --
- # scale p
- p_scale = beta / l_i_new
- p = p * p_scale[:, None]
- # scale acc
- acc_scale = l_i / l_i_new * alpha
- tl.store(t_ptrs, acc_scale)
- acc_scale = tl.load(t_ptrs)
- acc = acc * acc_scale[:, None]
- # update acc
- v = tl.load(
- v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
- mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
- other=0.0,
+ alibi_m = tl.load(alibi_ptr + cur_head)
+
+ block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
+
+ for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+ k = tl.load(
+ k_ptrs + (cur_batch_start_index + start_n) * stride_kbs,
+ mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
+ other=0.0,
+ )
+
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, k)
+ qk *= sm_scale
+
+ if alibi_ptr is not None:
+ alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
+ qk -= alibi_loc * alibi_m
+
+ qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
+
+ m_ij = tl.max(qk, 1)
+ p = tl.exp(qk - m_ij[:, None])
+ l_ij = tl.sum(p, 1)
+ # -- update m_i and l_i
+ m_i_new = tl.maximum(m_i, m_ij)
+ alpha = tl.exp(m_i - m_i_new)
+ beta = tl.exp(m_ij - m_i_new)
+ l_i_new = alpha * l_i + beta * l_ij
+ # -- update output accumulator --
+ # scale p
+ p_scale = beta / l_i_new
+ p = p * p_scale[:, None]
+ # scale acc
+ acc_scale = l_i / l_i_new * alpha
+ tl.store(t_ptrs, acc_scale)
+ acc_scale = tl.load(t_ptrs)
+ acc = acc * acc_scale[:, None]
+ # update acc
+ v = tl.load(
+ v_ptrs + (cur_batch_start_index + start_n) * stride_vbs,
+ mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
+ other=0.0,
+ )
+
+ p = p.to(v.dtype)
+ acc += tl.dot(p, v)
+ # update m_i and l_i
+ l_i = l_i_new
+ m_i = m_i_new
+
+ off_o = (
+ (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
)
+ out_ptrs = Out + off_o
+ tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
+ return
+ else:
+ # this function is modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L11
+ @triton.jit
+ def _context_flash_attention_kernel_2(
+ Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen,
+ Out,
+ kv_group_num,
+ stride_qbs, stride_qh, stride_qd,
+ stride_kbs, stride_kh, stride_kd,
+ stride_vbs, stride_vh, stride_vd,
+ stride_obs, stride_oh, stride_od,
+ BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ ):
+ cur_batch = tl.program_id(0)
+ cur_head = tl.program_id(1)
+ start_m = tl.program_id(2)
+
+ if kv_group_num is not None:
+ cur_kv_head = cur_head // kv_group_num
- p = p.to(v.dtype)
- acc += tl.dot(p, v)
- # update m_i and l_i
- l_i = l_i_new
- m_i = m_i_new
-
- off_o = (
- (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
- )
- out_ptrs = Out + off_o
- tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
- return
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
+ cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+
+ block_start_loc = BLOCK_M * start_m
+
+ # initialize offsets
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_DMODEL)
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd
+ if kv_group_num is None or kv_group_num == 1:
+ off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd
+ off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd
+ else:
+ off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd
+ off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd
+
+ q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
+
+ k_ptrs = K + off_k
+ v_ptrs = V + off_v
+
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+
+ if Alibi is not None:
+ alibi_m = tl.load(Alibi + cur_head)
+
+ block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
+
+ for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+ # -- compute qk ----
+ k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
+ mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)
+
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, k)
+ qk *= sm_scale
+
+ if Alibi is not None:
+ alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :])
+ qk -= alibi_loc * alibi_m
+
+ qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
+
+ m_ij = tl.max(qk, 1)
+ p = tl.exp(qk - m_ij[:, None])
+ l_ij = tl.sum(p, 1)
+ # -- update m_i and l_i
+ m_i_new = tl.maximum(m_i, m_ij)
+ alpha = tl.exp(m_i - m_i_new)
+ beta = tl.exp(m_ij - m_i_new)
+ l_i_new = alpha * l_i + beta * l_ij
+ # -- update output accumulator --
+ # scale p
+ p_scale = beta / l_i_new
+ p = p * p_scale[:, None]
+ # scale acc
+ acc_scale = l_i / l_i_new * alpha
+ acc = acc * acc_scale[:, None]
+ # update acc
+ v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
+ mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
+
+ p = p.to(v.dtype)
+ acc += tl.dot(p, v)
+ # update m_i and l_i
+ l_i = l_i_new
+ m_i = m_i_new
+ # initialize pointers to output
+ off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od
+ out_ptrs = Out + off_o
+ tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
+ return
@torch.no_grad()
def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None):
@@ -153,41 +249,65 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al
grid = (batch, head, triton.cdiv(max_input_len, BLOCK))
num_warps = 4 if Lk <= 64 else 8
-
- tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
-
- _context_flash_attention_kernel[grid](
- q,
- k,
- v,
- sm_scale,
- b_start_loc,
- b_seq_len,
- tmp,
- alibi,
- o,
- q.stride(0),
- q.stride(1),
- q.stride(2),
- k.stride(0),
- k.stride(1),
- k.stride(2),
- v.stride(0),
- v.stride(1),
- v.stride(2),
- o.stride(0),
- o.stride(1),
- o.stride(2),
- tmp.stride(0),
- tmp.stride(1),
- tmp.stride(2),
- # manually setting this blcok num, we can use tuning config to futher speed-up
- BLOCK_M=BLOCK,
- BLOCK_DMODEL=Lk,
- BLOCK_N=BLOCK,
- num_warps=num_warps,
- num_stages=1,
- )
+
+ if triton.__version__ < "2.1.0":
+ tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
+ _context_flash_attention_kernel[grid](
+ q,
+ k,
+ v,
+ sm_scale,
+ b_start_loc,
+ b_seq_len,
+ tmp,
+ alibi,
+ o,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ o.stride(0),
+ o.stride(1),
+ o.stride(2),
+ tmp.stride(0),
+ tmp.stride(1),
+ tmp.stride(2),
+ # manually setting this blcok num, we can use tuning config to futher speed-up
+ BLOCK_M=BLOCK,
+ BLOCK_DMODEL=Lk,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ else:
+ _context_flash_attention_kernel_2[grid](
+ q, k, v, sm_scale, alibi, b_start_loc, b_seq_len,
+ o,
+ None,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ o.stride(0),
+ o.stride(1),
+ o.stride(2),
+ BLOCK_M=BLOCK,
+ BLOCK_DMODEL=Lk,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+
return
@torch.no_grad()
@@ -207,36 +327,67 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32)
num_warps = 4 if Lk <= 64 else 8
# num_warps = 4
- _context_flash_attention_kernel[grid](
- q,
- k,
- v,
- sm_scale,
- b_start_loc,
- b_seq_len,
- tmp,
- None,
- o,
- q.stride(0),
- q.stride(1),
- q.stride(2),
- k.stride(0),
- k.stride(1),
- k.stride(2),
- v.stride(0),
- v.stride(1),
- v.stride(2),
- o.stride(0),
- o.stride(1),
- o.stride(2),
- tmp.stride(0),
- tmp.stride(1),
- tmp.stride(2),
- BLOCK_M=BLOCK,
- BLOCK_DMODEL=Lk,
- BLOCK_N=BLOCK,
- num_warps=num_warps,
- num_stages=1,
- )
+ if triton.__version__ < "2.1.0":
+ _context_flash_attention_kernel[grid](
+ q,
+ k,
+ v,
+ sm_scale,
+ b_start_loc,
+ b_seq_len,
+ tmp,
+ None,
+ o,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ o.stride(0),
+ o.stride(1),
+ o.stride(2),
+ tmp.stride(0),
+ tmp.stride(1),
+ tmp.stride(2),
+ BLOCK_M=BLOCK,
+ BLOCK_DMODEL=Lk,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,
+ )
+ else:
+ kv_group_num = q.shape[1] // k.shape[1]
+ _context_flash_attention_kernel_2[grid](
+ q,
+ k,
+ v,
+ sm_scale,
+ None,
+ b_start_loc,
+ b_seq_len,
+ o,
+ kv_group_num,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ o.stride(0),
+ o.stride(1),
+ o.stride(2),
+ BLOCK_M=BLOCK,
+ BLOCK_DMODEL=Lk,
+ BLOCK_N=BLOCK,
+ num_warps=num_warps,
+ num_stages=1,)
+
return
\ No newline at end of file
diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py
index 0ce6b09e54dc..b8e6ab1d05ad 100644
--- a/colossalai/kernel/triton/copy_kv_cache_dest.py
+++ b/colossalai/kernel/triton/copy_kv_cache_dest.py
@@ -10,7 +10,6 @@
print("please install triton from https://github.com/openai/triton")
if HAS_TRITON:
-
# adapted from https://github.com/ModelTC/lightllm/blob/5c559dd7981ed67679a08a1e09a88fb4c1550b3a/lightllm/common/triton_kernel/destindex_copy_kv.py
@triton.jit
def _fwd_copy_kv_cache_dest(
@@ -53,7 +52,6 @@ def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out):
assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out"
num_warps = 2
-
_fwd_copy_kv_cache_dest[(seq_len,)](
k_ptr,
dest_index_ptr,
diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py
new file mode 100644
index 000000000000..9b7b27fa1f49
--- /dev/null
+++ b/colossalai/kernel/triton/flash_decoding.py
@@ -0,0 +1,50 @@
+# adepted from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8/lightllm/models/llama/triton_kernel/flash_decoding.py
+import torch
+try:
+ from lightllm.models.llama.triton_kernel.flash_decoding_stage1 import flash_decode_stage1
+ from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
+ HAS_LIGHTLLM_KERNEL = True
+except:
+ print("install lightllm from https://github.com/ModelTC/lightllm/blob/ece7b43f8a6dfa74027adc77c2c176cff28c76c8")
+ HAS_LIGHTLLM_KERNEL = False
+
+
+if HAS_LIGHTLLM_KERNEL:
+ def token_flash_decoding(q, o_tensor, infer_state, q_head_num, head_dim, cache_k, cache_v):
+ BLOCK_SEQ = 256
+ batch_size = infer_state.batch_size
+ max_len_in_batch = infer_state.max_len_in_batch
+
+
+ calcu_shape1 = (batch_size, q_head_num, head_dim)
+
+ if getattr(infer_state, 'mid_o', None) is None:
+ infer_state.mid_o = torch.empty([batch_size,
+ q_head_num,
+ max_len_in_batch // BLOCK_SEQ + 1,
+ head_dim],
+ dtype=torch.float32,
+ device="cuda")
+ infer_state.mid_o_logexpsum = torch.empty([batch_size,
+ q_head_num,
+ max_len_in_batch // BLOCK_SEQ + 1],
+ dtype=torch.float32,
+ device="cuda")
+
+ mid_o = infer_state.mid_o
+ mid_o_logexpsum = infer_state.mid_o_logexpsum
+
+ flash_decode_stage1(q.view(calcu_shape1),
+ cache_k,
+ cache_v,
+ infer_state.block_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ mid_o,
+ mid_o_logexpsum,
+ BLOCK_SEQ)
+ flash_decode_stage2(mid_o,
+ mid_o_logexpsum,
+ infer_state.seq_len,
+ o_tensor.view(calcu_shape1),
+ BLOCK_SEQ)
diff --git a/colossalai/kernel/triton/llama_act_combine_kernel.py b/colossalai/kernel/triton/llama_act_combine_kernel.py
new file mode 100644
index 000000000000..45996c0dca53
--- /dev/null
+++ b/colossalai/kernel/triton/llama_act_combine_kernel.py
@@ -0,0 +1,185 @@
+from functools import reduce
+from typing import Any, Tuple
+
+import torch
+from torch import Tensor
+from torch.cuda.amp import custom_bwd, custom_fwd
+
+try:
+ import triton
+ import triton.language as tl
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+
+if HAS_TRITON:
+ PRECISION_MAP = {
+ "fp32": (0, torch.float32),
+ "fp16": (1, torch.float16),
+ "bf16": (2, torch.bfloat16),
+ }
+
+ @triton.jit
+ def _llama_act_combine_forward(
+ X_GATE1,
+ X_GATE2,
+ X_UP,
+ Y,
+ stride, # how much to increase the pointer when moving by 1 row
+ N, # number of columns in X
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ X_GATE1 += row * stride
+ X_GATE2 += row * stride
+ X_UP += row * stride
+ Y += row * stride
+
+ # do activation and combine, and store in y
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
+ x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
+ x_up = tl.load(X_UP + cols, mask=mask, other=0.)
+ x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
+ y = x_gate1 * x_gate2 * x_gate2_sigmoid * x_up
+ # Write output
+ tl.store(Y + cols, y, mask=mask)
+
+ @triton.jit
+ def _llama_act_combine_backward(
+ X_GATE1,
+ X_GATE2,
+ X_UP,
+ X_GATE1_GRAD,
+ X_GATE2_GRAD,
+ X_UP_GRAD,
+ Y_GRAD,
+ stride, # how much to increase the pointer when moving by 1 row
+ N, # number of columns in X
+ BLOCK_SIZE: tl.constexpr,
+ ):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ X_GATE1 += row * stride
+ X_GATE2 += row * stride
+ X_UP += row * stride
+ X_GATE1_GRAD += row * stride
+ X_GATE2_GRAD += row * stride
+ X_UP_GRAD += row * stride
+ Y_GRAD += row * stride
+
+ # do activation and combine, and store in y
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ x_gate1 = tl.load(X_GATE1 + cols, mask=mask, other=0.)
+ x_gate2 = tl.load(X_GATE2 + cols, mask=mask, other=0.)
+ x_up = tl.load(X_UP + cols, mask=mask, other=0.)
+ y_grad = tl.load(Y_GRAD + cols, mask=mask, other=0.)
+
+ # forward: y = x_gate1 * x_gate2 * tl.sigmoid(x_gate2) * x_up
+ x_gate2_sigmoid = tl.sigmoid(x_gate2.to(tl.float32)).to(x_gate2.dtype)
+ x_gate2_act = y_grad * x_gate2 * x_gate2_sigmoid
+ x_up_grad = x_gate2_act * x_gate1
+ x_gate1_grad = x_gate2_act * x_up
+ # grad(x*sigmoid(x)) = sigmoid(x) + x * sigmoid(x) * [1 − sigmoid(x)]
+ # = sigmoid(x) * {1 + x * [(1 − sigmoid(x)]}
+ x_gate2_grad = (y_grad * x_gate1 * x_up) * x_gate2_sigmoid * (1 + x_gate2 * (1 - x_gate2_sigmoid))
+
+ # Write output
+ tl.store(X_GATE1_GRAD + cols, x_gate1_grad, mask=mask)
+ tl.store(X_GATE2_GRAD + cols, x_gate2_grad, mask=mask)
+ tl.store(X_UP_GRAD + cols, x_up_grad, mask=mask)
+
+ class LlamaActCombine(torch.autograd.Function):
+ """
+ act(x_gate) * x_up
+
+ Args:
+ x_gate (torch.Tensor): (b, l, 2d) x_gate
+ x_up (torch.Tensor): (b, l, d) x_up
+ activation (str): only support swiglu
+ precision (str): fp32, fp16, bf16
+ """
+
+ @staticmethod
+ @custom_fwd
+ def forward(ctx: Any, x_gate: torch.Tensor, x_up: torch.Tensor, activation: str = "swiglu") -> torch.Tensor:
+ """
+ act(x_gate) * x_up
+
+ Args:
+ x_gate (torch.Tensor): (b, l, 2d) x gate
+ x_up (torch.Tensor): (b, l, d) x up
+ activation (str): only support swiglu
+ """
+ assert activation == "swiglu", "Only swiglu is supported"
+
+ # split x gate
+ assert x_gate.shape[-1] % 2 == 0, "axis size must be divisible by 2"
+ x_gate1, x_gate2 = torch.split(x_gate, x_gate.shape[-1] // 2, -1)
+ x_gate1 = x_gate1.contiguous()
+ x_gate2 = x_gate2.contiguous()
+ if not x_up.is_contiguous():
+ x_up = x_up.contiguous()
+ # assert shape
+ assert x_gate1.shape == x_gate2.shape == x_up.shape
+
+ # add ctx for backward
+ if x_gate.requires_grad:
+ ctx.save_for_backward(x_gate1, x_gate2, x_up)
+
+ # allocate output
+ y = torch.empty_like(x_up)
+ M, N = reduce(lambda x, y: x * y, x_up.shape[:-1]), x_up.shape[-1]
+
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x_gate.element_size()
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ if N > BLOCK_SIZE:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ # heuristics for number of warps
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ # restore setting
+ ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps = M, N, BLOCK_SIZE, num_warps
+ # enqueue kernel
+ _llama_act_combine_forward[(M,)](x_gate1,
+ x_gate2,
+ x_up,
+ y,
+ x_up.stride(-2),
+ N,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=num_warps)
+ return y
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, Tensor, None, None]:
+ # restore from ctx
+ (x_gate1, x_gate2, x_up) = ctx.saved_tensors
+ M, N, BLOCK_SIZE, num_warps = ctx.M, ctx.N, ctx.BLOCK_SIZE, ctx.num_warps
+
+ # init grad
+ y_grad = grad_outputs[0]
+ x_gate1_grad, x_gate2_grad, x_up_grad = torch.empty_like(x_gate1), torch.empty_like(
+ x_gate2), torch.empty_like(x_up)
+
+ # enqueue kernel
+ _llama_act_combine_backward[(M,)](x_gate1,
+ x_gate2,
+ x_up,
+ x_gate1_grad,
+ x_gate2_grad,
+ x_up_grad,
+ y_grad,
+ x_up.stride(-2),
+ N,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=num_warps)
+ x_gate_grad = torch.cat([x_gate1_grad, x_gate2_grad], dim=-1)
+ return x_gate_grad, x_up_grad, None, None
diff --git a/colossalai/kernel/triton/token_attention_kernel.py b/colossalai/kernel/triton/token_attention_kernel.py
index 8dc919bad125..de2003748e65 100644
--- a/colossalai/kernel/triton/token_attention_kernel.py
+++ b/colossalai/kernel/triton/token_attention_kernel.py
@@ -13,17 +13,7 @@
print("please install triton from https://github.com/openai/triton")
try:
- from lightllm.models.llama2.triton_kernel.token_attention_nopad_att1 import (
- token_att_fwd as lightllm_llama2_token_att_fwd,
- )
- from lightllm.models.llama2.triton_kernel.token_attention_nopad_reduceV import (
- token_att_fwd2 as lightllm_llama2_token_att_fwd2,
- )
- from lightllm.models.llama2.triton_kernel.token_attention_nopad_softmax import (
- token_softmax_fwd as lightllm_llama2_token_softmax_fwd,
- )
-
- from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fw2
+ from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 as lightllm_llama_token_att_fwd2
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_llama_token_att_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd as lightllm_llama_token_softmax_fwd
from lightllm.models.bloom.triton_kernel.token_attention_nopad_att1 import token_att_fwd as lightllm_bloom_token_att_fwd
@@ -72,7 +62,7 @@ def token_attention_fwd(
lightllm_llama_token_softmax_fwd(att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch)
att_m_tensor = None
- lightllm_llama_token_att_fw2(
+ lightllm_llama_token_att_fwd2(
prob, v, attn_out.view(calcu_shape1), kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_len_in_batch
)
prob = None
@@ -203,7 +193,7 @@ def token_attn(
calcu_shape1 = (batch_size, head_num, head_dim)
att_m_tensor = torch.empty((head_num, total_token_num), dtype=q.dtype, device="cuda")
- lightllm_llama2_token_att_fwd(
+ lightllm_llama_token_att_fwd(
q,
k,
att_m_tensor,
@@ -215,12 +205,12 @@ def token_attn(
if triton.__version__ == "2.0.0":
prob = torch.empty_like(att_m_tensor)
- lightllm_llama2_token_softmax_fwd(
+ lightllm_llama_token_softmax_fwd(
att_m_tensor, kv_cache_start_loc, kv_cache_seq_len, prob, max_len_in_batch
)
att_m_tensor = None
- lightllm_llama2_token_att_fwd2(
+ lightllm_llama_token_att_fwd2(
prob,
v,
attn_out.view(calcu_shape1),
diff --git a/colossalai/legacy/amp/torch_amp/torch_amp.py b/colossalai/legacy/amp/torch_amp/torch_amp.py
index ced5cc3e6647..0a8d09be21ea 100644
--- a/colossalai/legacy/amp/torch_amp/torch_amp.py
+++ b/colossalai/legacy/amp/torch_amp/torch_amp.py
@@ -1,7 +1,8 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-import torch.cuda.amp as torch_amp
+from colossalai.utils.device import autocast
+
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.loss import _Loss
@@ -70,7 +71,7 @@ def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
- @torch_amp.autocast()
+ @autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context
@@ -89,7 +90,7 @@ def __init__(self, loss: _Loss):
super().__init__()
self.loss = loss
- @torch_amp.autocast()
+ @autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context
diff --git a/colossalai/legacy/engine/gradient_handler/__init__.py b/colossalai/legacy/engine/gradient_handler/__init__.py
index 78928b138842..713df5a64783 100644
--- a/colossalai/legacy/engine/gradient_handler/__init__.py
+++ b/colossalai/legacy/engine/gradient_handler/__init__.py
@@ -1,6 +1,5 @@
from ._base_gradient_handler import BaseGradientHandler
from ._data_parallel_gradient_handler import DataParallelGradientHandler
-from ._moe_gradient_handler import MoeGradientHandler
from ._pipeline_parallel_gradient_handler import PipelineSharedModuleGradientHandler
from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
from ._zero_gradient_handler import ZeROGradientHandler
@@ -10,6 +9,5 @@
"DataParallelGradientHandler",
"ZeROGradientHandler",
"PipelineSharedModuleGradientHandler",
- "MoeGradientHandler",
"SequenceParallelGradientHandler",
]
diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule.py b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
index 4fc5040f6983..5fd5602e790c 100644
--- a/colossalai/legacy/engine/schedule/_pipeline_schedule.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule.py
@@ -12,7 +12,7 @@
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ._base_schedule import BaseSchedule
diff --git a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
index 867c3dfa819b..4cd7e47c37f1 100644
--- a/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
+++ b/colossalai/legacy/engine/schedule/_pipeline_schedule_v2.py
@@ -9,7 +9,7 @@
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ._pipeline_schedule import PipelineSchedule
diff --git a/colossalai/legacy/inference/README.md b/colossalai/legacy/inference/README.md
new file mode 100644
index 000000000000..f466f46c1629
--- /dev/null
+++ b/colossalai/legacy/inference/README.md
@@ -0,0 +1,143 @@
+# 🚀 Colossal-Inference
+
+## Table of contents
+
+## Introduction
+
+`Colossal Inference` is a module that contains colossal-ai designed inference framework, featuring high performance, steady and easy usability. `Colossal Inference` incorporated the advantages of the latest open-source inference systems, including LightLLM, TGI, vLLM, FasterTransformer and flash attention. while combining the design of Colossal AI, especially Shardformer, to reduce the learning curve for users.
+
+## Design
+
+Colossal Inference is composed of two main components:
+
+1. High performance kernels and ops: which are inspired from existing libraries and modified correspondingly.
+2. Efficient memory management mechanism:which includes the key-value cache manager, allowing for zero memory waste during inference.
+ 1. `cache manager`: serves as a memory manager to help manage the key-value cache, it integrates functions such as memory allocation, indexing and release.
+ 2. `batch_infer_info`: holds all essential elements of a batch inference, which is updated every batch.
+3. High-level inference engine combined with `Shardformer`: it allows our inference framework to easily invoke and utilize various parallel methods.
+ 1. `engine.TPInferEngine`: it is a high level interface that integrates with shardformer, especially for multi-card (tensor parallel) inference:
+ 2. `modeling.llama.LlamaInferenceForwards`: contains the `forward` methods for llama inference. (in this case : llama)
+ 3. `policies.llama.LlamaModelInferPolicy` : contains the policies for `llama` models, which is used to call `shardformer` and segmentate the model forward in tensor parallelism way.
+
+## Pipeline of inference:
+
+In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
+
+
+
+## Roadmap of our implementation
+
+- [x] Design cache manager and batch infer state
+- [x] Design TpInference engine to integrates with `Shardformer`
+- [x] Register corresponding high-performance `kernel` and `ops`
+- [x] Design policies and forwards (e.g. `Llama` and `Bloom`)
+ - [x] policy
+ - [x] context forward
+ - [x] token forward
+ - [x] support flash-decoding
+- [ ] Replace the kernels with `faster-transformer` in token-forward stage
+- [ ] Support all models
+ - [x] Llama
+ - [x] Llama-2
+ - [x] Bloom
+ - [x] Chatglm2
+- [ ] Benchmarking for all models
+
+## Get started
+
+### Installation
+
+```bash
+pip install -e .
+```
+
+### Requirements
+
+dependencies
+
+```bash
+pytorch= 1.13.1 (gpu)
+cuda>= 11.6
+transformers= 4.30.2
+triton
+# for install flash-attention
+flash-attention
+
+# install lightllm since we depend on lightllm triton kernels
+git clone https://github.com/ModelTC/lightllm
+cd lightllm
+git checkout 28c1267cfca536b7b4f28e921e03de735b003039
+pip3 install -e .
+
+# also, install xformers from source:
+pip install ninja
+# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
+pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
+
+```
+
+### Docker
+
+You can use docker run to use docker container to set-up environment
+
+```
+# env: python==3.8, cuda 11.6, pytorch == 1.13.1 triton==2.0.0.dev20221202, vllm kernels support, flash-attention-2 kernels support
+docker pull hpcaitech/colossalai-inference:v2
+docker run -it --gpus all --name ANY_NAME -v $PWD:/workspace -w /workspace hpcaitech/colossalai-inference:v2 /bin/bash
+
+# enter into docker container
+cd /path/to/CollossalAI
+pip install -e .
+
+# install lightllm
+git clone https://github.com/ModelTC/lightllm
+cd lightllm
+git checkout 28c1267cfca536b7b4f28e921e03de735b003039
+pip3 install -e .
+
+# install xformers from source
+pip install ninja
+# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types
+pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
+```
+
+### Dive into fast-inference!
+
+example files are in
+
+```bash
+cd colossalai.examples
+python xx
+```
+
+## Performance
+
+### environment:
+
+We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `colossal-inference` and original `hugging-face torch fp16`.
+
+For various models, experiments were conducted using multiple batch sizes under the consistent model configuration of `7 billion(7b)` parameters, `1024` input length, and 128 output length. The obtained results are as follows (due to time constraints, the evaluation has currently been performed solely on the `A100` single GPU performance; multi-GPU performance will be addressed in the future):
+
+### Single GPU Performance:
+
+Currently the stats below are calculated based on A100 (single GPU), and we calculate token latency based on average values of context-forward and decoding forward process, which means we combine both of processes to calculate token generation times. We are actively developing new features and methods to further optimize the performance of LLM models. Please stay tuned.
+
+#### Llama
+
+| batch_size | 8 | 16 | 32 |
+| :---------------------: | :----: | :----: | :----: |
+| hugging-face torch fp16 | 199.12 | 246.56 | 278.4 |
+| colossal-inference | 326.4 | 582.72 | 816.64 |
+
+
+
+### Bloom
+
+| batch_size | 8 | 16 | 32 |
+| :---------------------: | :----: | :----: | :----: |
+| hugging-face torch fp16 | 189.68 | 226.66 | 249.61 |
+| colossal-inference | 323.28 | 538.52 | 611.64 |
+
+
+
+The results of more models are coming soon!
diff --git a/colossalai/legacy/inference/__init__.py b/colossalai/legacy/inference/__init__.py
new file mode 100644
index 000000000000..d5a988cfc6f0
--- /dev/null
+++ b/colossalai/legacy/inference/__init__.py
@@ -0,0 +1,4 @@
+from .hybridengine import CaiInferEngine
+from .hybridengine.polices import LlamaModelInferPolicy
+
+__all__ = ["CaiInferEngine", "LlamaModelInferPolicy"]
diff --git a/colossalai/legacy/inference/async_engine.py b/colossalai/legacy/inference/async_engine.py
new file mode 100644
index 000000000000..d0890ba3e9fc
--- /dev/null
+++ b/colossalai/legacy/inference/async_engine.py
@@ -0,0 +1,133 @@
+import asyncio
+
+from colossalai.inference.dynamic_batching.ray_dist_init import Driver
+
+from .dynamic_batching.io_struct import RequestOutput
+from .dynamic_batching.sampling_params import SamplingParams
+
+
+class RequestTracker:
+ """
+ A class for trace down all the requests, abstraction for async
+ """
+
+ def __init__(self) -> None:
+ self._requests: asyncio.Queue[str] = asyncio.Queue()
+ self._finished_requests: asyncio.Queue[RequestOutput] = asyncio.Queue()
+ self.new_requests_event = None
+
+ def __contains__(self, item):
+ return item in self._requests
+
+ def init_event(self):
+ self.new_requests_event = asyncio.Event()
+
+ def add_request(self, request_id: str):
+ """Add a request to be sent to the engine on the next background
+ loop iteration."""
+ self._requests.put_nowait(request_id)
+ self.new_requests_event.set() # NOTE: we may find a better way to clear this event
+
+ def add_stop(self):
+ """
+ Add a StopIteration flag to stop async generator.
+ """
+ self._finished_requests.put_nowait(StopIteration)
+ self.new_requests_event.clear()
+
+ def process_request_output(self, request_output: RequestOutput) -> None:
+ """Process a request output from the engine."""
+ self._finished_requests.put_nowait(request_output)
+
+ async def wait_for_new_requests(self):
+ await self.new_requests_event.wait()
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self) -> RequestOutput:
+ result = await self._finished_requests.get()
+ # print("result of ", result)
+ if result is StopIteration:
+ raise StopAsyncIteration
+ return result
+
+
+class Async_Engine:
+
+ """
+ Use an engine to launch RAY Driver --> RAY Worker --> Async_Manager
+ Background loop: inference reqs in waiting list (Listen)
+ Request Tracker: manage incoming requests and restore finished ones
+ Generate: exposed func for add new input and return finished ones
+ """
+
+ def __init__(
+ self,
+ router_config,
+ engine_config,
+ start_engine_loop: bool = True,
+ ) -> None:
+ self.driver = Driver(router_config=router_config, engine_config=engine_config)
+ self.background_loop = None
+ self.start_engine_loop = start_engine_loop
+ self._request_tracker = RequestTracker()
+
+ def _step(self):
+ """
+ Logic for handling requests
+ """
+ request_outputs = self.driver.step()
+ if request_outputs is not None:
+ for request_output in request_outputs:
+ self._request_tracker.process_request_output(request_output)
+ self._request_tracker.add_stop()
+
+ def abort_request(self, request_id: str):
+ self.driver.abort(request_id)
+
+ def _has_requests_in_progress(self):
+ return self.driver.is_running()
+
+ async def run_loop_fwd(self):
+ has_requests_in_progress = self._has_requests_in_progress()
+ while True:
+ if not has_requests_in_progress:
+ await self._request_tracker.wait_for_new_requests()
+ self._step()
+ await asyncio.sleep(0)
+
+ @property
+ def is_running(self):
+ return self.background_loop is not None and not self.background_loop.done()
+
+ def start_background_loop(self):
+ if self.is_running:
+ raise RuntimeError("Background loop is already running.")
+
+ self._request_tracker.init_event()
+
+ self.background_loop_unshielded = asyncio.get_event_loop().create_task(self.run_loop_fwd())
+ self.background_loop = asyncio.shield(self.background_loop_unshielded)
+
+ async def add_request(self, request_id: str, prompt: str, sampling_params: SamplingParams):
+ self.driver.add_input(request_id, prompt, sampling_params)
+ self._request_tracker.add_request(request_id)
+
+ async def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams):
+ """
+ The only exposed func, adding new request and return a async generator that yields the existing results.
+ """
+ try:
+ if not self.is_running:
+ self.start_background_loop()
+
+ await self.add_request(request_id, prompt, sampling_params)
+
+ async for request_output in self._request_tracker:
+ yield request_output
+
+ except (Exception, asyncio.CancelledError) as e:
+ # If there is an exception or coroutine is cancelled, abort the request.
+ self.abort_request(request_id)
+ raise e
diff --git a/colossalai/legacy/inference/async_manager.py b/colossalai/legacy/inference/async_manager.py
new file mode 100644
index 000000000000..60440a792f1c
--- /dev/null
+++ b/colossalai/legacy/inference/async_manager.py
@@ -0,0 +1,151 @@
+from typing import List
+
+from .dynamic_batching.io_struct import Batch, Req, RequestOutput
+from .manager import DynamicBatchManager
+from .tensor_parallel import TPInferEngine
+
+
+class Async_DynamicBatchManager(DynamicBatchManager):
+ def __init__(
+ self,
+ tp_engine: TPInferEngine,
+ max_total_token_num: int,
+ batch_max_tokens: int,
+ model: str,
+ tokenizer=None,
+ eos_id=None,
+ log_stats=True,
+ log_stats_interval=10,
+ running_batch: Batch = None,
+ waiting_req_list: List = [],
+ ):
+ """
+ Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
+ max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
+ batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
+ running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
+ eos_id : The end token of a seq
+ model: the model weight dir path, the app will load config, weights and tokenizer from this dir
+ log_stats : whether to log stats
+ log_stats_interval : log stats interval
+ running_batch : running batch
+ waiting_req_list : list of waiting requests, initialized before dynamic batch manager
+ """
+ super().__init__(
+ tp_engine,
+ max_total_token_num,
+ batch_max_tokens,
+ model,
+ tokenizer,
+ eos_id,
+ log_stats,
+ log_stats_interval,
+ running_batch,
+ waiting_req_list,
+ )
+
+ def _step(self):
+ """
+ Logic for handling requests
+ """
+ has_new_finished = False
+ if self.running_batch is None:
+ new_batch = self.req_queue.generate_new_batch(self.running_batch)
+ if new_batch is not None:
+ self.stats_tool.count_prompt_tokens(new_batch)
+ self.running_batch = new_batch
+ has_new_finished, outputs = self._prefill_batch(self.running_batch)
+ self._filter_runing_batch()
+ self.has_wait_tokens = 0
+
+ else:
+ if self.has_wait_tokens < self.max_wait_tokens:
+ self.stats_tool.count_output_tokens(self.running_batch)
+ has_new_finished, outputs = self._decode_batch(self.running_batch)
+ self._filter_runing_batch()
+ self.has_wait_tokens += 1
+
+ else:
+ new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
+ if new_mini_batch is not None:
+ self.stats_tool.count_prompt_tokens(new_mini_batch)
+ has_new_finished, outputs = self._prefill_batch(new_mini_batch)
+ if not new_mini_batch.is_clear():
+ self._merge_batch(self.running_batch, new_mini_batch)
+ self.running_batch.merge(new_mini_batch)
+ self.has_wait_tokens = 0
+
+ else:
+ self.stats_tool.count_output_tokens(self.running_batch)
+ has_new_finished, outputs = self._decode_batch(self.running_batch)
+ self._filter_runing_batch()
+ self.has_wait_tokens += 1
+
+ if has_new_finished:
+ return outputs
+ return None
+
+ def _prefill_batch(self, batch):
+ """
+ For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
+ """
+ self._init_batch(batch)
+
+ # TODO: figure out if cache and batch id is needed
+ ans = self.engine._prefill_batch(batch.batch_id)
+ req_to_out_token_id = ans
+ self._add_token_id_to_req(batch, req_to_out_token_id)
+ has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
+ outputs = self._handle_finish_req(batch, has_new_finished_req)
+ return has_new_finished_req, outputs
+ # delete finished reqs
+
+ def _decode_batch(self, batch: Batch):
+ """
+ Decoding process
+ """
+ ans = self.engine._decode_batch(batch.batch_id)
+ req_to_out_token_id = ans
+ self._add_token_id_to_req(batch, req_to_out_token_id)
+ has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
+ outputs = self._handle_finish_req(batch, has_new_finished_req)
+ return has_new_finished_req, outputs
+
+ def _handle_finish_req(self, batch: Batch, has_new_finished_req):
+ if has_new_finished_req:
+ finished_reqs = batch.filter_finished()
+ if batch.is_clear():
+ self._remove_batch(batch)
+ else:
+ self._filter_batch(batch)
+ return self._output_process(finished_reqs)
+ return None
+
+ def _output_process(self, finished_reqs: List[Req]):
+ """
+ Process the output of a batch.
+ """
+ outputs = []
+ for req in finished_reqs:
+ output = self.tokenizer.decode(req.output_ids)
+ outputs.append(RequestOutput(req.request_id, req.prompts, req.prompt_ids, output))
+ return outputs
+
+
+def start_dynamic_batching(args, tp_engine, waiting_req_list):
+ try:
+ batch_manager = Async_DynamicBatchManager(
+ tp_engine=tp_engine,
+ max_total_token_num=args.max_total_token_num,
+ batch_max_tokens=args.batch_max_tokens,
+ eos_id=args.eos_id,
+ model=args.model,
+ log_stats=not args.disable_log_stats,
+ log_stats_interval=args.log_stats_interval,
+ waiting_req_list=waiting_req_list,
+ )
+
+ except Exception:
+ raise Exception
+
+ return batch_manager
diff --git a/colossalai/legacy/inference/dynamic_batching/__init__.py b/colossalai/legacy/inference/dynamic_batching/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/colossalai/legacy/inference/dynamic_batching/get_tokenizer.py b/colossalai/legacy/inference/dynamic_batching/get_tokenizer.py
new file mode 100644
index 000000000000..94aa3f24393f
--- /dev/null
+++ b/colossalai/legacy/inference/dynamic_batching/get_tokenizer.py
@@ -0,0 +1,40 @@
+"""
+Motivated by VllM (https://github.com/vllm-project/vllm), This module is trying to resolve the tokenizer issue.
+
+license: MIT, see LICENSE for more details.
+"""
+
+from transformers import AutoTokenizer
+
+_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
+
+
+def get_tokenizer(
+ tokenizer=None,
+ tokenizer_name: str = "",
+ trust_remote_code: bool = False,
+ use_fast: bool = True,
+):
+ if tokenizer is not None:
+ tokenizer = tokenizer
+ else:
+ if "llama" in tokenizer_name.lower() and use_fast == True:
+ print(
+ "For some LLaMA-based models, initializing the fast tokenizer may "
+ "take a long time. To eliminate the initialization time, consider "
+ f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
+ "tokenizer. This is done automatically in Colossalai."
+ )
+
+ tokenizer_name = _FAST_LLAMA_TOKENIZER
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
+ )
+ except TypeError:
+ use_fast = False
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_name, use_fast=use_fast, trust_remote_code=trust_remote_code
+ )
+ return tokenizer
diff --git a/colossalai/legacy/inference/dynamic_batching/infer_batch.py b/colossalai/legacy/inference/dynamic_batching/infer_batch.py
new file mode 100644
index 000000000000..112784c15f84
--- /dev/null
+++ b/colossalai/legacy/inference/dynamic_batching/infer_batch.py
@@ -0,0 +1,346 @@
+# Adapted from https://github.com/ModelTC/lightllm
+
+import collections
+from dataclasses import dataclass
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+
+from colossalai.inference.tensor_parallel import MemoryManager
+
+
+# make batch infer state an attr of InferBatch
+class InferSamplingParams:
+ def __init__(
+ self,
+ do_sample: bool = False,
+ presence_penalty: float = 0.0,
+ frequency_penalty: float = 0.0,
+ temperature: float = 1.0,
+ top_p: float = 1.0,
+ top_k: int = -1,
+ vocab_size: int = -1,
+ ) -> None:
+ self.do_sample = do_sample
+ self.presence_penalty = presence_penalty
+ self.frequency_penalty = frequency_penalty
+ self.temperature = temperature
+ self.top_p = top_p
+ self.top_k = top_k
+ if self.top_k == -1:
+ self.top_k = vocab_size
+ return
+
+
+@dataclass
+class InferBatch:
+ batch_id: int
+ requests: List
+ requests_idx_mapping: Dict[int, int]
+
+ input_ids: torch.Tensor
+
+ all_input_ids: List[List[int]]
+ input_lengths: List[int]
+
+ out_token_id_counts: List
+ sampling_param_list: List[InferSamplingParams]
+
+ nopad_total_token_num: int
+ nopad_max_len_in_batch: int
+ nopad_b_loc: torch.Tensor
+ nopad_b_start_loc: torch.Tensor
+ nopad_b_seq_len: torch.Tensor
+ cache_manager: MemoryManager
+ max_total_len: int
+
+ @classmethod
+ @torch.no_grad()
+ def init_batch(
+ cls,
+ batch_id,
+ requests,
+ dtype: torch.dtype,
+ device: torch.device,
+ cache_manager: MemoryManager,
+ vocab_size: int,
+ max_total_len: int,
+ ) -> "InferBatch":
+ input_lengths = []
+ all_input_ids = []
+ requests_idx_mapping = {}
+
+ out_token_id_counts = []
+ sampling_param_list = []
+
+ nopad_total_token_num = 0
+ nopad_max_len_in_batch = 0
+ nopad_b_loc = torch.empty((len(requests), max_total_len + 12), dtype=torch.long, device="cuda")
+ # to avoid memory leak , we pre-allocate 12 more space for each batch.
+ nopad_b_start_loc = torch.zeros(len(requests), dtype=torch.int32, device="cuda")
+ for i, r in enumerate(requests):
+ # request id -> idx in list mapping
+ requests_idx_mapping[r["request_id"]] = i
+
+ tokenized_input = r["input_id"]
+
+ input_length = len(tokenized_input)
+ input_lengths.append(input_length)
+ all_input_ids.append(tokenized_input)
+ out_token_id_counts.append(collections.defaultdict(int))
+
+ # postprocessor
+ sampling_param = r["sampling_param"]
+ sampling_param["vocab_size"] = vocab_size
+ sampling_param_list.append(InferSamplingParams(**sampling_param))
+
+ nopad_total_token_num += input_length
+ nopad_max_len_in_batch = max(nopad_max_len_in_batch, input_length)
+
+ nopad_b_seq_len = torch.tensor(input_lengths, dtype=torch.int32, device="cuda")
+ nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
+
+ if len(requests) > 1:
+ input_ids = np.concatenate(all_input_ids, dtype=np.int64)
+ else:
+ input_ids = all_input_ids[0]
+
+ # Create tensors on device
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
+
+ return cls(
+ batch_id=batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ input_lengths=input_lengths,
+ all_input_ids=all_input_ids,
+ nopad_total_token_num=nopad_total_token_num,
+ nopad_max_len_in_batch=nopad_max_len_in_batch,
+ nopad_b_loc=nopad_b_loc,
+ nopad_b_start_loc=nopad_b_start_loc,
+ nopad_b_seq_len=nopad_b_seq_len,
+ out_token_id_counts=out_token_id_counts,
+ sampling_param_list=sampling_param_list,
+ cache_manager=cache_manager,
+ max_total_len=max_total_len,
+ )
+
+ @torch.no_grad()
+ def free_self(self) -> None:
+ """
+ Free the memory of the InferBatch itself
+ """
+ remove_index = []
+ for idx in range(len(self)):
+ remove_index.append(
+ self.nopad_b_loc[
+ idx,
+ (self.nopad_max_len_in_batch - 1)
+ - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
+ ]
+ )
+ remove_index = torch.cat(remove_index, dim=-1)
+ self.cache_manager.free(remove_index)
+
+ @torch.no_grad()
+ def filter(self, request_ids: List[int]) -> "InferBatch":
+ """
+ Filter finished batch and return a new InferBatch with left ones.
+ """
+ if len(request_ids) == 0:
+ raise ValueError("Batch must have at least one request")
+ if len(request_ids) == len(self):
+ return self
+ requests_idx_mapping = {}
+ indices = []
+ requests = []
+ all_input_ids = []
+ input_lengths = []
+ nopad_total_token_num = 0
+ nopad_max_len_in_batch = 0
+ nopad_b_loc = torch.empty((len(request_ids), self.max_total_len + 12), dtype=torch.long, device="cuda")
+ nopad_b_start_loc = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
+ nopad_b_seq_len = torch.zeros(len(request_ids), dtype=torch.int32, device="cuda")
+
+ left_idx = []
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ left_idx.append(idx)
+
+ left_idx_set = set(left_idx)
+ remove_index = []
+ for idx in range(len(self)):
+ if idx not in left_idx_set:
+ remove_index.append(
+ self.nopad_b_loc[
+ idx,
+ (self.nopad_max_len_in_batch - 1)
+ - (self.nopad_b_seq_len[idx] - 1) : (self.nopad_max_len_in_batch - 1),
+ ]
+ )
+ remove_index = torch.cat(remove_index, dim=-1)
+ self.cache_manager.free(remove_index)
+
+ nopad_max_len_in_batch = 0
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ indices.append(idx)
+
+ nopad_b_seq_len[:] = self.nopad_b_seq_len[indices]
+ nopad_max_len_in_batch = torch.max(nopad_b_seq_len).item()
+ nopad_b_start_loc[1:] = torch.cumsum(nopad_b_seq_len, dim=0, dtype=torch.int32)[0:-1]
+ nopad_total_token_num = torch.sum(nopad_b_seq_len).item()
+
+ nopad_b_loc[:, 0 : (nopad_max_len_in_batch - 1)] = self.nopad_b_loc[
+ indices,
+ (self.nopad_max_len_in_batch - 1) - (nopad_max_len_in_batch - 1) : (self.nopad_max_len_in_batch - 1),
+ ]
+ for i, request_id in enumerate(request_ids):
+ idx = self.requests_idx_mapping[request_id]
+ requests_idx_mapping[request_id] = i
+ requests.append(self.requests[idx])
+ all_input_ids.append(self.all_input_ids[idx])
+ input_lengths.append(self.input_lengths[idx])
+
+ input_ids = self.input_ids[indices]
+
+ return InferBatch(
+ batch_id=self.batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ input_lengths=input_lengths,
+ all_input_ids=all_input_ids,
+ nopad_total_token_num=nopad_total_token_num,
+ nopad_max_len_in_batch=nopad_max_len_in_batch,
+ nopad_b_loc=nopad_b_loc,
+ nopad_b_start_loc=nopad_b_start_loc,
+ nopad_b_seq_len=nopad_b_seq_len,
+ out_token_id_counts=[self.out_token_id_counts[_i] for _i in indices],
+ sampling_param_list=[self.sampling_param_list[_i] for _i in indices],
+ cache_manager=self.cache_manager,
+ max_total_len=self.max_total_len,
+ )
+
+ @classmethod
+ @torch.no_grad()
+ def merge(cls, batch1, batch2) -> "InferBatch":
+ """
+ Return megerd new InferBatch
+ """
+ requests = batch1.requests + batch2.requests
+ requests_idx_mapping = {}
+ new_batch_size = len(batch1) + len(batch2)
+
+ input_ids = batch1.input_ids.new_empty(new_batch_size)
+ all_input_ids = []
+ input_lengths = []
+ out_token_id_counts = []
+ sampling_param_list = []
+
+ cumulative_batch_size = 0
+ nopad_total_token_num = batch1.nopad_total_token_num + batch2.nopad_total_token_num
+ nopad_max_len_in_batch = max(batch1.nopad_max_len_in_batch, batch2.nopad_max_len_in_batch)
+ max_total_len = max(batch1.max_total_len, batch2.max_total_len)
+ nopad_b_loc = torch.empty((new_batch_size, batch1.max_total_len + 12), dtype=torch.long, device="cuda")
+ nopad_b_start_loc = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
+ nopad_b_seq_len = torch.zeros(new_batch_size, dtype=torch.int32, device="cuda")
+ nopad_start_loc_len_temp = 0
+ batches = [batch1, batch2]
+ for i, batch in enumerate(batches):
+ if i == 0:
+ requests_idx_mapping = batch.requests_idx_mapping
+ else:
+ for k, v in batch.requests_idx_mapping.items():
+ requests_idx_mapping[k] = v + cumulative_batch_size
+ start_index = cumulative_batch_size
+ end_index = cumulative_batch_size + len(batch)
+ input_ids[start_index:end_index] = batch.input_ids
+ nopad_b_seq_len[start_index:end_index] = batch.nopad_b_seq_len
+ nopad_b_start_loc[start_index:end_index] = batch.nopad_b_start_loc + nopad_start_loc_len_temp
+ nopad_start_loc_len_temp = nopad_b_start_loc[end_index - 1] + nopad_b_seq_len[end_index - 1]
+ nopad_b_loc[
+ start_index:end_index,
+ nopad_max_len_in_batch - batch.nopad_max_len_in_batch : nopad_max_len_in_batch - 1,
+ ] = batch.nopad_b_loc[:, : batch.nopad_max_len_in_batch - 1]
+
+ all_input_ids.extend(batch.all_input_ids)
+
+ input_lengths.extend(batch.input_lengths)
+ out_token_id_counts.extend(batch.out_token_id_counts)
+ sampling_param_list.extend(batch.sampling_param_list)
+ # Update
+ cumulative_batch_size += len(batch)
+
+ nopad_b_loc[:, nopad_max_len_in_batch - 1] = (
+ nopad_total_token_num - new_batch_size + torch.arange(0, new_batch_size, dtype=torch.int32, device="cuda")
+ )
+ return InferBatch(
+ batch_id=batches[0].batch_id,
+ requests=requests,
+ requests_idx_mapping=requests_idx_mapping,
+ input_ids=input_ids,
+ input_lengths=input_lengths,
+ all_input_ids=all_input_ids,
+ nopad_total_token_num=nopad_total_token_num,
+ nopad_max_len_in_batch=nopad_max_len_in_batch,
+ nopad_b_loc=nopad_b_loc,
+ nopad_b_start_loc=nopad_b_start_loc,
+ nopad_b_seq_len=nopad_b_seq_len,
+ out_token_id_counts=out_token_id_counts,
+ sampling_param_list=sampling_param_list,
+ cache_manager=batches[0].cache_manager,
+ max_total_len=max_total_len,
+ )
+
+ def __len__(self):
+ return len(self.requests)
+
+ def get_post_sample_tensors(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ presence_penalties: List[float] = []
+ frequency_penalties: List[float] = []
+ temperatures: List[float] = []
+ top_ps: List[float] = []
+ top_ks: List[int] = []
+ p_token_ids: List[int] = []
+ p_token_counts: List[int] = []
+ p_seq_len: List[int] = [
+ 0,
+ ]
+ p_max_len_in_batch: int = 0
+ for i, id_to_count in enumerate(self.out_token_id_counts):
+ sample_param = self.sampling_param_list[i]
+ presence_penalties.append(sample_param.presence_penalty)
+ frequency_penalties.append(sample_param.frequency_penalty)
+ temperatures.append(sample_param.temperature)
+ top_ps.append(sample_param.top_p)
+ top_ks.append(sample_param.top_k)
+
+ for token_id, count in id_to_count.items():
+ p_token_ids.append(token_id)
+ p_token_counts.append(count)
+ p_seq_len.append(len(id_to_count))
+ p_max_len_in_batch = max(p_max_len_in_batch, len(id_to_count))
+
+ presence_penalties = torch.tensor(presence_penalties, dtype=torch.float, device="cuda")
+ frequency_penalties = torch.tensor(frequency_penalties, dtype=torch.float, device="cuda")
+ temperatures = torch.tensor(temperatures, dtype=torch.float, device="cuda")
+ top_ps = torch.tensor(top_ps, dtype=torch.float, device="cuda")
+ top_ks = torch.tensor(top_ks, dtype=torch.int32, device="cuda")
+ p_token_ids = torch.tensor(p_token_ids, dtype=torch.int32, device="cuda")
+ p_token_counts = torch.tensor(p_token_counts, dtype=torch.int32, device="cuda")
+ p_seq_len = torch.tensor(p_seq_len, dtype=torch.int32, device="cuda")
+ p_cumsum_seq_len = torch.cumsum(p_seq_len, dim=0, dtype=torch.int32)
+ return (
+ presence_penalties,
+ frequency_penalties,
+ temperatures,
+ top_ps,
+ top_ks,
+ p_token_ids,
+ p_token_counts,
+ p_cumsum_seq_len,
+ p_max_len_in_batch,
+ )
diff --git a/colossalai/legacy/inference/dynamic_batching/io_struct.py b/colossalai/legacy/inference/dynamic_batching/io_struct.py
new file mode 100644
index 000000000000..fc5ecfe5796b
--- /dev/null
+++ b/colossalai/legacy/inference/dynamic_batching/io_struct.py
@@ -0,0 +1,166 @@
+# Adapted from https://github.com/ModelTC/lightllm
+
+from typing import Dict, List, Tuple
+
+from .sampling_params import SamplingParams
+
+
+class Req:
+ def __init__(self, request_id, prompt_ids, sample_params: SamplingParams, prompts: str = ""):
+ self.request_id = request_id
+ self.prompt_ids = prompt_ids
+ self.input_len = len(prompt_ids)
+ self.max_output_len = sample_params.max_new_tokens
+ self.sample_params = sample_params
+ self.output_ids = []
+ self.output_metadata_list = []
+ self.has_generate_finished = False
+ self.aborted = False
+ self.prompts = prompts
+
+ def to_rpc_obj(self):
+ return {
+ "request_id": self.request_id,
+ "input_id": self.prompt_ids,
+ "output_len": self.max_output_len,
+ "sampling_param": self.sample_params.to_dict(),
+ }
+
+ def stop_sequences_matched(self):
+ # should we add stpp sequences to the sample params?
+ if self.sample_params.stop_sequences is not None:
+ for stop_token_ids in self.sample_params.stop_sequences:
+ stop_len = len(stop_token_ids)
+ if (
+ stop_len > 0
+ and len(self.output_ids) >= stop_len
+ and all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len))
+ ):
+ return True
+ return False
+
+ def __repr__(self):
+ return f"request_id(n={self.request_id}, " f"prompt_ids={self.prompt_ids}, "
+
+
+class Batch:
+ def __init__(self, batch_id, reqs: List[Req]):
+ self.batch_id = batch_id
+ self.reqs = reqs
+ self.id_to_reqs = {req.request_id: req for req in reqs}
+
+ def input_tokens(self):
+ batch_input_tokens = 0
+ for req in self.reqs:
+ batch_input_tokens += req.input_len
+ return batch_input_tokens
+
+ def calcu_max_tokens(self):
+ tokens = 0
+ for req in self.reqs:
+ tokens += req.input_len + req.max_output_len
+ return tokens
+
+ def calcu_used_tokens(self):
+ tokens = 0
+ for req in self.reqs:
+ tokens += req.input_len + len(req.output_ids)
+ return tokens
+
+ def mark_finished_req(self, eos_id, engine_max_output_len):
+ has_new_finish = False
+ for req in self.reqs:
+ if req.stop_sequences_matched():
+ req.has_generate_finished = True
+ has_new_finish = True
+ if len(req.output_ids) >= engine_max_output_len:
+ req.has_generate_finished = True
+ has_new_finish = True
+ if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False:
+ req.has_generate_finished = True
+ has_new_finish = True
+ if len(req.output_ids) >= req.max_output_len or req.aborted:
+ req.has_generate_finished = True
+ has_new_finish = True
+ return has_new_finish
+
+ def filter_finished(self) -> List[Req]:
+ """
+ Filter finished requests from the batch, the finished ones will be removed from 'reqs'.
+ """
+ # TODO: the logic of return should be defined here.
+ unfinished_req = []
+ finished_req = []
+ for req in self.reqs:
+ if not req.has_generate_finished:
+ unfinished_req.append(req)
+ else:
+ finished_req.append(req)
+ self.reqs = unfinished_req
+ self.id_to_reqs = {req.request_id: req for req in self.reqs}
+ return finished_req
+
+ def is_clear(self):
+ return len(self.reqs) == 0
+
+ def merge(self, mini_batch):
+ for _req in mini_batch.reqs:
+ self.reqs.append(_req)
+ self.id_to_reqs = {req.request_id: req for req in self.reqs}
+ return
+
+ def __repr__(self):
+ return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, "
+
+ def __len__(self):
+ return len(self.reqs)
+
+
+class BatchTokenIdOut:
+ def __init__(self):
+ self.reqs_infs: List[
+ Tuple[str, int, Dict, bool, bool]
+ ] = [] # [req_id, new_token_id, gen_metadata, finished_state, abort_state]
+
+
+class BatchStrOut:
+ def __init__(self):
+ self.reqs_infs: List[
+ Tuple[str, str, Dict, bool, bool]
+ ] = [] # [req_id, token_str, gen_metadata, finished_state, abort_state]
+
+
+class AbortReq:
+ def __init__(self, req_id):
+ self.req_id = req_id
+
+
+class RequestOutput:
+ """The output data of a request to the LLM.
+
+ Args:
+ request_id: The unique ID of the request.
+ prompt: The prompt string of the request.
+ prompt_token_ids: The token IDs of the prompt.
+ outputs: The output sequences of the request.
+ """
+
+ def __init__(
+ self,
+ request_id: str,
+ prompt: str,
+ prompt_token_ids: List[int],
+ outputs,
+ ) -> None:
+ self.request_id = request_id
+ self.prompt = prompt
+ self.prompt_token_ids = prompt_token_ids
+ self.outputs = outputs
+
+ def __repr__(self) -> str:
+ return (
+ f"RequestOutput(request_id={self.request_id}, "
+ f"prompt={self.prompt!r}, "
+ f"prompt_token_ids={self.prompt_token_ids}, "
+ f"outputs={self.outputs}, "
+ )
diff --git a/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py b/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py
new file mode 100644
index 000000000000..3e40bb0eeb9d
--- /dev/null
+++ b/colossalai/legacy/inference/dynamic_batching/ray_dist_init.py
@@ -0,0 +1,154 @@
+import logging
+import os
+from typing import List
+
+import ray
+import ray.util.collective as collective
+import torch
+from transformers import AutoModelForCausalLM
+
+import colossalai
+from colossalai.inference.async_manager import start_dynamic_batching
+from colossalai.inference.dynamic_batching.get_tokenizer import get_tokenizer
+from colossalai.inference.dynamic_batching.io_struct import RequestOutput
+from colossalai.inference.dynamic_batching.ray_init_config import EngineArgsClass, RooterArgsClass
+from colossalai.inference.dynamic_batching.sampling_params import SamplingParams
+from colossalai.inference.tensor_parallel.engine import TPInferEngine
+from colossalai.shardformer import ShardConfig
+from colossalai.testing import free_port
+
+ray_serve_logger = logging.getLogger("ray.serve")
+
+
+def log_cuda_info(scope_name: str):
+ ray_serve_logger.info(f" {scope_name}: ray.get_gpu_ids(): {ray.get_gpu_ids()}")
+ ray_serve_logger.info(
+ f" {scope_name}: CUDA_VISIBLE_DEVICES: {os.getenv('CUDA_VISIBLE_DEVICES', 'NO DEVICES FOUND!')}"
+ )
+ if torch.cuda.is_available():
+ ray_serve_logger.info(
+ f" {scope_name}: cuda current_device: {torch.cuda.current_device()}, cuda device count: {torch.cuda.device_count()}"
+ )
+ else:
+ ray_serve_logger.info(f" {scope_name}: cuda is not available!")
+
+
+@ray.remote(num_gpus=1)
+class Worker:
+ def __init__(
+ self,
+ model_path: str,
+ tensor_parallel_size: int,
+ max_batch_size: int,
+ max_input_len: int,
+ max_output_len: int,
+ router_config: RooterArgsClass,
+ ):
+ log_cuda_info("Worker.init")
+ self.tensor_parallel_size = tensor_parallel_size
+ self.model_path = model_path
+ self.max_batch_size = max_batch_size
+ self.max_input_len = max_input_len
+ self.max_output_len = max_output_len
+ self.router_config = router_config
+
+ def setup(self, world_size, rank, port):
+ # initialize a ray collective group, otherwise colossalai distributed env won't be built successfully
+ collective.init_collective_group(world_size, rank, "nccl", "default")
+ # initialize and set distributed environment
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ ray_serve_logger.info(f"Worker with rank {rank} (world size {world_size}) setting up..")
+ log_cuda_info("Worker.setup")
+
+ # Load model
+ self.tokenizer = get_tokenizer(tokenizer_name=self.model_path)
+ if self.tokenizer.pad_token is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+ self.model = AutoModelForCausalLM.from_pretrained(
+ self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
+ )
+ shard_config = ShardConfig(
+ enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
+ )
+ self.infer_engine = TPInferEngine(
+ self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
+ )
+ self.start_dynamic_batching = start_dynamic_batching(self.router_config, self.infer_engine, [])
+
+ return True
+
+ # def generate(self, request_id: str, prompt: str, sampling_params: SamplingParams) -> List[str]:
+ # ray_serve_logger.info(f"text: {prompt}")
+
+ # final_outputs = self.start_dynamic_batching.generate(prompt, sampling_params, request_id)
+
+ # return final_outputs
+
+ def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams):
+ self.start_dynamic_batching.add_input(request_id, prompt, sampling_params)
+
+ def abort(self, request_id: str):
+ self.start_dynamic_batching.abort(request_id)
+
+ def step(self) -> List[RequestOutput]:
+ return self.start_dynamic_batching._step()
+
+ def add_req(self, prompt_ids: List[int], sampling_params: SamplingParams, request_id: str, prompt: str):
+ self.start_dynamic_batching.add_req(prompt_ids, sampling_params, request_id, prompt)
+
+ def is_running(self):
+ return self.start_dynamic_batching.is_running()
+
+
+class Driver:
+ def __init__(self, router_config: RooterArgsClass, engine_config: EngineArgsClass):
+ log_cuda_info("Driver:init")
+ model_path = engine_config.model
+ tensor_parallel_size = engine_config.tensor_parallel_size
+
+ self.num_workers = tensor_parallel_size
+ self.workers = []
+ init_rets = []
+
+ # Just grab a free port on localhost
+ # NOTE workers in this communication group listen to the same port
+ available_port = free_port()
+
+ for i in range(self.num_workers):
+ worker_name = "worker_idx_{}".format(i)
+ w = Worker.options(name=worker_name).remote(
+ model_path,
+ self.num_workers,
+ engine_config.max_batch_size,
+ engine_config.max_input_len,
+ engine_config.max_output_len,
+ router_config,
+ )
+ self.workers.append(w)
+ init_rets.append(w.setup.remote(self.num_workers, i, available_port))
+ _options = {
+ "group_name": "default_driver",
+ "world_size": self.num_workers,
+ "ranks": [i for i in range(self.num_workers)],
+ "backend": "nccl",
+ }
+ collective.create_collective_group(self.workers, **_options)
+ _ = ray.get(init_rets)
+
+ def add_input(self, request_id: str, prompt: str, sampling_params: SamplingParams):
+ ray.get([w.add_input.remote(request_id, prompt, sampling_params) for w in self.workers])
+
+ def abort(self, request_id: str):
+ ray.get([w.abort.remote(request_id) for w in self.workers])
+
+ def step(self):
+ results = ray.get([w.step.remote() for w in self.workers])
+ outputs = results[0] # get any one of the copies
+ return outputs
+
+ def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompt: str):
+ ray.get([w.add_req.remote(prompt_ids, sampling_params, request_id, prompt) for w in self.workers])
+
+ def is_running(self):
+ results = ray.get([w.is_running.remote() for w in self.workers])
+ return any(results)
diff --git a/colossalai/legacy/inference/dynamic_batching/ray_init_config.py b/colossalai/legacy/inference/dynamic_batching/ray_init_config.py
new file mode 100644
index 000000000000..471f07330aec
--- /dev/null
+++ b/colossalai/legacy/inference/dynamic_batching/ray_init_config.py
@@ -0,0 +1,58 @@
+import logging
+
+import yaml
+from pydantic import BaseModel
+
+logger = logging.getLogger(__name__)
+
+
+class EngineArgsClass(BaseModel):
+ """Config for Engine"""
+
+ model: str
+ tensor_parallel_size: int = 2
+ max_batch_size: int = 4
+ max_input_len: int = 128
+ max_output_len: int = 32
+
+
+class RooterArgsClass(BaseModel):
+ """Config for Rooter"""
+
+ max_total_token_num: int = 42
+ batch_max_tokens: int = 42
+ eos_id: int = 0
+ disable_log_stats: bool = False
+ log_stats_interval: int = 10
+ model: str
+
+
+class RayInitConfig(BaseModel):
+ """All-together configs without app router config"""
+
+ engine_config_data: EngineArgsClass
+ router_config_data: RooterArgsClass
+
+ @classmethod
+ def from_yaml_path(cls, path: str):
+ try:
+ with open(path, "r") as yaml_file:
+ try:
+ config = yaml.safe_load(yaml_file)
+ # serve deployment config
+ engine_config = config.get("engine_config", {})
+ router_config = config.get("router_config", {})
+
+ return cls(
+ engine_config_data=engine_config,
+ router_config_data=router_config,
+ )
+ except yaml.YAMLError as e:
+ logger.error(f"An Error occurred when parsing yaml: {e}")
+ raise
+ except FileNotFoundError:
+ logger.error(f"The file '{path}' does not exist!")
+ raise
+ except OSError as e:
+ logger.error(f"An Error occurred: {e}")
+ raise
diff --git a/colossalai/legacy/inference/dynamic_batching/req_queue.py b/colossalai/legacy/inference/dynamic_batching/req_queue.py
new file mode 100644
index 000000000000..0de43bd1a21f
--- /dev/null
+++ b/colossalai/legacy/inference/dynamic_batching/req_queue.py
@@ -0,0 +1,73 @@
+# Adapted from https://github.com/ModelTC/lightllm
+
+import uuid
+from typing import List
+
+import numpy as np
+
+from .io_struct import Batch, Req
+
+
+class ReqQueue:
+ def __init__(self, max_total_tokens, batch_max_tokens, running_max_req_size, waiting_req_list=[]) -> None:
+ self.max_total_tokens = max_total_tokens
+ assert batch_max_tokens is not None
+ self.batch_max_tokens = batch_max_tokens
+ self.running_max_req_size = running_max_req_size
+ self.waiting_req_list: List[Req] = waiting_req_list
+
+ def append(self, req):
+ self.waiting_req_list.append(req)
+ return
+
+ def _init_cache_list(self, current_batch: Batch):
+ if current_batch is not None:
+ self.cache_len_list = [
+ (req.input_len + len(req.output_ids), req.max_output_len - len(req.output_ids) - 1)
+ for req in current_batch.reqs
+ ]
+ else:
+ self.cache_len_list = []
+
+ # @calculate_time(show=True, min_cost_ms=0.1)
+ def _can_add_new_req(self, req):
+ self.cache_len_list.append((req.input_len + 1, req.max_output_len - 1)) # hard to analysis
+ self.cache_len_list.sort(key=lambda x: -x[1])
+
+ left_out_len_array = np.array([e[1] for e in self.cache_len_list])
+ # assert left_out_len_array.min() >= 0
+ has_run_len_array = np.array([e[0] for e in self.cache_len_list])
+ cum_run_len_array = np.cumsum(has_run_len_array)
+ size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
+
+ need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
+ # NOTE: change here < to <=
+ return need_max_token_num <= self.max_total_tokens and len(self.cache_len_list) <= self.running_max_req_size
+
+ def generate_new_batch(self, current_batch: Batch = None):
+ if current_batch is not None and len(current_batch.reqs) >= self.running_max_req_size:
+ return None
+ self._init_cache_list(current_batch)
+ can_run_list = []
+ new_batch_total_tokens = 0
+ aborted_count = 0
+ for req in self.waiting_req_list:
+ flag = self._can_add_new_req(req)
+ if req.aborted:
+ aborted_count += 1
+ continue
+ if flag and new_batch_total_tokens + req.input_len <= self.batch_max_tokens:
+ can_run_list.append(req)
+ new_batch_total_tokens += req.input_len
+ else:
+ break
+
+ if len(can_run_list) != 0:
+ new_batch = Batch(uuid.uuid4().hex, can_run_list)
+ self.waiting_req_list = self.waiting_req_list[len(can_run_list) + aborted_count :]
+ return new_batch
+ else:
+ return None
+
+ def __len__(self):
+ return self.waiting_req_list.__len__()
diff --git a/colossalai/legacy/inference/dynamic_batching/sampling_params.py b/colossalai/legacy/inference/dynamic_batching/sampling_params.py
new file mode 100644
index 000000000000..a37a83390021
--- /dev/null
+++ b/colossalai/legacy/inference/dynamic_batching/sampling_params.py
@@ -0,0 +1,83 @@
+# Adapted from https://github.com/ModelTC/lightllm
+
+"""Sampling parameters for text generation."""
+from typing import List, Optional, Union
+
+_SAMPLING_EPS = 1e-5
+
+
+class SamplingParams:
+ def __init__(
+ self,
+ do_sample: bool = False,
+ presence_penalty: float = 0.0,
+ frequency_penalty: float = 0.0,
+ temperature: float = 1.0,
+ top_p: float = 1.0,
+ top_k: int = -1, # -1 is for all
+ ignore_eos: bool = False,
+ max_new_tokens: int = 256,
+ stop_sequences: Optional[Union[str, List[str]]] = None, # conditions to stop generation
+ ) -> None:
+ self.do_sample = do_sample
+ self.presence_penalty = presence_penalty
+ self.frequency_penalty = frequency_penalty
+ self.temperature = temperature
+ self.top_p = top_p
+ self.top_k = top_k
+ self.ignore_eos = ignore_eos
+ self.max_new_tokens = max_new_tokens
+ self.stop_sequences = stop_sequences
+ if self.do_sample == False:
+ self.temperature = 1.0
+ self.top_p = 1.0
+ self.top_k = 1
+ if (
+ self.temperature >= 0.0 and self.temperature < _SAMPLING_EPS
+ ): # temperature is too slow, change to greedy search
+ self.temperature = 1.0
+ self.top_k = 1
+ return
+
+ def verify(self):
+ if self.presence_penalty < 0.0:
+ raise ValueError(f"presence_penalty must >= 0.0, got {self.presence_penalty}")
+ if self.frequency_penalty < 0.0:
+ raise ValueError(f"frequency_penalty must >= 0.0, got {self.frequency_penalty}")
+ if self.temperature <= 0.0:
+ raise ValueError(f"temperature must > 0.0, got {self.temperature}")
+ if self.top_p <= 0.0 or self.top_p > 1.0:
+ raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}")
+ if self.top_k < -1 or self.top_k == 0:
+ raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
+ if self.max_new_tokens < 1:
+ raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
+ return
+
+ def stop_sentences_to_token_ids(self, tokenizer):
+ if self.stop_sequences is None:
+ self.stop_sequences = []
+ else:
+ if isinstance(self.stop_sequences, str):
+ self.stop_sequences = [self.stop_sequences]
+ new_stop_sequences = []
+ for stop_str in self.stop_sequences:
+ stop_str_ids = tokenizer.encode(stop_str)
+ if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
+ stop_str_ids = stop_str_ids[1:]
+ if len(stop_str_ids) > 0:
+ new_stop_sequences.append(stop_str_ids)
+ self.stop_sequences = new_stop_sequences
+ return
+
+ def to_dict(self):
+ ret = {}
+ ret["do_sample"] = self.do_sample
+ ret["presence_penalty"] = self.presence_penalty
+ ret["frequency_penalty"] = self.frequency_penalty
+ ret["temperature"] = self.temperature
+ ret["top_p"] = self.top_p
+ ret["top_k"] = self.top_k
+ # if self.ignore_eos is not None:
+ # ret["ignore_eos"] = self.ignore_eos
+ return ret
diff --git a/colossalai/legacy/inference/dynamic_batching/stats.py b/colossalai/legacy/inference/dynamic_batching/stats.py
new file mode 100644
index 000000000000..524072861a3f
--- /dev/null
+++ b/colossalai/legacy/inference/dynamic_batching/stats.py
@@ -0,0 +1,45 @@
+# Adapted from https://github.com/ModelTC/lightllm
+
+import time
+
+
+class Stats:
+ def __init__(self, log_status, log_stats_interval) -> None:
+ self.log_stats = log_status
+ self.log_stats_interval = log_stats_interval
+ self.last_log_time = time.time()
+ self.all_tokens = 0
+ self.output_tokens = 0
+ self.prompt_tokens = 0
+ return
+
+ def count_prompt_tokens(self, run_batch):
+ if self.log_stats:
+ tokens = run_batch.input_tokens()
+ self.prompt_tokens += tokens
+ self.all_tokens += tokens
+ return
+
+ def count_output_tokens(self, run_batch):
+ if self.log_stats:
+ tokens = len(run_batch.reqs)
+ self.output_tokens += tokens
+ self.all_tokens += tokens
+ return
+
+ def print_stats(self):
+ if not self.log_stats:
+ return
+
+ now = time.time()
+ if now - self.last_log_time > self.log_stats_interval:
+ print(
+ f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
+ f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
+ f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
+ )
+ self.all_tokens = 0
+ self.output_tokens = 0
+ self.prompt_tokens = 0
+ self.last_log_time = now
+ return
diff --git a/colossalai/legacy/inference/hybridengine/__init__.py b/colossalai/legacy/inference/hybridengine/__init__.py
new file mode 100644
index 000000000000..6377ef817301
--- /dev/null
+++ b/colossalai/legacy/inference/hybridengine/__init__.py
@@ -0,0 +1,3 @@
+from .engine import CaiInferEngine
+
+__all__ = ["CaiInferEngine"]
diff --git a/colossalai/legacy/inference/hybridengine/engine.py b/colossalai/legacy/inference/hybridengine/engine.py
new file mode 100644
index 000000000000..bb0b4c77a2a7
--- /dev/null
+++ b/colossalai/legacy/inference/hybridengine/engine.py
@@ -0,0 +1,170 @@
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from transformers.tokenization_utils_base import BatchEncoding
+
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.pipeline.schedule.generate import GenerateSchedule
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer import ShardConfig, ShardFormer
+from colossalai.shardformer.policies.base_policy import Policy
+
+from ..pipeline.microbatch_manager import MicroBatchManager
+from ..tensor_parallel.kvcache_manager import MemoryManager
+
+PP_AXIS, TP_AXIS = 0, 1
+
+_supported_models = [
+ "LlamaForCausalLM",
+]
+
+
+class CaiInferEngine:
+ """
+ CaiInferEngine is a class that handles the pipeline parallel inference.
+
+ Args:
+ tp_size (int): the size of tensor parallelism.
+ pp_size (int): the size of pipeline parallelism.
+ model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
+ model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
+ micro_batch_size (int): the micro batch size.
+ micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
+ max_batch_size (int): the maximum batch size.
+ max_input_len (int): the maximum input length.
+ max_output_len (int): the maximum output length.
+
+ Example:
+
+ ```python
+ from colossalai.inference import InferEngine
+ from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
+ import colossalai
+ from transformers import LlamaForCausalLM, LlamaTokenizer
+
+ colossalai.launch_from_torch(config={})
+
+ model = LlamaForCausalLM.from_pretrained("your_path_to_model")
+ tokenizer = LlamaTokenizer.from_pretrained("/home/lczyh/share/models/llama-7b-hf")
+ # assume the model is infered with 2 pipeline stages
+ inferengine = CaiInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy())
+
+ input = ["Introduce a landmark in China ","Introduce a landmark in China "]
+ data = tokenizer(input, return_tensors='pt')
+ output = inferengine.inference([data.to('cuda').data])
+
+ ```
+
+ """
+
+ def __init__(
+ self,
+ tp_size: int = 1,
+ pp_size: int = 1,
+ dtype: str = "fp16",
+ model: nn.Module = None,
+ model_policy: Policy = None,
+ micro_batch_size: int = 1,
+ micro_batch_buffer_size: int = None,
+ max_batch_size: int = 4,
+ max_input_len: int = 32,
+ max_output_len: int = 32,
+ verbose: bool = False,
+ # TODO: implement early_stopping, and various gerneration options
+ early_stopping: bool = False,
+ do_sample: bool = False,
+ num_beams: int = 1,
+ ) -> None:
+ assert model.__class__.__name__ in _supported_models, f"Model {model.__class__.__name__} is not supported."
+ assert (
+ tp_size * pp_size == dist.get_world_size()
+ ), f"TP size({tp_size}) * PP size({pp_size}) should be equal to the global world size ({dist.get_world_size()})"
+ assert model and model_policy, "Model with model_policy should be provided."
+ assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
+
+ assert max_batch_size <= 64, "Max batch size exceeds the constraint"
+ assert max_input_len + max_output_len <= 4096, "Max length exceeds the constraint"
+
+ # TODO: support only tensor parallel inference
+ assert pp_size > 1, "Not support only tensor parallel inference."
+ self.pp_size = pp_size
+ self.tp_size = tp_size
+
+ if dtype == "fp16":
+ self.dtype = torch.float16
+ model.half()
+ elif dtype == "bf16":
+ self.dtype = torch.bfloat16
+ model.to(torch.bfloat16)
+ else:
+ self.dtype = torch.float32
+
+ # Init pg mesh
+ pg_mesh = ProcessGroupMesh(pp_size, tp_size)
+
+ stage_manager = None
+ if pp_size > 1:
+ stage_manager = PipelineStageManager(pg_mesh, PP_AXIS, True)
+ self.cache_manager_list = [
+ self._init_manager(model, max_batch_size, max_input_len, max_output_len)
+ for _ in range(micro_batch_buffer_size or pp_size)
+ ]
+ self.mb_manager = MicroBatchManager(
+ stage_manager.stage,
+ micro_batch_size,
+ micro_batch_buffer_size or pp_size,
+ max_input_len,
+ max_output_len,
+ self.cache_manager_list,
+ )
+ self.verbose = verbose
+ self.schedule = GenerateSchedule(stage_manager, self.mb_manager, verbose)
+
+ self.model = self._shardformer(model, model_policy, stage_manager, pg_mesh.get_group_along_axis(TP_AXIS))
+
+ def inference(self, input_list):
+ """
+ Args:
+ input_list (list): a list of input data, each element is a `BatchEncoding` or `dict`.
+
+ Returns:
+ out (list): a list of output data, each element is a list of token.
+ timestamp (float): the time cost of the inference, only return when verbose is `True`.
+ """
+ assert isinstance(
+ input_list, (BatchEncoding, dict)
+ ), f"Only accept BatchEncoding or dict as input, but get {input_list.__class__.__name__}."
+ if isinstance(input_list, BatchEncoding):
+ input_list = input_list.data
+ out, timestamp = self.schedule.generate_step(self.model, iter([input_list]))
+ if self.verbose:
+ return out, timestamp
+ else:
+ return out
+
+ def _shardformer(self, model, model_policy, stage_manager, tp_group):
+ shardconfig = ShardConfig(
+ tensor_parallel_process_group=tp_group,
+ pipeline_stage_manager=stage_manager,
+ enable_tensor_parallelism=False,
+ enable_fused_normalization=False,
+ enable_all_optimization=False,
+ enable_flash_attention=False,
+ enable_jit_fused=False,
+ enable_sequence_parallelism=False,
+ )
+ shardformer = ShardFormer(shard_config=shardconfig)
+ shard_model, _ = shardformer.optimize(model, model_policy)
+ return shard_model.cuda()
+
+ def _init_manager(self, model, max_batch_size: int, max_input_len: int, max_output_len: int) -> None:
+ max_total_token_num = max_batch_size * (max_input_len + max_output_len)
+ head_dim = model.config.hidden_size // model.config.num_attention_heads
+ head_num = model.config.num_attention_heads
+ num_hidden_layers = (
+ model.config.num_hidden_layers if hasattr(model.config, "num_hidden_layers") else model.config.num_layers
+ )
+ layer_num = num_hidden_layers // self.pp_size
+
+ cache_manager = MemoryManager(max_total_token_num, self.dtype, head_num, head_dim, layer_num)
+ return cache_manager
diff --git a/colossalai/legacy/inference/hybridengine/modeling/__init__.py b/colossalai/legacy/inference/hybridengine/modeling/__init__.py
new file mode 100644
index 000000000000..239bdebd7efd
--- /dev/null
+++ b/colossalai/legacy/inference/hybridengine/modeling/__init__.py
@@ -0,0 +1,3 @@
+from .llama import LlamaInferenceForwards
+
+__all__ = ["LlamaInferenceForwards"]
diff --git a/colossalai/legacy/inference/hybridengine/modeling/_utils.py b/colossalai/legacy/inference/hybridengine/modeling/_utils.py
new file mode 100644
index 000000000000..068b64b4f829
--- /dev/null
+++ b/colossalai/legacy/inference/hybridengine/modeling/_utils.py
@@ -0,0 +1,67 @@
+"""
+Utils for model inference
+"""
+import os
+
+import torch
+
+from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
+
+
+def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
+ """
+ This function copies the key and value cache to the memory cache
+ Args:
+ layer_id : id of current layer
+ key_buffer : key cache
+ value_buffer : value cache
+ context_mem_index : index of memory cache in kv cache manager
+ mem_manager : cache manager
+ """
+ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
+
+
+def init_to_get_rotary(self, base=10000, use_elem=False):
+ """
+ This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
+ Args:
+ self : Model that holds the rotary positional embedding
+ base : calculation arg
+ use_elem : activated when using chatglm-based models
+ """
+ self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
+ if not hasattr(self.config, "rope_scaling"):
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
+
+ if hasattr(self.config, "max_sequence_length"):
+ max_seq_len = self.config.max_sequence_length
+ elif hasattr(self.config, "max_position_embeddings"):
+ max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
+ else:
+ max_seq_len = 2048 * rope_scaling_factor
+ base = float(base)
+
+ # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+ ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
+
+ if ntk_alpha is not None:
+ ntk_alpha = float(ntk_alpha)
+ assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
+ if ntk_alpha > 1:
+ print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
+
+ n_elem = self.config.head_dim_
+ if use_elem:
+ n_elem //= 2
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
+ t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
+ freqs = torch.outer(t, inv_freq)
+
+ self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
+ self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
diff --git a/colossalai/legacy/inference/hybridengine/modeling/llama.py b/colossalai/legacy/inference/hybridengine/modeling/llama.py
new file mode 100644
index 000000000000..34474d115c8f
--- /dev/null
+++ b/colossalai/legacy/inference/hybridengine/modeling/llama.py
@@ -0,0 +1,489 @@
+# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
+import math
+from typing import List, Optional, Tuple
+
+import torch
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
+from transformers.utils import logging
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
+from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
+from colossalai.pipeline.stage_manager import PipelineStageManager
+
+from ._utils import copy_kv_to_mem_cache
+
+try:
+ from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
+ context_attention_fwd as lightllm_llama2_context_attention_fwd,
+ )
+ from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
+ context_attention_fwd as lightllm_context_attention_fwd,
+ )
+ from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
+
+ HAS_LIGHTLLM_KERNEL = True
+except:
+ print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
+ HAS_LIGHTLLM_KERNEL = False
+
+try:
+ from flash_attn import flash_attn_with_kvcache
+
+ HAS_FLASH_KERNEL = True
+except:
+ HAS_FLASH_KERNEL = False
+ print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def llama_triton_context_attention(
+ query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
+):
+ if num_key_value_groups == 1:
+ if HAS_LIGHTLLM_KERNEL is False:
+ llama_context_attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ # infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
+ )
+ else:
+ lightllm_context_attention_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ # infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
+ )
+ else:
+ assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model"
+ lightllm_llama2_context_attention_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ # infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
+ )
+
+
+def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
+ assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
+ if num_key_value_groups == 1:
+ token_attention_fwd(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ # infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
+ )
+ else:
+ Llama2TokenAttentionForwards.token_attn(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ # infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
+ infer_state.other_kv_index,
+ )
+
+
+class LlamaInferenceForwards:
+ """
+ This class holds forwards for llama inference.
+ We intend to replace the forward methods for LlamaModel, LlamaDecoderLayer, and LlamaAttention for LlamaForCausalLM.
+ """
+
+ @staticmethod
+ def llama_causal_lm_forward(
+ self: LlamaForCausalLM,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: BatchInferState = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ # If is first stage and after warmup, go throught lm_head first
+ if stage_manager.is_first_stage() and hidden_states is not None:
+ lm_logits = self.lm_head(hidden_states)
+ return {"logits": lm_logits}
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = LlamaInferenceForwards.llama_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ infer_state=infer_state,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ )
+
+ return outputs
+
+ @staticmethod
+ def llama_model_forward(
+ self: LlamaModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ infer_state: BatchInferState = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ # retrieve input_ids and inputs_embeds
+ if stage_manager is None or stage_manager.is_first_stage():
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ assert stage_manager is not None
+ assert hidden_states is not None, f"hidden_state should not be none in stage {stage_manager.stage}"
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ if infer_state.is_context_stage:
+ past_key_values_length = 0
+ else:
+ past_key_values_length = infer_state.max_len_in_batch - 1
+
+ # NOTE: differentiate with prefill stage
+ # block_loc require different value-assigning method for two different stage
+ if use_cache and seq_length != 1:
+ # NOTE assume prefill stage
+ # allocate memory block
+ infer_state.is_context_stage = True # set prefill stage, notify attention layer
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ infer_state.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
+ else:
+ infer_state.is_context_stage = False
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
+ else:
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.repeat(batch_size, 1)
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if infer_state.is_context_stage:
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+
+ else:
+ seq_len = infer_state.seq_len
+ infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
+ infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=hidden_states.device
+ )
+
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+ )
+
+ # decoder layers
+ infer_state.decode_layer_id = 0
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ if past_key_values is None:
+ past_key_values = tuple([None] * (end_idx - start_idx + 1))
+
+ for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
+ decoder_layer = self.layers[idx]
+ # NOTE: modify here for passing args to decoder layer
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+ infer_state.decode_layer_id += 1
+ hidden_states = layer_outputs[0]
+
+ if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
+ hidden_states = self.norm(hidden_states)
+
+ # update indices
+ # infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.max_len_in_batch += 1
+
+ # if not return_dict:
+ # return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ # return BaseModelOutputWithPast(
+ # last_hidden_state=hidden_states,
+ # past_key_values=next_cache,
+ # hidden_states=all_hidden_states,
+ # attentions=all_self_attns,
+ # )
+ return {"hidden_states": hidden_states}
+
+ @staticmethod
+ def llama_decoder_layer_forward(
+ self: LlamaDecoderLayer,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ infer_state=infer_state,
+ )
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+ @staticmethod
+ def llama_flash_attn_kvcache_forward(
+ self: LlamaAttention,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ assert use_cache is True, "use_cache should be set to True using this llama attention"
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # NOTE might think about better way to handle transposed k and v
+ # key_states [bs, seq_len, num_heads, head_dim/embed_size_per_head]
+ # key_states_transposed [bs, num_heads, seq_len, head_dim/embed_size_per_head]
+
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+
+ # NOTE might want to revise
+ # need some way to record the length of past key values cache
+ # since we won't return past_key_value_cache right now
+
+ cos, sin = infer_state.position_cos, infer_state.position_sin
+
+ llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
+ llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
+
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
+ value_states = value_states.reshape(-1, self.num_key_value_heads, self.head_dim)
+
+ if infer_state.is_context_stage:
+ # first token generation
+ # copy key and value calculated in current step to memory manager
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.context_mem_index,
+ infer_state.cache_manager,
+ )
+ attn_output = torch.empty_like(query_states)
+
+ llama_triton_context_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state,
+ num_key_value_groups=self.num_key_value_groups,
+ )
+ else:
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_k.copy_(key_states)
+ cache_v.copy_(value_states)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
+ copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.decode_mem_index,
+ infer_state.cache_manager,
+ )
+
+ if HAS_LIGHTLLM_KERNEL:
+ attn_output = torch.empty_like(query_states)
+ llama_triton_token_attention(
+ query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
+ )
+ else:
+ self.num_heads // self.num_key_value_heads
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
+
+ query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
+ copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
+ copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
+
+ attn_output = flash_attn_with_kvcache(
+ q=query_states,
+ k_cache=copy_cache_k,
+ v_cache=copy_cache_v,
+ softmax_scale=1 / math.sqrt(self.head_dim),
+ causal=True,
+ )
+
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ # return past_key_value as None
+ return attn_output, None, None
diff --git a/colossalai/legacy/inference/hybridengine/polices/__init__.py b/colossalai/legacy/inference/hybridengine/polices/__init__.py
new file mode 100644
index 000000000000..7271812c5366
--- /dev/null
+++ b/colossalai/legacy/inference/hybridengine/polices/__init__.py
@@ -0,0 +1,3 @@
+from .llama import LlamaModelInferPolicy
+
+__all__ = ["LlamaModelInferPolicy"]
diff --git a/colossalai/legacy/inference/hybridengine/polices/llama.py b/colossalai/legacy/inference/hybridengine/polices/llama.py
new file mode 100644
index 000000000000..992299714bd1
--- /dev/null
+++ b/colossalai/legacy/inference/hybridengine/polices/llama.py
@@ -0,0 +1,142 @@
+from functools import partial
+from typing import List
+
+import torch
+from torch.nn import Module
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaModel,
+ LlamaRMSNorm,
+)
+
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
+
+# import colossalai
+from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
+
+from ..modeling._utils import init_to_get_rotary
+from ..modeling.llama import LlamaInferenceForwards
+
+try:
+ from colossalai.kernel.triton import rmsnorm_forward
+
+ HAS_TRITON_RMSNORM = True
+except:
+ print("you should install triton from https://github.com/openai/triton")
+ HAS_TRITON_RMSNORM = False
+
+
+def get_triton_rmsnorm_forward():
+ if HAS_TRITON_RMSNORM:
+
+ def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
+ return rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
+
+ return _triton_rmsnorm_forward
+ else:
+ return None
+
+
+class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+
+ if self.shard_config.inference_gptq:
+ from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
+
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ }
+ policy[LlamaDecoderLayer] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=RowCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.gate_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.up_proj",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.down_proj",
+ target_module=RowCaiQuantLinear,
+ kwargs={"split_num": 1},
+ ),
+ ],
+ )
+
+ self.shard_config._infer()
+
+ infer_forward = LlamaInferenceForwards.llama_model_forward
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel)
+
+ infer_forward = LlamaInferenceForwards.llama_decoder_layer_forward
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
+ )
+
+ infer_forward = LlamaInferenceForwards.llama_flash_attn_kvcache_forward
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaAttention
+ )
+
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=LlamaForCausalLM, new_forward=LlamaInferenceForwards.llama_causal_lm_forward, policy=policy
+ )
+ infer_forward = None
+ if HAS_TRITON_RMSNORM:
+ infer_forward = get_triton_rmsnorm_forward()
+
+ if infer_forward is not None:
+ method_replacement = {"forward": partial(infer_forward)}
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=LlamaRMSNorm
+ )
+
+ return policy
+
+ def postprocess(self):
+ init_to_get_rotary(self.model.model)
+ return self.model
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_first_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
diff --git a/colossalai/legacy/inference/manager.py b/colossalai/legacy/inference/manager.py
new file mode 100644
index 000000000000..9672a50141a0
--- /dev/null
+++ b/colossalai/legacy/inference/manager.py
@@ -0,0 +1,296 @@
+# Adapted from https://github.com/ModelTC/lightllm
+
+import time
+from typing import List
+
+from .dynamic_batching.get_tokenizer import get_tokenizer
+from .dynamic_batching.infer_batch import InferBatch
+from .dynamic_batching.io_struct import Batch, Req
+from .dynamic_batching.req_queue import ReqQueue
+from .dynamic_batching.sampling_params import SamplingParams
+from .dynamic_batching.stats import Stats
+from .tensor_parallel import TPInferEngine
+
+
+class DynamicBatchManager:
+ def __init__(
+ self,
+ tp_engine: TPInferEngine,
+ max_total_token_num,
+ batch_max_tokens,
+ model,
+ tokenizer=None,
+ eos_id=None,
+ log_stats=True,
+ log_stats_interval=10,
+ running_batch: Batch = None,
+ waiting_req_list: List = [],
+ ):
+ """
+ Args: tp_engine : The tp engine that dynamic batch manager hold, defined before dynamic batch manager
+ max_total_token_num : max_total_token_num for memory manager, default to: max batch size * (max input len + max output len)
+ batch_max_tokens : max tokens of one batch, default to (max input + output len) * num_requests
+ running_max_req_size : max request size of running batch, equals to MAX_BATCH_SIZE of tp engine
+ eos_id : The end token of a seq
+ model: the model weight dir path, the app will load config, weights and tokenizer from this dir
+ log_stats : whether to log stats
+ log_stats_interval : log stats interval
+ running_batch : running batch
+ waiting_req_list : list of waiting requests, initialized before dynamic batch manager
+ """
+ self.engine = tp_engine
+ self.max_total_token_num = max_total_token_num
+ running_max_req_size = self.engine.max_batch_size if self.engine is not None else 2
+ self.req_queue = ReqQueue(max_total_token_num, batch_max_tokens, running_max_req_size, waiting_req_list)
+ # all the inputs should be put into req_queue: waiting req list
+ assert max_total_token_num >= self.engine.max_batch_size * (
+ self.engine.max_input_len + self.engine.max_output_len
+ ), "max_total_token_num should be greater than max_batch_size * (max_input_len+max_output_len)"
+ assert (
+ batch_max_tokens >= self.engine.max_input_len + self.engine.max_output_len
+ ), "batch_max_tokens should be greater than (max_input_len+max_output_len)"
+ self.running_batch: Batch = running_batch
+ self.eos_id = eos_id
+ self.has_wait_tokens = 0
+ self.max_wait_tokens = 10
+ self.model = model
+
+ self.stats_tool = Stats(log_stats, log_stats_interval)
+ self.mem_usage_interval = log_stats_interval * 2
+ self.tokenizer = get_tokenizer(tokenizer_name=self.model) if tokenizer is None else tokenizer
+ if self.eos_id == None:
+ self.eos_id = self.tokenizer.eos_token_id
+
+ def add_req(self, request_id: str, prompt_ids: List[int], sampling_params: SamplingParams, prompts: str = ""):
+ """
+ Add new request to req queue, during initialization all requests are held in waiting list.
+ """
+ sampling_params.max_new_tokens = (
+ self.engine.max_output_len
+ if sampling_params.max_new_tokens > self.engine.max_output_len
+ else sampling_params.max_new_tokens
+ )
+ req = Req(request_id, prompt_ids, sampling_params, prompts)
+ self.req_queue.append(req)
+ return
+
+ def add_input(self, request_id, prompts, sampling_params):
+ """
+ Encode and Add new input to req queue. support one sequence input for now.
+ """
+ prompt_ids = self.tokenizer.encode(prompts)
+ prompt_len = len(prompt_ids)
+ if prompt_len > self.engine.max_input_len:
+ raise ValueError(f"the input prompt token len {prompt_len} is too long > {self.engine.max_input_len}")
+ sampling_params.stop_sentences_to_token_ids(self.tokenizer)
+ self.add_req(request_id, prompt_ids, sampling_params, prompts)
+ return
+
+ def abort(self, request_id):
+ if self.running_batch is not None:
+ for req in self.running_batch.reqs:
+ if req.request_id == request_id:
+ req.has_generate_finished = True
+ req.aborted = True
+ for req in self.req_queue.waiting_req_list:
+ if req.request_id == request_id:
+ req.has_generate_finished = True
+ req.aborted = True
+ return
+
+ def loop_for_fwd(self):
+ """
+ The main loop for a dynamic batching process.
+ """
+ counter_count = 0
+ # self.running_batch is not None or self.req_queue.waiting_req_list
+ while self.running_batch is not None or self.req_queue.waiting_req_list:
+ yield from self._step()
+ counter_count += 1
+ if self.running_batch is not None:
+ if counter_count % self.mem_usage_interval == 0:
+ print(
+ "current batch size:",
+ len(self.running_batch.reqs),
+ "token used ratio:",
+ self.running_batch.calcu_used_tokens() / self.max_total_token_num,
+ )
+ self.stats_tool.print_stats()
+
+ if self.running_batch is None:
+ time.sleep(0.1) # 10ms
+
+ def _step(self):
+ """
+ Logic for handling requests
+ """
+
+ if self.running_batch is None:
+ new_batch = self.req_queue.generate_new_batch(self.running_batch)
+ if new_batch is not None:
+ self.stats_tool.count_prompt_tokens(new_batch)
+ self.running_batch = new_batch
+ yield from self._prefill_batch(self.running_batch)
+ self._filter_runing_batch()
+ self.has_wait_tokens = 0
+ return
+
+ if self.has_wait_tokens < self.max_wait_tokens:
+ self.stats_tool.count_output_tokens(self.running_batch)
+ yield from self._decode_batch(self.running_batch)
+ self._filter_runing_batch()
+ self.has_wait_tokens += 1
+ return
+ else:
+ new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
+ if new_mini_batch is not None:
+ self.stats_tool.count_prompt_tokens(new_mini_batch)
+ yield from self._prefill_batch(new_mini_batch)
+ if not new_mini_batch.is_clear():
+ self._merge_batch(self.running_batch, new_mini_batch)
+ self.running_batch.merge(new_mini_batch)
+ self.has_wait_tokens = 0
+
+ else:
+ self.stats_tool.count_output_tokens(self.running_batch)
+ yield from self._decode_batch(self.running_batch)
+ self._filter_runing_batch()
+ self.has_wait_tokens += 1
+
+ return
+
+ def _init_batch(self, batch: Batch, dtype="fp16"):
+ reqs = [r.to_rpc_obj() for r in batch.reqs]
+ batch_id = batch.batch_id
+
+ import torch
+
+ if dtype == "fp16":
+ dtype = torch.float16
+ else:
+ assert False, "error dtype"
+
+ batch_data = InferBatch.init_batch(
+ batch_id,
+ reqs,
+ dtype,
+ torch.cuda.current_device(),
+ self.engine.cache_manager,
+ self.engine.model.config.vocab_size,
+ self.engine.max_input_len + self.engine.max_output_len,
+ )
+ self.engine.cache[batch_id] = batch_data
+
+ def _prefill_batch(self, batch):
+ """
+ For all batches, no matter it is a new batch or a mini batch, we need to do prefill first.
+ """
+ self._init_batch(batch)
+
+ # TODO: figure out if cache and batch id is needed
+ ans = self.engine._prefill_batch(batch.batch_id)
+ req_to_out_token_id = ans
+ self._add_token_id_to_req(batch, req_to_out_token_id)
+ has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
+ yield from self._handle_finish_req(batch, has_new_finished_req)
+
+ # delete finished reqs
+
+ def _decode_batch(self, batch: Batch):
+ """
+ Decoding process
+ """
+ ans = self.engine._decode_batch(batch.batch_id)
+ req_to_out_token_id = ans
+ self._add_token_id_to_req(batch, req_to_out_token_id)
+ has_new_finished_req = batch.mark_finished_req(self.eos_id, self.engine.max_output_len)
+ yield from self._handle_finish_req(batch, has_new_finished_req)
+
+ def _filter_batch(self, batch: Batch):
+ batch_id = batch.batch_id
+ req_id_list = [r.request_id for r in batch.reqs]
+ batch = self.engine.cache.pop(batch_id)
+ filter_batch = batch.filter(req_id_list)
+ del batch
+ self.engine.cache[batch_id] = filter_batch
+
+ def _merge_batch(self, batch1, batch2):
+ """
+ Merge new mini batch into running batch.
+ """
+ batch1 = self.engine.cache.pop(batch1.batch_id)
+ batch2 = self.engine.cache.pop(batch2.batch_id)
+
+ m_batch = InferBatch.merge(batch1, batch2)
+ self.engine.cache[batch1.batch_id] = m_batch
+ del batch1
+ del batch2
+
+ def _remove_batch(self, batch):
+ """
+ Remove finished batch.
+ """
+ batch = self.engine.cache.pop(batch.batch_id)
+ batch.free_self()
+ del batch
+
+ def _handle_finish_req(self, batch: Batch, has_new_finished_req):
+ if has_new_finished_req:
+ finished_reqs = batch.filter_finished()
+ if batch.is_clear():
+ self._remove_batch(batch)
+ else:
+ self._filter_batch(batch)
+ yield from self._output_process(finished_reqs)
+
+ def _filter_runing_batch(self):
+ if self.running_batch is not None and self.running_batch.is_clear():
+ self.running_batch = None
+
+ def _add_token_id_to_req(self, batch: Batch, req_ans):
+ for req_id, (new_token_id, new_gen_metadata) in req_ans.items():
+ req = batch.id_to_reqs[req_id]
+ req.output_ids.append(new_token_id)
+ req.output_metadata_list.append(new_gen_metadata)
+ return
+
+ def _output_process(self, finished_reqs: List[Req]):
+ """
+ Process the output of a batch.
+ """
+ for req in finished_reqs:
+ output = self.tokenizer.decode(req.output_ids)
+ yield req.prompts + output
+
+ def clean_up(self):
+ # this logic should be implemented in the future.
+ pass
+
+ def generate(self, request_id, prompts, sampling_params):
+ """
+ Generate the output of a request.
+ """
+ self.add_input(request_id, prompts, sampling_params)
+ return self.loop_for_fwd()
+
+ def is_running(self):
+ return self.running_batch is not None or self.req_queue.waiting_req_list
+
+
+def start_dynamic_batching(args, tp_engine, waiting_req_list):
+ try:
+ batch_manager = DynamicBatchManager(
+ tp_engine=tp_engine,
+ max_total_token_num=args.max_total_token_num,
+ batch_max_tokens=args.batch_max_tokens,
+ eos_id=args.eos_id,
+ model=args.model,
+ log_stats=not args.disable_log_stats,
+ log_stats_interval=args.log_stats_interval,
+ waiting_req_list=waiting_req_list,
+ )
+
+ except Exception:
+ raise Exception
+
+ return batch_manager
diff --git a/colossalai/inference/pipeline/README.md b/colossalai/legacy/inference/pipeline/README.md
similarity index 62%
rename from colossalai/inference/pipeline/README.md
rename to colossalai/legacy/inference/pipeline/README.md
index a90d5d6da182..f9bb35cc4d4c 100644
--- a/colossalai/inference/pipeline/README.md
+++ b/colossalai/legacy/inference/pipeline/README.md
@@ -17,7 +17,7 @@
Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py).
1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks:
- - Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`.
+ - Initialize the pipeline inference environment with `PipelineStageManager` and model with `ShardFormer`.
- Run the pipeline inference model.
2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks:
@@ -31,54 +31,53 @@ Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManag
### Example
```python
-from colossalai.pipeline import PPInferEngine
-# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example.
-model = LlamaForCausalLM.from_pretrained('/path/to/model')
-inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt")
-engine = PPInferEngine(
- pp_size=2,
- dtype='fp16',
- micro_batch_size=1,
- new_length=10,
- model=model,
- model_policy=LlamaForCausalLMPipelinePolicy())
-
-output = engine.inference([inputs])
+from colossalai.inference import PPInferEngine
+from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
+import colossalai
+from transformers import LlamaForCausalLM, LlamaTokenizer
-```
+colossalai.launch_from_torch(config={})
+
+model = LlamaForCausalLM.from_pretrained("/path/to/model")
+tokenizer = LlamaTokenizer.from_pretrained("/path/to/model")
-### Quick start
-```shell
-cd benchmark
-sh run.sh
+# assume the model is inferred with 2 pipeline stages
+inferengine = PPInferEngine(pp_size=2, model=model, model_policy=LlamaModelInferPolicy(), new_length=32)
+
+input = ["Introduce a landmark in London","Introduce a landmark in Singapore"]
+data = tokenizer(input, return_tensors='pt')
+output = inferengine.inference(data.to('cuda'))
+print(tokenizer.batch_decode(output))
```
## Performance
-We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G.
+We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2 * A10, 20G / 2 * A800, 80G.
-### Llama Throughput(tokens/s)
+### Llama Throughput (tokens/s) | input length=1024, output length=128
-#### 7b, fp16
+#### A10 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
-| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM |
-| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
-| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 |
-| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM |
+| Pipeline Inference | 40.35 | 77.1 | 139.03 | 232.7 | 257.81 | OOM |
+| Hugging Face | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
-#### 7b, fp32
+#### A10 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
-| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 |
-| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM |
-| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 |
-| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM |
+| Pipeline Inference | 25.39 | 47.09 | 83.7 | 89.46 |
+| Hugging Face | 23.48 | 37.59 | 53.44 | OOM |
-#### 13b, fp16
-| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
-| :---: | :---: | :---: | :---: | :---: |
-| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 |
-| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM |
-| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 |
-| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM |
+
+#### A800 7b, fp16
+| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
+| :---: | :---: | :---: | :---: | :---: | :---: |
+| Pipeline Inference| 57.97 | 110.13 | 213.33 | 389.86 | 670.12 |
+| Hugging Face | 42.44 | 76.5 | 151.97 | 212.88 | 256.13 |
+
+
+#### A800 13b, fp16
+| batch_size(micro_batch size) | 2(1) | 4(2) | 8(4) | 16(8) | 32(16) |
+| :---: | :---: | :---: | :---: | :---: | :---: |
+| Pipeline Inference | 41.78 | 94.18 | 172.67| 310.75| 470.15 |
+| Hugging Face | 36.57 | 68.4 | 105.81 | 139.51 | 166.34 |
diff --git a/colossalai/legacy/inference/pipeline/__init__.py b/colossalai/legacy/inference/pipeline/__init__.py
new file mode 100644
index 000000000000..f43e4a847448
--- /dev/null
+++ b/colossalai/legacy/inference/pipeline/__init__.py
@@ -0,0 +1,3 @@
+from .microbatch_manager import MicroBatchManager
+
+__all__ = ["MicroBatchManager"]
diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/legacy/inference/pipeline/benchmark/benchmark.py
similarity index 96%
rename from colossalai/inference/pipeline/benchmark/benchmark.py
rename to colossalai/legacy/inference/pipeline/benchmark/benchmark.py
index 9c47909f70f0..8392d0a1e579 100644
--- a/colossalai/inference/pipeline/benchmark/benchmark.py
+++ b/colossalai/legacy/inference/pipeline/benchmark/benchmark.py
@@ -7,7 +7,7 @@
import colossalai
from colossalai.inference import PPInferEngine
-from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
+from colossalai.inference.pipeline.policies import LlamaModelInferPolicy
GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024
@@ -117,8 +117,11 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
micro_batch_size=args.mb_size,
new_length=args.new_length,
model=model,
- model_policy=LlamaForCausalLMPipelinePolicy(),
+ model_policy=LlamaModelInferPolicy(),
verbose=True,
+ max_batch_size=args.mb_size,
+ max_input_len=args.seq_len,
+ max_output_len=args.seq_len + args.new_length + 256,
)
data = data_gen(args.batch_size, args.seq_len)
diff --git a/colossalai/inference/pipeline/benchmark/run.sh b/colossalai/legacy/inference/pipeline/benchmark/run.sh
similarity index 94%
rename from colossalai/inference/pipeline/benchmark/run.sh
rename to colossalai/legacy/inference/pipeline/benchmark/run.sh
index 7d8da858692f..e3c33bb88db1 100644
--- a/colossalai/inference/pipeline/benchmark/run.sh
+++ b/colossalai/legacy/inference/pipeline/benchmark/run.sh
@@ -1,7 +1,7 @@
script_dir=$(cd "$(dirname "$0")" && pwd)
cd "${script_dir}"
-# 7b, fp32, 2 gpu, 1024, 128
+# 7b, fp16, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
@@ -13,7 +13,7 @@ for BATCH_SIZE in 2 4 8 16; do
--pp_size=2
done
-# 7b, fp32, 2 gpu, 512, 512
+# 7b, fp16, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16 32; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
@@ -25,7 +25,7 @@ for BATCH_SIZE in 2 4 8 16 32; do
--pp_size=2
done
-# 7b, fp32, 2 gpu, 1024, 128
+# 7b, fp16, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \
diff --git a/colossalai/legacy/inference/pipeline/microbatch_manager.py b/colossalai/legacy/inference/pipeline/microbatch_manager.py
new file mode 100644
index 000000000000..441cf603985c
--- /dev/null
+++ b/colossalai/legacy/inference/pipeline/microbatch_manager.py
@@ -0,0 +1,249 @@
+from enum import Enum
+from typing import Dict
+
+import torch
+
+from ..tensor_parallel.batch_infer_state import BatchInferState
+from ..tensor_parallel.kvcache_manager import MemoryManager
+
+__all__ = "MicroBatchManager"
+
+
+class Status(Enum):
+ PREFILL = 1
+ GENERATE = 2
+ DONE = 3
+ COOLDOWN = 4
+
+
+class MicroBatchDescription:
+ """
+ This is the class to record the infomation of each microbatch, and also do some update operation.
+ This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
+ details, please refer to the doc of these two classes blow.
+
+ Args:
+ inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
+ output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
+ """
+
+ def __init__(
+ self,
+ inputs_dict: Dict[str, torch.Tensor],
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager: MemoryManager,
+ ) -> None:
+ self.mb_length = inputs_dict["input_ids"].shape[-1]
+ self.target_length = self.mb_length + max_output_len
+ self.infer_state = BatchInferState.init_from_batch(
+ batch=inputs_dict, max_input_len=max_input_len, max_output_len=max_output_len, cache_manager=cache_manager
+ )
+ # print(f"[init] {inputs_dict}, {max_input_len}, {max_output_len}, {cache_manager}, {self.infer_state}")
+
+ def update(self, *args, **kwargs):
+ pass
+
+ @property
+ def state(self):
+ """
+ Return the state of current micro batch, when current length is equal to target length,
+ the state is DONE, otherwise GENERATE
+
+ """
+ # TODO: add the condition for early stopping
+ if self.cur_length == self.target_length:
+ return Status.DONE
+ elif self.cur_length == self.target_length - 1:
+ return Status.COOLDOWN
+ else:
+ return Status.GENERATE
+
+ @property
+ def cur_length(self):
+ """
+ Return the current sequnence length of micro batch
+
+ """
+
+
+class HeadMicroBatchDescription(MicroBatchDescription):
+ """
+ This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
+ and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
+ information and the condition to determine the state is different from other stages.
+
+ Args:
+ inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
+ output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
+
+ """
+
+ def __init__(
+ self,
+ inputs_dict: Dict[str, torch.Tensor],
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager: MemoryManager,
+ ) -> None:
+ super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
+ assert inputs_dict is not None
+ assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
+ self.input_ids = inputs_dict["input_ids"]
+ self.attn_mask = inputs_dict["attention_mask"]
+ self.new_tokens = None
+
+ def update(self, new_token: torch.Tensor = None):
+ if new_token is not None:
+ self._update_newtokens(new_token)
+ if self.state is not Status.DONE and new_token is not None:
+ self._update_attnmask()
+
+ def _update_newtokens(self, new_token: torch.Tensor):
+ if self.new_tokens is None:
+ self.new_tokens = new_token
+ else:
+ self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1)
+
+ def _update_attnmask(self):
+ self.attn_mask = torch.cat(
+ (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
+ )
+
+ @property
+ def cur_length(self):
+ """
+ When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token
+
+ """
+ if self.new_tokens is None:
+ return self.mb_length
+ else:
+ return self.mb_length + len(self.new_tokens[0])
+
+
+class BodyMicroBatchDescription(MicroBatchDescription):
+ """
+ This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
+
+ Args:
+ inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
+ """
+
+ def __init__(
+ self,
+ inputs_dict: Dict[str, torch.Tensor],
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager: MemoryManager,
+ ) -> None:
+ super().__init__(inputs_dict, max_input_len, max_output_len, cache_manager)
+
+ @property
+ def cur_length(self):
+ """
+ When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
+
+ """
+ return self.infer_state.seq_len.max().item()
+
+
+class MicroBatchManager:
+ """
+ MicroBatchManager is a class that manages the micro batch.
+
+ Args:
+ stage (int): stage id of current stage.
+ micro_batch_size (int): the micro batch size.
+ micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
+
+ """
+
+ def __init__(
+ self,
+ stage: int,
+ micro_batch_size: int,
+ micro_batch_buffer_size: int,
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager_list: MemoryManager,
+ ):
+ self.stage = stage
+ self.micro_batch_size = micro_batch_size
+ self.buffer_size = micro_batch_buffer_size
+ self.max_input_len = max_input_len
+ self.max_output_len = max_output_len
+ self.cache_manager_list = cache_manager_list
+ self.mb_descrption_buffer = {}
+ self.new_tokens_buffer = {}
+ self.idx = 0
+
+ def add_descrption(self, inputs_dict: Dict[str, torch.Tensor]):
+ if self.stage == 0:
+ self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(
+ inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
+ )
+ else:
+ self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(
+ inputs_dict, self.max_input_len, self.max_output_len, self.cache_manager_list[self.idx]
+ )
+
+ def step(self, new_token: torch.Tensor = None):
+ """
+ Update the state if microbatch manager, 2 conditions.
+ 1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
+ 2. For other conditon, only receive the output of previous stage, and update the descrption.
+
+ Args:
+ inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
+ output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
+ new_token (torch.Tensor): the new token generated by current stage.
+ """
+ # Add descrption first if the descrption is None
+ self.cur_descrption.update(new_token)
+ return self.cur_state
+
+ def export_new_tokens(self):
+ new_tokens_list = []
+ for i in self.mb_descrption_buffer.values():
+ new_tokens_list.extend(i.new_tokens.tolist())
+ return new_tokens_list
+
+ def is_micro_batch_done(self):
+ if len(self.mb_descrption_buffer) == 0:
+ return False
+ for mb in self.mb_descrption_buffer.values():
+ if mb.state != Status.DONE:
+ return False
+ return True
+
+ def clear(self):
+ self.mb_descrption_buffer.clear()
+ for cache in self.cache_manager_list:
+ cache.free_all()
+
+ def next(self):
+ self.idx = (self.idx + 1) % self.buffer_size
+
+ def _remove_descrption(self):
+ self.mb_descrption_buffer.pop(self.idx)
+
+ @property
+ def cur_descrption(self) -> MicroBatchDescription:
+ return self.mb_descrption_buffer.get(self.idx)
+
+ @property
+ def cur_infer_state(self):
+ if self.cur_descrption is None:
+ return None
+ return self.cur_descrption.infer_state
+
+ @property
+ def cur_state(self):
+ """
+ Return the state of current micro batch, when current descrption is None, the state is PREFILL
+
+ """
+ if self.cur_descrption is None:
+ return Status.PREFILL
+ return self.cur_descrption.state
diff --git a/colossalai/legacy/inference/quant/gptq/__init__.py b/colossalai/legacy/inference/quant/gptq/__init__.py
new file mode 100644
index 000000000000..c035f397923a
--- /dev/null
+++ b/colossalai/legacy/inference/quant/gptq/__init__.py
@@ -0,0 +1,4 @@
+from .cai_gptq import HAS_AUTO_GPTQ
+
+if HAS_AUTO_GPTQ:
+ from .cai_gptq import CaiGPTQLinearOp, CaiQuantLinear
diff --git a/colossalai/legacy/inference/quant/gptq/cai_gptq/__init__.py b/colossalai/legacy/inference/quant/gptq/cai_gptq/__init__.py
new file mode 100644
index 000000000000..4ed76293bd81
--- /dev/null
+++ b/colossalai/legacy/inference/quant/gptq/cai_gptq/__init__.py
@@ -0,0 +1,14 @@
+import warnings
+
+HAS_AUTO_GPTQ = False
+try:
+ import auto_gptq
+
+ HAS_AUTO_GPTQ = True
+except ImportError:
+ warnings.warn("please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ")
+ HAS_AUTO_GPTQ = False
+
+if HAS_AUTO_GPTQ:
+ from .cai_quant_linear import CaiQuantLinear, ColCaiQuantLinear, RowCaiQuantLinear
+ from .gptq_op import CaiGPTQLinearOp
diff --git a/colossalai/legacy/inference/quant/gptq/cai_gptq/cai_quant_linear.py b/colossalai/legacy/inference/quant/gptq/cai_gptq/cai_quant_linear.py
new file mode 100644
index 000000000000..36339ac88486
--- /dev/null
+++ b/colossalai/legacy/inference/quant/gptq/cai_gptq/cai_quant_linear.py
@@ -0,0 +1,354 @@
+# Adapted from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ
+
+import math
+import warnings
+from typing import List, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+from colossalai.lazy import LazyInitContext
+from colossalai.shardformer.layer import ParallelModule
+
+from .gptq_op import CaiGPTQLinearOp
+
+HAS_GPTQ_CUDA = False
+try:
+ from colossalai.kernel.op_builder.gptq import GPTQBuilder
+
+ gptq_cuda = GPTQBuilder().load()
+ HAS_GPTQ_CUDA = True
+except ImportError:
+ warnings.warn("CUDA gptq is not installed")
+ HAS_GPTQ_CUDA = False
+
+
+class CaiQuantLinear(nn.Module):
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
+ super().__init__()
+ if bits not in [2, 4, 8]:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ self.infeatures = infeatures
+ self.outfeatures = outfeatures
+ self.bits = bits
+ self.maxq = 2**self.bits - 1
+ self.groupsize = groupsize if groupsize != -1 else infeatures
+
+ self.register_buffer("qweight", torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
+ self.register_buffer(
+ "qzeros",
+ torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32),
+ )
+ self.register_buffer(
+ "scales", torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
+ )
+ if row_split:
+ self.register_buffer(
+ "g_idx",
+ torch.tensor(
+ [(i + (tp_rank * self.infeatures)) // self.groupsize for i in range(infeatures)], dtype=torch.int32
+ ),
+ )
+ else:
+ self.register_buffer(
+ "g_idx", torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
+ )
+
+ if bias:
+ self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
+ else:
+ self.bias = None
+
+ self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)
+
+ self.q4 = None
+ self.empty_tensor = torch.empty((1, 1), device="meta")
+ self.tp_size = tp_size
+ self.tp_rank = tp_rank
+ self.row_split = row_split
+
+ def pack(self, linear, scales, zeros, g_idx=None):
+ g_idx = (
+ g_idx.clone()
+ if g_idx is not None
+ else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
+ )
+
+ scales = scales.t().contiguous()
+ zeros = zeros.t().contiguous()
+ scale_zeros = zeros * scales
+ half_scales = scales.clone().half()
+ # print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
+ self.scales = scales.clone().half()
+ if linear.bias is not None:
+ self.bias = linear.bias.clone().half()
+
+ pbits = 32
+ ptype = torch.int32
+ unsign_type = np.uint32
+ sign_type = np.int32
+
+ intweight = []
+ for idx in range(self.infeatures):
+ intweight.append(
+ torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[
+ :, None
+ ]
+ )
+ intweight = torch.cat(intweight, dim=1)
+ intweight = intweight.t().contiguous()
+ intweight = intweight.numpy().astype(unsign_type)
+ qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)
+
+ i = 0
+ row = 0
+
+ while row < qweight.shape[0]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (pbits // self.bits)):
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
+ i += pbits // self.bits
+ row += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ qweight = qweight.astype(sign_type)
+ qweight1 = torch.from_numpy(qweight)
+ qweight1 = qweight1.contiguous() # .to("cuda")
+ self.qweight.data.copy_(qweight1)
+
+ qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
+ zeros -= 1
+ zeros = zeros.numpy().astype(unsign_type)
+ i = 0
+ col = 0
+ while col < qzeros.shape[1]:
+ if self.bits in [2, 4, 8]:
+ for j in range(i, i + (pbits // self.bits)):
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
+ i += pbits // self.bits
+ col += 1
+ else:
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
+ qzeros = qzeros.astype(sign_type)
+ qzeros = torch.from_numpy(qzeros)
+ qzeros = qzeros
+ self.qzeros.data.copy_(qzeros)
+
+ if torch.equal(self.g_idx.to(g_idx.device), g_idx):
+ self.g_idx = None
+ else:
+ self.g_idx = g_idx
+
+ def init_q4(self):
+ assert self.qweight.device.type == "cuda"
+ self.q4_width = self.qweight.shape[1]
+ if self.g_idx is not None:
+ if self.row_split and torch.equal(
+ self.g_idx,
+ torch.tensor(
+ [(i + (self.tp_rank * self.infeatures)) // self.groupsize for i in range(self.infeatures)],
+ dtype=torch.int32,
+ device=self.g_idx.device,
+ ),
+ ):
+ self.g_idx = None
+ elif torch.equal(
+ self.g_idx,
+ torch.tensor(
+ [i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device
+ ),
+ ):
+ self.g_idx = None
+
+ if self.g_idx is not None:
+ g_idx = self.g_idx.to("cpu")
+ else:
+ g_idx = self.empty_tensor
+
+ self.q4 = gptq_cuda.make_q4(self.qweight, self.qzeros, self.scales, g_idx, torch.cuda.current_device())
+ torch.cuda.synchronize()
+
+ def forward(self, x):
+ outshape = x.shape[:-1] + (self.outfeatures,)
+
+ if HAS_GPTQ_CUDA and self.bits == 4:
+ if self.q4 is None:
+ self.init_q4()
+
+ x = x.view(-1, x.shape[-1])
+ output = torch.empty((x.shape[0], self.outfeatures), dtype=torch.float16, device=x.device)
+ gptq_cuda.q4_matmul(x.half(), self.q4, output)
+ if self.bias is not None and (not self.row_split or self.tp_size == 1):
+ output.add_(self.bias)
+ else:
+ if self.bias is not None and (not self.row_split or self.tp_size == 1):
+ bias = self.bias
+ else:
+ bias = None
+ output = self.gptq_linear(
+ x,
+ self.qweight,
+ self.scales,
+ self.qzeros,
+ g_idx=self.g_idx,
+ bias=bias,
+ )
+ return output.view(outshape)
+
+
+def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1):
+ qweights = gptq_linear.qweight.split(gptq_linear.out_features // split_num, dim=-1)
+ qzeros = gptq_linear.qzeros.split(gptq_linear.out_features // (32 // cai_linear.bits) // split_num, dim=-1)
+ scales = gptq_linear.scales.split(gptq_linear.out_features // split_num, dim=-1)
+ g_idx = gptq_linear.g_idx
+ if gptq_linear.bias is not None:
+ bias = gptq_linear.bias.split(gptq_linear.out_features // split_num, dim=-1)
+
+ cai_split_out_features = cai_linear.outfeatures // split_num
+ zero_split_block = cai_linear.outfeatures // (32 // cai_linear.bits) // split_num
+
+ for i in range(split_num):
+ cai_linear.qweight[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = qweights[i][
+ :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
+ ]
+ cai_linear.qzeros[:, i * zero_split_block : (i + 1) * zero_split_block] = qzeros[i][
+ :, tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block
+ ]
+ cai_linear.scales[:, i * cai_split_out_features : (i + 1) * cai_split_out_features] = scales[i][
+ :, tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
+ ]
+ if cai_linear.bias is not None:
+ cai_linear.bias[i * cai_split_out_features : (i + 1) * cai_split_out_features] = bias[i][
+ tp_rank * cai_split_out_features : (tp_rank + 1) * cai_split_out_features
+ ]
+
+ cai_linear.g_idx.copy_(g_idx)
+
+
+def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
+ qweights = gptq_linear.qweight.split(gptq_linear.in_features // split_num, dim=0)
+ qzeros = gptq_linear.qzeros.split(gptq_linear.in_features // split_num, dim=0)
+ scales = gptq_linear.scales.split(gptq_linear.in_features // split_num, dim=0)
+ g_idxs = gptq_linear.g_idx.split(gptq_linear.in_features // split_num, dim=0)
+
+ cai_split_in_features = cai_linear.infeatures // (32 // cai_linear.bits) // split_num
+ zero_split_block = cai_linear.infeatures // cai_linear.groupsize // split_num
+ idx_split_features = cai_linear.infeatures // split_num
+
+ for i in range(split_num):
+ cai_linear.qweight[i * cai_split_in_features : (i + 1) * cai_split_in_features, :] = qweights[i][
+ tp_rank * cai_split_in_features : (tp_rank + 1) * cai_split_in_features, :
+ ]
+ cai_linear.qzeros[i * zero_split_block : (i + 1) * zero_split_block, :] = qzeros[i][
+ tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
+ ]
+ cai_linear.scales[i * zero_split_block : (i + 1) * zero_split_block, :] = scales[i][
+ tp_rank * zero_split_block : (tp_rank + 1) * zero_split_block, :
+ ]
+ cai_linear.g_idx[i * idx_split_features : (i + 1) * idx_split_features] = g_idxs[i][
+ tp_rank * idx_split_features : (tp_rank + 1) * idx_split_features
+ ]
+ if cai_linear.bias is not None:
+ cai_linear.bias.copy_(gptq_linear.bias)
+
+
+class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
+ super().__init__(
+ bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
+ )
+ self.process_group = None
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
+ )
+ linear_1d = RowCaiQuantLinear(
+ module.bits,
+ module.group_size,
+ module.in_features // tp_size,
+ module.out_features,
+ module.bias is not None,
+ tp_size=tp_size,
+ tp_rank=tp_rank,
+ row_split=True,
+ )
+ linear_1d.process_group = process_group
+
+ split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ return linear_1d
+
+ def forward(self, x):
+ output = super().forward(x)
+ if self.tp_size > 1:
+ dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
+ if self.bias is not None:
+ output.add_(self.bias)
+ return output
+
+
+class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias, tp_size=1, tp_rank=0, row_split=False):
+ super().__init__(
+ bits, groupsize, infeatures, outfeatures, bias, tp_size=tp_size, tp_rank=tp_rank, row_split=row_split
+ )
+ self.process_group = None
+
+ @staticmethod
+ def from_native_module(
+ module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
+ ) -> ParallelModule:
+ LazyInitContext.materialize(module)
+ # get the attributes
+ in_features = module.in_features
+
+ # ensure only one process group is passed
+ if isinstance(process_group, (list, tuple)):
+ assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
+ process_group = process_group[0]
+
+ tp_size = dist.get_world_size(process_group)
+ tp_rank = dist.get_rank(process_group)
+
+ if in_features < tp_size:
+ return module
+
+ if in_features % tp_size != 0:
+ raise ValueError(
+ f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
+ )
+ linear_1d = ColCaiQuantLinear(
+ module.bits,
+ module.group_size,
+ module.in_features,
+ module.out_features // tp_size,
+ module.bias is not None,
+ tp_size=tp_size,
+ tp_rank=tp_rank,
+ )
+ linear_1d.process_group = process_group
+
+ split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
+ return linear_1d
diff --git a/colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py b/colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py
new file mode 100644
index 000000000000..a8902eb35cd0
--- /dev/null
+++ b/colossalai/legacy/inference/quant/gptq/cai_gptq/gptq_op.py
@@ -0,0 +1,58 @@
+import torch
+
+from colossalai.kernel.triton import gptq_fused_linear_triton
+
+
+class CaiGPTQLinearOp(torch.nn.Module):
+ def __init__(self, gptq_group_size, gptq_quant_bits):
+ super(CaiGPTQLinearOp, self).__init__()
+ self.group_size = gptq_group_size
+ self.bits = gptq_quant_bits
+ self.maxq = 2**self.bits - 1
+ self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_scales: torch.Tensor,
+ weight_zeros: torch.Tensor,
+ g_idx: torch.Tensor = None,
+ act_type=0,
+ bias: torch.Tensor = None,
+ residual: torch.Tensor = None,
+ qkv_fused=False,
+ ):
+ add_bias = True
+ if bias is None:
+ bias = self.empty_tensor
+ add_bias = False
+
+ add_residual = True
+ if residual is None:
+ residual = self.empty_tensor
+ add_residual = False
+ x = input.view(-1, input.shape[-1])
+
+ out = gptq_fused_linear_triton(
+ x,
+ weight,
+ weight_scales,
+ weight_zeros,
+ bias,
+ residual,
+ self.bits,
+ self.maxq,
+ self.group_size,
+ qkv_fused,
+ add_bias,
+ add_residual,
+ act_type=act_type,
+ g_idx=g_idx,
+ )
+ if qkv_fused:
+ out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
+ else:
+ out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
+
+ return out
diff --git a/colossalai/legacy/inference/quant/smoothquant/__init__.py b/colossalai/legacy/inference/quant/smoothquant/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/colossalai/legacy/inference/quant/smoothquant/models/__init__.py b/colossalai/legacy/inference/quant/smoothquant/models/__init__.py
new file mode 100644
index 000000000000..77541d8610c5
--- /dev/null
+++ b/colossalai/legacy/inference/quant/smoothquant/models/__init__.py
@@ -0,0 +1,12 @@
+try:
+ import torch_int
+
+ HAS_TORCH_INT = True
+except ImportError:
+ HAS_TORCH_INT = False
+ raise ImportError(
+ "Not install torch_int. Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int"
+ )
+
+if HAS_TORCH_INT:
+ from .llama import LLamaSmoothquantAttention, LlamaSmoothquantMLP
diff --git a/colossalai/legacy/inference/quant/smoothquant/models/base_model.py b/colossalai/legacy/inference/quant/smoothquant/models/base_model.py
new file mode 100644
index 000000000000..9554be9ea96b
--- /dev/null
+++ b/colossalai/legacy/inference/quant/smoothquant/models/base_model.py
@@ -0,0 +1,487 @@
+# Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ
+# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
+# Adapted from smoothquant: https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
+
+import os
+import warnings
+from abc import abstractmethod
+from functools import partial
+from os.path import isdir, isfile, join
+from typing import Dict, List, Optional, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn as nn
+import transformers
+from safetensors.torch import save_file as safe_save
+from tqdm import tqdm
+from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
+from transformers.modeling_utils import no_init_weights
+from transformers.utils.generic import ContextManagers
+from transformers.utils.hub import PushToHubMixin, cached_file
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.inference.tensor_parallel.kvcache_manager import MemoryManager
+
+SUPPORTED_MODELS = ["llama"]
+
+
+class BaseSmoothForCausalLM(nn.Module, PushToHubMixin):
+ layer_type: str = None
+
+ def __init__(self, model: PreTrainedModel, quantized: bool = False):
+ super().__init__()
+
+ self.model = model
+ self.model_type = self.model.config.model_type
+ self._quantized = quantized
+ self.config = self.model.config
+ self.cache_manager = None
+ self.max_total_token_num = 0
+
+ @property
+ def quantized(self):
+ return self._quantized
+
+ def init_cache_manager(self, max_total_token_num=2048):
+ if self.config.model_type == "llama":
+ head_num = self.config.num_key_value_heads
+ layer_num = self.config.num_hidden_layers
+ head_dim = self.config.hidden_size // head_num
+
+ self.cache_manager = MemoryManager(max_total_token_num, torch.int8, head_num, head_dim, layer_num)
+ self.max_total_token_num = max_total_token_num
+
+ def init_batch_state(self, max_output_len=256, **kwargs):
+ input_ids = kwargs["input_ids"]
+ batch_size = len(input_ids)
+
+ seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ start_index = 0
+ max_len_in_batch = -1
+
+ for i in range(batch_size):
+ seq_len = len(input_ids[i])
+ seq_lengths[i] = seq_len
+ seq_start_indexes[i] = start_index
+ start_index += seq_len
+ max_len_in_batch = seq_len if seq_len > max_len_in_batch else max_len_in_batch
+
+ if "max_total_token_num" in kwargs.keys():
+ max_total_token_num = kwargs["max_total_token_num"]
+ self.init_cache_manager(max_total_token_num)
+
+ if "max_new_tokens" in kwargs.keys():
+ max_output_len = kwargs["max_new_tokens"]
+
+ if batch_size * (max_len_in_batch + max_output_len) > self.max_total_token_num:
+ max_total_token_num = batch_size * (max_len_in_batch + max_output_len)
+ warnings.warn(f"reset max tokens to {max_total_token_num}")
+ self.init_cache_manager(max_total_token_num)
+
+ block_loc = torch.empty((batch_size, max_len_in_batch + max_output_len), dtype=torch.long, device="cuda")
+ batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
+ batch_infer_state.seq_len = seq_lengths.to("cuda")
+ batch_infer_state.start_loc = seq_start_indexes.to("cuda")
+ batch_infer_state.block_loc = block_loc
+ batch_infer_state.decode_layer_id = 0
+ batch_infer_state.is_context_stage = True
+ batch_infer_state.set_cache_manager(self.cache_manager)
+ batch_infer_state.cache_manager.free_all()
+ return batch_infer_state
+
+ @abstractmethod
+ @torch.inference_mode()
+ def quantize(
+ self,
+ examples: List[Dict[str, Union[List[int], torch.LongTensor]]],
+ ):
+ if self.quantized:
+ raise EnvironmentError("can't execute quantize because the model is quantized.")
+
+ def forward(self, *args, **kwargs):
+ return self.model(*args, **kwargs)
+
+ def generate(self, **kwargs):
+ """shortcut for model.generate"""
+
+ batch_infer_state = self.init_batch_state(**kwargs)
+ if self.config.model_type == "llama":
+ setattr(self.model.model, "infer_state", batch_infer_state)
+
+ with torch.inference_mode():
+ return self.model.generate(**kwargs)
+
+ def prepare_inputs_for_generation(self, *args, **kwargs):
+ """shortcut for model.prepare_inputs_for_generation"""
+ return self.model.prepare_inputs_for_generation(*args, **kwargs)
+
+ def collect_act_scales(self, model, tokenizer, dataset, device, num_samples=512, seq_len=512):
+ for text in tqdm(dataset):
+ input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
+ model(input_ids)
+
+ def collect_act_dict(self, model, tokenizer, dataset, act_dict, device, num_samples=512, seq_len=512):
+ pbar = tqdm(dataset)
+ for text in pbar:
+ input_ids = tokenizer(text, return_tensors="pt", max_length=seq_len, truncation=True).input_ids.to(device)
+ model(input_ids)
+ mean_scale = np.mean([v["input"] for v in act_dict.values()])
+ pbar.set_description(f"Mean input scale: {mean_scale:.2f}")
+
+ # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
+ def get_act_scales(self, model, tokenizer, dataset, num_samples=512, seq_len=512):
+ model.eval()
+ device = next(model.parameters()).device
+ act_scales = {}
+
+ def stat_tensor(name, tensor):
+ hidden_dim = tensor.shape[-1]
+ tensor = tensor.view(-1, hidden_dim).abs().detach()
+ comming_max = torch.max(tensor, dim=0)[0].float().cpu()
+ if name in act_scales:
+ act_scales[name] = torch.max(act_scales[name], comming_max)
+ else:
+ act_scales[name] = comming_max
+
+ def stat_input_hook(m, x, y, name):
+ if isinstance(x, tuple):
+ x = x[0]
+ stat_tensor(name, x)
+
+ hooks = []
+ for name, m in model.named_modules():
+ if isinstance(m, nn.Linear):
+ hooks.append(m.register_forward_hook(partial(stat_input_hook, name=name)))
+
+ self.collect_act_scales(model, tokenizer, dataset, device, num_samples, seq_len)
+
+ for h in hooks:
+ h.remove()
+
+ return act_scales
+
+ # Adapted from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/smooth.py
+ @torch.no_grad()
+ def smooth_ln_fcs(self, ln, fcs, act_scales, alpha=0.5):
+ if not isinstance(fcs, list):
+ fcs = [fcs]
+ for fc in fcs:
+ assert isinstance(fc, nn.Linear)
+ assert ln.weight.numel() == fc.in_features == act_scales.numel()
+
+ device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
+ act_scales = act_scales.to(device=device, dtype=dtype)
+ weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
+ weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
+
+ scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
+
+ ln.weight.div_(scales)
+ if hasattr(ln, "bias"):
+ ln.bias.div_(scales)
+
+ for fc in fcs:
+ fc.weight.mul_(scales.view(1, -1))
+
+ @classmethod
+ def create_quantized_model(model):
+ raise NotImplementedError("Not implement create_quantized_model method")
+
+ # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
+ def save_quantized(
+ self,
+ save_dir: str,
+ model_basename: str,
+ use_safetensors: bool = False,
+ safetensors_metadata: Optional[Dict[str, str]] = None,
+ ):
+ """save quantized model and configs to local disk"""
+ os.makedirs(save_dir, exist_ok=True)
+
+ if not self.quantized:
+ raise EnvironmentError("can only save quantized model, please execute .quantize first.")
+
+ self.model.to("cpu")
+
+ model_base_name = model_basename # or f"smooth-"
+ if use_safetensors:
+ model_save_name = model_base_name + ".safetensors"
+ state_dict = self.model.state_dict()
+ state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
+ if safetensors_metadata is None:
+ safetensors_metadata = {}
+ elif not isinstance(safetensors_metadata, dict):
+ raise TypeError("safetensors_metadata must be a dictionary.")
+ else:
+ print(f"Received safetensors_metadata: {safetensors_metadata}")
+ new_safetensors_metadata = {}
+ converted_keys = False
+ for key, value in safetensors_metadata.items():
+ if not isinstance(key, str) or not isinstance(value, str):
+ converted_keys = True
+ try:
+ new_key = str(key)
+ new_value = str(value)
+ except Exception as e:
+ raise TypeError(
+ f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
+ )
+ if new_key in new_safetensors_metadata:
+ print(
+ f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
+ )
+ new_safetensors_metadata[new_key] = new_value
+ safetensors_metadata = new_safetensors_metadata
+ if converted_keys:
+ print(
+ f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
+ )
+
+ # Format is required to enable Accelerate to load the metadata
+ # otherwise it raises an OSError
+ safetensors_metadata["format"] = "pt"
+
+ safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
+ else:
+ model_save_name = model_base_name + ".bin"
+ torch.save(self.model.state_dict(), join(save_dir, model_save_name))
+
+ self.model.config.save_pretrained(save_dir)
+
+ # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
+ def save_pretrained(
+ self,
+ save_dir: str,
+ use_safetensors: bool = False,
+ safetensors_metadata: Optional[Dict[str, str]] = None,
+ **kwargs,
+ ):
+ """alias of save_quantized"""
+ warnings.warn("you are using save_pretrained, which will re-direct to save_quantized.")
+ self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
+
+ # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ max_memory: Optional[dict] = None,
+ trust_remote_code: bool = False,
+ torch_dtype: torch.dtype = torch.float16,
+ **model_init_kwargs,
+ ):
+ if not torch.cuda.is_available():
+ raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
+
+ def skip(*args, **kwargs):
+ pass
+
+ torch.nn.init.kaiming_uniform_ = skip
+ torch.nn.init.uniform_ = skip
+ torch.nn.init.normal_ = skip
+
+ # Parameters related to loading from Hugging Face Hub
+ cache_dir = model_init_kwargs.pop("cache_dir", None)
+ force_download = model_init_kwargs.pop("force_download", False)
+ resume_download = model_init_kwargs.pop("resume_download", False)
+ proxies = model_init_kwargs.pop("proxies", None)
+ local_files_only = model_init_kwargs.pop("local_files_only", False)
+ use_auth_token = model_init_kwargs.pop("use_auth_token", None)
+ revision = model_init_kwargs.pop("revision", None)
+ subfolder = model_init_kwargs.pop("subfolder", "")
+ model_init_kwargs.pop("_commit_hash", None)
+
+ cached_file_kwargs = {
+ "cache_dir": cache_dir,
+ "force_download": force_download,
+ "proxies": proxies,
+ "resume_download": resume_download,
+ "local_files_only": local_files_only,
+ "use_auth_token": use_auth_token,
+ "revision": revision,
+ "subfolder": subfolder,
+ }
+
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, **cached_file_kwargs)
+ if config.model_type not in SUPPORTED_MODELS:
+ raise TypeError(f"{config.model_type} isn't supported yet.")
+
+ # enforce some values despite user specified
+ model_init_kwargs["torch_dtype"] = torch_dtype
+ model_init_kwargs["trust_remote_code"] = trust_remote_code
+ if max_memory:
+ if "disk" in max_memory:
+ raise NotImplementedError("disk offload not support yet.")
+ with accelerate.init_empty_weights():
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
+ model.tie_weights()
+
+ max_memory = accelerate.utils.get_balanced_memory(
+ model,
+ max_memory=max_memory,
+ no_split_module_classes=[cls.layer_type],
+ dtype=model_init_kwargs["torch_dtype"],
+ low_zero=False,
+ )
+ model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
+ model,
+ max_memory=max_memory,
+ no_split_module_classes=[cls.layer_type],
+ dtype=model_init_kwargs["torch_dtype"],
+ )
+ model_init_kwargs["low_cpu_mem_usage"] = True
+
+ del model
+ else:
+ model_init_kwargs["device_map"] = None
+ model_init_kwargs["low_cpu_mem_usage"] = False
+
+ torch.cuda.empty_cache()
+
+ merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
+ model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **merged_kwargs)
+
+ model_config = model.config.to_dict()
+ seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
+ if any([k in model_config for k in seq_len_keys]):
+ for key in seq_len_keys:
+ if key in model_config:
+ model.seqlen = model_config[key]
+ break
+ else:
+ warnings.warn("can't get model's sequence length from model config, will set to 4096.")
+ model.seqlen = 4096
+ model.eval()
+
+ return cls(model, False)
+
+ # Adapted from AutoGPTQ: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/modeling/_base.py
+ @classmethod
+ def from_quantized(
+ cls,
+ model_name_or_path: Optional[str],
+ model_basename: Optional[str] = None,
+ device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
+ max_memory: Optional[dict] = None,
+ device: Optional[Union[str, int]] = None,
+ low_cpu_mem_usage: bool = False,
+ torch_dtype: Optional[torch.dtype] = None,
+ use_safetensors: bool = False,
+ trust_remote_code: bool = False,
+ **kwargs,
+ ):
+ """load quantized model from local disk"""
+
+ # Parameters related to loading from Hugging Face Hub
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ use_auth_token = kwargs.pop("use_auth_token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", "")
+ commit_hash = kwargs.pop("_commit_hash", None)
+
+ cached_file_kwargs = {
+ "cache_dir": cache_dir,
+ "force_download": force_download,
+ "proxies": proxies,
+ "resume_download": resume_download,
+ "local_files_only": local_files_only,
+ "use_auth_token": use_auth_token,
+ "revision": revision,
+ "subfolder": subfolder,
+ "_raise_exceptions_for_missing_entries": False,
+ "_commit_hash": commit_hash,
+ }
+
+ # == step1: prepare configs and file names == #
+ config = AutoConfig.from_pretrained(
+ model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs
+ )
+
+ if config.model_type not in SUPPORTED_MODELS:
+ raise TypeError(f"{config.model_type} isn't supported yet.")
+
+ extensions = []
+ if use_safetensors:
+ extensions.append(".safetensors")
+ else:
+ extensions += [".bin", ".pt"]
+
+ model_name_or_path = str(model_name_or_path)
+ is_local = isdir(model_name_or_path)
+
+ resolved_archive_file = None
+ if is_local:
+ model_save_name = join(model_name_or_path, model_basename)
+ for ext in extensions:
+ if isfile(model_save_name + ext):
+ resolved_archive_file = model_save_name + ext
+ break
+ else: # remote
+ for ext in extensions:
+ resolved_archive_file = cached_file(model_name_or_path, model_basename + ext, **cached_file_kwargs)
+ if resolved_archive_file is not None:
+ break
+
+ if resolved_archive_file is None: # Could not find a model file to use
+ raise FileNotFoundError(f"Could not find model in {model_name_or_path}")
+
+ model_save_name = resolved_archive_file
+
+ # == step2: convert model to quantized-model (replace Linear) == #
+ def skip(*args, **kwargs):
+ pass
+
+ torch.nn.init.kaiming_uniform_ = skip
+ torch.nn.init.uniform_ = skip
+ torch.nn.init.normal_ = skip
+
+ transformers.modeling_utils._init_weights = False
+
+ init_contexts = [no_init_weights()]
+ if low_cpu_mem_usage:
+ init_contexts.append(accelerate.init_empty_weights(include_buffers=True))
+
+ with ContextManagers(init_contexts):
+ model = AutoModelForCausalLM.from_config(
+ config, trust_remote_code=trust_remote_code, torch_dtype=torch_dtype
+ )
+ cls.create_quantized_model(model)
+ model.tie_weights()
+
+ # == step3: load checkpoint to quantized-model == #
+ accelerate.utils.modeling.load_checkpoint_in_model(
+ model, checkpoint=model_save_name, offload_state_dict=True, offload_buffers=True
+ )
+
+ # == step4: set seqlen == #
+ model_config = model.config.to_dict()
+ seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
+ if any([k in model_config for k in seq_len_keys]):
+ for key in seq_len_keys:
+ if key in model_config:
+ model.seqlen = model_config[key]
+ break
+ else:
+ warnings.warn("can't get model's sequence length from model config, will set to 4096.")
+ model.seqlen = 4096
+
+ return cls(
+ model,
+ True,
+ )
+
+ def __getattr__(self, item):
+ try:
+ return super().__getattr__(item)
+ except:
+ return getattr(self.model, item)
+
+
+__all__ = ["BaseSmoothForCausalLM"]
diff --git a/colossalai/legacy/inference/quant/smoothquant/models/linear.py b/colossalai/legacy/inference/quant/smoothquant/models/linear.py
new file mode 100644
index 000000000000..969c390a0849
--- /dev/null
+++ b/colossalai/legacy/inference/quant/smoothquant/models/linear.py
@@ -0,0 +1,179 @@
+# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
+
+import torch
+from torch_int._CUDA import linear_a8_w8_b8_o8, linear_a8_w8_bfp32_ofp32
+from torch_int.functional.quantization import quantize_per_tensor_absmax
+
+try:
+ from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
+
+ smoothquant_cuda = SmoothquantBuilder().load()
+ HAS_SMOOTHQUANT_CUDA = True
+except ImportError:
+ HAS_SMOOTHQUANT_CUDA = False
+ raise ImportError("CUDA smoothquant linear is not installed")
+
+
+class W8A8BFP32O32LinearSiLU(torch.nn.Module):
+ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+
+ self.register_buffer(
+ "weight",
+ torch.randint(
+ -127,
+ 127,
+ (self.out_features, self.in_features),
+ dtype=torch.int8,
+ requires_grad=False,
+ ),
+ )
+ self.register_buffer(
+ "bias",
+ torch.zeros((1, self.out_features), dtype=torch.float, requires_grad=False),
+ )
+ self.register_buffer("a", torch.tensor(alpha))
+
+ def to(self, *args, **kwargs):
+ super().to(*args, **kwargs)
+ self.weight = self.weight.to(*args, **kwargs)
+ self.bias = self.bias.to(*args, **kwargs)
+ return self
+
+ @torch.no_grad()
+ def forward(self, x):
+ x_shape = x.shape
+ x = x.view(-1, x_shape[-1])
+ y = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1.0)
+ y = y.view(*x_shape[:-1], -1)
+ return y
+
+ @staticmethod
+ def from_float(module: torch.nn.Linear, input_scale):
+ int8_module = W8A8BFP32O32LinearSiLU(module.in_features, module.out_features)
+ int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
+ alpha = input_scale * weight_scale
+ int8_module.weight = int8_weight
+ if module.bias is not None:
+ int8_module.bias.data.copy_(module.bias.to(torch.float))
+ int8_module.a = alpha
+ return int8_module
+
+
+# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
+class W8A8B8O8Linear(torch.nn.Module):
+ # For qkv_proj
+ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+
+ self.register_buffer(
+ "weight",
+ torch.randint(
+ -127,
+ 127,
+ (self.out_features, self.in_features),
+ dtype=torch.int8,
+ requires_grad=False,
+ ),
+ )
+ self.register_buffer(
+ "bias",
+ torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False),
+ )
+ self.register_buffer("a", torch.tensor(alpha))
+ self.register_buffer("b", torch.tensor(beta))
+
+ def to(self, *args, **kwargs):
+ super().to(*args, **kwargs)
+ self.weight = self.weight.to(*args, **kwargs)
+ self.bias = self.bias.to(*args, **kwargs)
+ return self
+
+ @torch.no_grad()
+ def forward(self, x):
+ x_shape = x.shape
+ x = x.view(-1, x_shape[-1])
+ y = linear_a8_w8_b8_o8(x, self.weight, self.bias, self.a.item(), self.b.item())
+ y = y.view(*x_shape[:-1], -1)
+ return y
+
+ @staticmethod
+ def from_float(module: torch.nn.Linear, input_scale, output_scale):
+ int8_module = W8A8B8O8Linear(module.in_features, module.out_features)
+ int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
+ alpha = input_scale * weight_scale / output_scale
+ int8_module.weight = int8_weight
+ int8_module.a = alpha
+
+ if module.bias is not None:
+ int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
+ int8_module.bias = int8_bias
+ beta = bias_scale / output_scale
+ int8_module.b = beta
+
+ return int8_module
+
+
+# modified from torch-int: https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py
+class W8A8BFP32OFP32Linear(torch.nn.Module):
+ # For fc2 and out_proj
+ def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+
+ self.register_buffer(
+ "weight",
+ torch.randint(
+ -127,
+ 127,
+ (self.out_features, self.in_features),
+ dtype=torch.int8,
+ requires_grad=False,
+ ),
+ )
+ self.register_buffer(
+ "bias",
+ torch.zeros(self.out_features, dtype=torch.float32, requires_grad=False),
+ )
+ self.register_buffer("a", torch.tensor(alpha))
+
+ def _apply(self, fn):
+ # prevent the bias from being converted to half
+ super()._apply(fn)
+ self.bias = self.bias.to(torch.float32)
+ return self
+
+ def to(self, *args, **kwargs):
+ super().to(*args, **kwargs)
+ self.weight = self.weight.to(*args, **kwargs)
+ self.bias = self.bias.to(*args, **kwargs)
+ self.bias = self.bias.to(torch.float32)
+ return self
+
+ @torch.no_grad()
+ def forward(self, x):
+ x_shape = x.shape
+ x = x.view(-1, x_shape[-1])
+ y = linear_a8_w8_bfp32_ofp32(x, self.weight, self.bias, self.a.item(), 1)
+ y = y.view(*x_shape[:-1], -1)
+ return y
+
+ @staticmethod
+ def from_float(module: torch.nn.Linear, input_scale):
+ int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features)
+ int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
+ alpha = input_scale * weight_scale
+ int8_module.weight = int8_weight
+ int8_module.a = alpha
+ int8_module.input_scale = input_scale
+ int8_module.weight_scale = weight_scale
+
+ if module.bias is not None:
+ int8_module.bias = module.bias.to(torch.float32)
+
+ return int8_module
diff --git a/colossalai/legacy/inference/quant/smoothquant/models/llama.py b/colossalai/legacy/inference/quant/smoothquant/models/llama.py
new file mode 100644
index 000000000000..30063857ac30
--- /dev/null
+++ b/colossalai/legacy/inference/quant/smoothquant/models/llama.py
@@ -0,0 +1,838 @@
+import math
+import os
+import types
+from collections import defaultdict
+from functools import partial
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch_int.nn.bmm import BMM_S8T_S8N_F32T, BMM_S8T_S8N_S8T
+from transformers import PreTrainedModel
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LLAMA_INPUTS_DOCSTRING,
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaMLP,
+ LlamaRotaryEmbedding,
+ repeat_kv,
+ rotate_half,
+)
+from transformers.utils import add_start_docstrings_to_model_forward
+
+from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
+from colossalai.kernel.triton import (
+ copy_kv_cache_to_dest,
+ int8_rotary_embedding_fwd,
+ smooth_llama_context_attn_fwd,
+ smooth_token_attention_fwd,
+)
+
+from .base_model import BaseSmoothForCausalLM
+from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
+
+
+class LLamaSmoothquantAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.head_dim = hidden_size // num_heads
+
+ if (self.head_dim * num_heads) != self.hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {num_heads})."
+ )
+
+ self.qk_bmm = BMM_S8T_S8N_F32T(1.0)
+ self.pv_bmm = BMM_S8T_S8N_S8T(1.0)
+
+ self.k_proj = W8A8B8O8Linear(hidden_size, hidden_size)
+ self.v_proj = W8A8B8O8Linear(hidden_size, hidden_size)
+ self.q_proj = W8A8B8O8Linear(hidden_size, hidden_size)
+ self.o_proj = W8A8BFP32OFP32Linear(hidden_size, hidden_size)
+
+ self.register_buffer("q_output_scale", torch.tensor([1.0]))
+ self.register_buffer("k_output_scale", torch.tensor([1.0]))
+ self.register_buffer("v_output_scale", torch.tensor([1.0]))
+ self.register_buffer("q_rotary_output_scale", torch.tensor([1.0]))
+ self.register_buffer("k_rotary_output_scale", torch.tensor([1.0]))
+ self.register_buffer("out_input_scale", torch.tensor([1.0]))
+ self.register_buffer("attn_input_scale", torch.tensor([1.0]))
+
+ self._init_rope()
+ self.num_key_value_heads = num_heads
+
+ def _init_rope(self):
+ self.rotary_emb = LlamaRotaryEmbedding(
+ self.head_dim,
+ max_position_embeddings=2048,
+ base=10000.0,
+ )
+
+ @staticmethod
+ def pack(
+ module: LlamaAttention,
+ attn_input_scale: float,
+ q_output_scale: float,
+ k_output_scale: float,
+ v_output_scale: float,
+ q_rotary_output_scale: float,
+ k_rotary_output_scale: float,
+ out_input_scale: float,
+ ):
+ int8_module = LLamaSmoothquantAttention(module.hidden_size, module.num_heads)
+
+ int8_module.attn_input_scale = torch.tensor([attn_input_scale])
+
+ int8_module.q_output_scale = torch.tensor([q_output_scale])
+ int8_module.k_output_scale = torch.tensor([k_output_scale])
+ int8_module.v_output_scale = torch.tensor([v_output_scale])
+
+ int8_module.q_rotary_output_scale = torch.tensor([q_rotary_output_scale])
+ int8_module.k_rotary_output_scale = torch.tensor([k_rotary_output_scale])
+
+ int8_module.q_proj = W8A8B8O8Linear.from_float(module.q_proj, attn_input_scale, q_output_scale)
+ int8_module.k_proj = W8A8B8O8Linear.from_float(module.k_proj, attn_input_scale, k_output_scale)
+ int8_module.v_proj = W8A8B8O8Linear.from_float(module.v_proj, attn_input_scale, v_output_scale)
+ int8_module.o_proj = W8A8BFP32OFP32Linear.from_float(module.o_proj, out_input_scale)
+
+ int8_module.out_input_scale = torch.tensor([out_input_scale])
+
+ return int8_module
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ @torch.no_grad()
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ rotary_emb: Tuple[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ cos = rotary_emb[0]
+ sin = rotary_emb[1]
+
+ int8_rotary_embedding_fwd(
+ query_states.view(-1, self.num_heads, self.head_dim),
+ cos,
+ sin,
+ self.q_output_scale.item(),
+ self.q_rotary_output_scale.item(),
+ )
+ int8_rotary_embedding_fwd(
+ key_states.view(-1, self.num_heads, self.head_dim),
+ cos,
+ sin,
+ self.k_output_scale.item(),
+ self.k_rotary_output_scale.item(),
+ )
+
+ def _copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
+ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
+ return
+
+ query_states = query_states.view(-1, self.num_heads, self.head_dim)
+ key_states = key_states.view(-1, self.num_heads, self.head_dim)
+ value_states = value_states.view(-1, self.num_heads, self.head_dim)
+
+ if infer_state.is_context_stage:
+ # first token generation
+
+ # copy key and value calculated in current step to memory manager
+ _copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.context_mem_index,
+ infer_state.cache_manager,
+ )
+
+ attn_output = torch.empty_like(query_states)
+
+ smooth_llama_context_attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ self.q_rotary_output_scale.item(),
+ self.k_rotary_output_scale.item(),
+ self.v_output_scale.item(),
+ self.out_input_scale.item(),
+ infer_state.start_loc,
+ infer_state.seq_len,
+ q_len,
+ )
+
+ else:
+ if infer_state.decode_is_contiguous:
+ # if decode is contiguous, then we copy to key cache and value cache in cache manager directly
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id][
+ infer_state.decode_mem_start : infer_state.decode_mem_end, :, :
+ ]
+ cache_k.copy_(key_states)
+ cache_v.copy_(value_states)
+ else:
+ # if decode is not contiguous, use triton kernel to copy key and value cache
+ # k, v shape: [batch_size, num_heads, head_dim/embed_size_per_head
+ _copy_kv_to_mem_cache(
+ infer_state.decode_layer_id,
+ key_states,
+ value_states,
+ infer_state.decode_mem_index,
+ infer_state.cache_manager,
+ )
+
+ # (batch_size, seqlen, nheads, headdim)
+ attn_output = torch.empty_like(query_states)
+
+ smooth_token_attention_fwd(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ self.q_rotary_output_scale.item(),
+ self.k_rotary_output_scale.item(),
+ self.v_output_scale.item(),
+ self.out_input_scale.item(),
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ infer_state.max_len_in_batch,
+ )
+
+ attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, None
+
+
+class LlamaLayerNormQ(torch.nn.Module):
+ def __init__(self, dim, eps=1e-5):
+ super().__init__()
+ self.input_scale = 1.0
+ self.variance_epsilon = eps
+ self.register_buffer("weight", torch.ones(dim, dtype=torch.float32))
+
+ def forward(self, x):
+ ln_output_fp = torch.nn.functional.layer_norm(x, x.shape[-1:], self.weight, None, self.variance_epsilon)
+ ln_output_int8 = ln_output_fp.round().clamp(-128, 127).to(torch.int8)
+ return ln_output_int8
+
+ @staticmethod
+ def from_float(module: torch.nn.LayerNorm, output_scale: float):
+ assert module.weight.shape[0] == module.weight.numel()
+ q_module = LlamaLayerNormQ(module.weight.shape[0], module.variance_epsilon)
+ q_module.weight = module.weight / output_scale
+ return q_module
+
+
+class LlamaSmoothquantMLP(nn.Module):
+ def __init__(self, intermediate_size, hidden_size):
+ super().__init__()
+ self.gate_proj = W8A8BFP32O32LinearSiLU(hidden_size, intermediate_size)
+ self.up_proj = W8A8BFP32OFP32Linear(hidden_size, intermediate_size)
+ self.down_proj = W8A8BFP32OFP32Linear(intermediate_size, hidden_size)
+ self.register_buffer("down_proj_input_scale", torch.tensor([1.0]))
+
+ @staticmethod
+ def pack(
+ mlp_module: LlamaMLP,
+ gate_proj_input_scale: float,
+ up_proj_input_scale: float,
+ down_proj_input_scale: float,
+ ):
+ int8_module = LlamaSmoothquantMLP(
+ mlp_module.intermediate_size,
+ mlp_module.hidden_size,
+ )
+
+ int8_module.gate_proj = W8A8BFP32O32LinearSiLU.from_float(mlp_module.gate_proj, gate_proj_input_scale)
+ int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.up_proj, up_proj_input_scale)
+ int8_module.down_proj = W8A8BFP32OFP32Linear.from_float(mlp_module.down_proj, down_proj_input_scale)
+ int8_module.down_proj_input_scale = torch.tensor([down_proj_input_scale])
+ return int8_module
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ ):
+ x_shape = hidden_states.shape
+ gate_out = self.gate_proj(hidden_states)
+ up_out = self.up_proj(hidden_states)
+ inter_out = gate_out * up_out
+ inter_out = inter_out.div_(self.down_proj_input_scale.item()).round().clamp(-128, 127).to(torch.int8)
+ down_out = self.down_proj(inter_out)
+ down_out = down_out.view(*x_shape[:-1], -1)
+ return down_out
+
+
+class LlamaSmoothquantDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.self_attn = LLamaSmoothquantAttention(config.hidden_size, config.num_attention_heads)
+
+ self.mlp = LlamaSmoothquantMLP(config.intermediate_size, config.hidden_size)
+ self.input_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.post_attention_layernorm = LlamaLayerNormQ(config.hidden_size, eps=config.rms_norm_eps)
+
+ @staticmethod
+ def pack(
+ module: LlamaDecoderLayer,
+ attn_input_scale: float,
+ q_output_scale: float,
+ k_output_scale: float,
+ v_output_scale: float,
+ q_rotary_output_scale: float,
+ k_rotary_output_scale: float,
+ out_input_scale: float,
+ gate_input_scale: float,
+ up_input_scale: float,
+ down_input_scale: float,
+ ):
+ config = module.self_attn.config
+ int8_decoder_layer = LlamaSmoothquantDecoderLayer(config)
+
+ int8_decoder_layer.input_layernorm = LlamaLayerNormQ.from_float(module.input_layernorm, attn_input_scale)
+ int8_decoder_layer.self_attn = LLamaSmoothquantAttention.pack(
+ module.self_attn,
+ attn_input_scale,
+ q_output_scale,
+ k_output_scale,
+ v_output_scale,
+ q_rotary_output_scale,
+ k_rotary_output_scale,
+ out_input_scale,
+ )
+
+ int8_decoder_layer.post_attention_layernorm = LlamaLayerNormQ.from_float(
+ module.post_attention_layernorm, gate_input_scale
+ )
+
+ int8_decoder_layer.mlp = LlamaSmoothquantMLP.pack(
+ module.mlp,
+ gate_input_scale,
+ up_input_scale,
+ down_input_scale,
+ )
+
+ return int8_decoder_layer
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ rotary_emb: Tuple[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+ infer_state: Optional[BatchInferState] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ rotary_emb=rotary_emb,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
+ infer_state=infer_state,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ return hidden_states, None, None
+
+
+class LlamaApplyRotary(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, cos, sin, position_ids):
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ x_embed = (x * cos) + (rotate_half(x) * sin)
+
+ return x_embed
+
+
+# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+def llama_decoder_layer_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ padding_mask: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.config.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states = self.q_apply_rotary(query_states, cos, sin, position_ids)
+ key_states = self.k_apply_rotary(key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if self.config.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def init_to_get_rotary(config, base=10000, use_elem=False):
+ """
+ This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
+ Args:
+ base : calculation arg
+ use_elem : activated when using chatglm-based models
+ """
+ config.head_dim_ = config.hidden_size // config.num_attention_heads
+ if not hasattr(config, "rope_scaling"):
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = config.rope_scaling.factor if config.rope_scaling is not None else 1.0
+
+ if hasattr(config, "max_sequence_length"):
+ max_seq_len = config.max_sequence_length
+ elif hasattr(config, "max_position_embeddings"):
+ max_seq_len = config.max_position_embeddings * rope_scaling_factor
+ else:
+ max_seq_len = 2048 * rope_scaling_factor
+ base = float(base)
+
+ # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+ try:
+ ntk_alpha = float(os.environ.get("INFER_NTK_ALPHA", 1))
+ assert ntk_alpha >= 1
+ if ntk_alpha > 1:
+ print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (config.head_dim_ / (config.head_dim_ - 2))) # Base change formula
+ except:
+ pass
+
+ n_elem = config.head_dim_
+ if use_elem:
+ n_elem //= 2
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
+ t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
+ freqs = torch.outer(t, inv_freq)
+
+ _cos_cached = torch.cos(freqs).to(torch.float)
+ _sin_cached = torch.sin(freqs).to(torch.float)
+ return _cos_cached, _sin_cached
+
+
+# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+def llama_model_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ infer_state = self.infer_state
+ if infer_state.is_context_stage:
+ past_key_values_length = 0
+ else:
+ past_key_values_length = infer_state.max_len_in_batch - 1
+
+ seq_length_with_past = seq_length + past_key_values_length
+
+ # NOTE: differentiate with prefill stage
+ # block_loc require different value-assigning method for two different stage
+ # NOTE: differentiate with prefill stage
+ # block_loc require different value-assigning method for two different stage
+ if infer_state.is_context_stage:
+ infer_state.context_mem_index = infer_state.cache_manager.alloc(infer_state.total_token_num)
+ infer_state.init_block_loc(
+ infer_state.block_loc, infer_state.seq_len, seq_length, infer_state.context_mem_index
+ )
+ else:
+ alloc_mem = infer_state.cache_manager.alloc_contiguous(batch_size)
+ if alloc_mem is not None:
+ infer_state.decode_is_contiguous = True
+ infer_state.decode_mem_index = alloc_mem[0]
+ infer_state.decode_mem_start = alloc_mem[1]
+ infer_state.decode_mem_end = alloc_mem[2]
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ else:
+ print(f" *** Encountered allocation non-contiguous")
+ print(f" infer_state.cache_manager.max_len_in_batch: {infer_state.max_len_in_batch}")
+ infer_state.decode_is_contiguous = False
+ alloc_mem = infer_state.cache_manager.alloc(batch_size)
+ infer_state.decode_mem_index = alloc_mem
+ infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
+ padding_mask = None
+ else:
+ if 0 in attention_mask:
+ padding_mask = attention_mask
+ else:
+ padding_mask = None
+
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ raise NotImplementedError("not implement gradient_checkpointing and training options ")
+
+ if past_key_values_length == 0:
+ position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
+ position_ids.view(-1).shape[0], -1
+ )
+ else:
+ position_cos = torch.index_select(self._cos_cached, 0, position_ids.view(-1)).view(batch_size, -1)
+ position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(batch_size, -1)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+ infer_state.decode_layer_id = 0
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ rotary_emb=(position_cos, position_sin),
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ padding_mask=padding_mask,
+ infer_state=infer_state,
+ )
+
+ hidden_states = layer_outputs[0]
+ infer_state.decode_layer_id += 1
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ infer_state.is_context_stage = False
+ infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.seq_len += 1
+ infer_state.max_len_in_batch += 1
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class SmoothLlamaForCausalLM(BaseSmoothForCausalLM):
+ layer_type = "LlamaDecoderLayer"
+
+ def __init__(self, model: PreTrainedModel, quantized: bool = False):
+ super().__init__(model, quantized)
+
+ # Adatped from https://github.com/mit-han-lab/smoothquant/blob/main/smoothquant/calibration.py
+ def get_act_dict(
+ self,
+ tokenizer,
+ dataset,
+ num_samples=512,
+ seq_len=512,
+ ):
+ llama_model = self.model
+
+ llama_model.eval()
+ device = next(llama_model.parameters()).device
+ # print("model:", llama_model)
+ act_dict = defaultdict(dict)
+
+ def stat_io_hook(m, x, y, name):
+ if isinstance(x, tuple):
+ x = x[0]
+ if name not in act_dict or "input" not in act_dict[name]:
+ act_dict[name]["input"] = x.detach().abs().max().item()
+ else:
+ act_dict[name]["input"] = max(act_dict[name]["input"], x.detach().abs().max().item())
+ if isinstance(y, tuple):
+ y = y[0]
+ if name not in act_dict or "output" not in act_dict[name]:
+ act_dict[name]["output"] = y.detach().abs().max().item()
+ else:
+ act_dict[name]["output"] = max(act_dict[name]["output"], y.detach().abs().max().item())
+
+ for name, m in llama_model.named_modules():
+ if isinstance(m, LlamaAttention):
+ setattr(m, "q_apply_rotary", LlamaApplyRotary())
+ setattr(m, "k_apply_rotary", LlamaApplyRotary())
+ m.forward = types.MethodType(llama_decoder_layer_forward, m)
+
+ hooks = []
+ for name, m in llama_model.named_modules():
+ if isinstance(m, LlamaApplyRotary):
+ hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
+ if isinstance(m, torch.nn.Linear):
+ hooks.append(m.register_forward_hook(partial(stat_io_hook, name=name)))
+
+ self.collect_act_dict(llama_model, tokenizer, dataset, act_dict, device, num_samples, seq_len)
+
+ for hook in hooks:
+ hook.remove()
+ return act_dict
+
+ def smooth_fn(self, scales, alpha=0.5):
+ model = self.model
+ for name, module in model.named_modules():
+ if isinstance(module, LlamaDecoderLayer):
+ attn_ln = module.input_layernorm
+ qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
+ qkv_input_scales = scales[name + ".self_attn.q_proj"]
+ self.smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
+
+ def create_quantized_model(model):
+ llama_config = model.config
+ for i, layer in enumerate(model.model.layers):
+ model.model.layers[i] = LlamaSmoothquantDecoderLayer(llama_config)
+
+ model.model.forward = types.MethodType(llama_model_forward, model.model)
+ cos, sin = init_to_get_rotary(llama_config)
+ model.model.register_buffer("_cos_cached", cos)
+ model.model.register_buffer("_sin_cached", sin)
+
+ def quantized(
+ self,
+ tokenizer,
+ dataset,
+ num_samples=512,
+ seq_len=512,
+ alpha=0.5,
+ ):
+ llama_model = self.model
+ llama_config = llama_model.config
+
+ act_scales = self.get_act_scales(llama_model, tokenizer, dataset, num_samples, seq_len)
+
+ self.smooth_fn(act_scales, alpha)
+
+ act_dict = self.get_act_dict(tokenizer, dataset, num_samples, seq_len)
+ decoder_layer_scales = []
+
+ for idx in range(llama_config.num_hidden_layers):
+ scale_dict = {}
+ scale_dict["attn_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["input"] / 127
+ scale_dict["q_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.q_proj"]["output"] / 127
+ scale_dict["k_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.k_proj"]["output"] / 127
+ scale_dict["v_output_scale"] = act_dict[f"model.layers.{idx}.self_attn.v_proj"]["output"] / 127
+
+ scale_dict["q_rotary_output_scale"] = (
+ act_dict[f"model.layers.{idx}.self_attn.q_apply_rotary"]["output"] / 127
+ )
+ scale_dict["k_rotary_output_scale"] = (
+ act_dict[f"model.layers.{idx}.self_attn.k_apply_rotary"]["output"] / 127
+ )
+
+ scale_dict["out_input_scale"] = act_dict[f"model.layers.{idx}.self_attn.o_proj"]["input"] / 127
+
+ scale_dict["gate_input_scale"] = act_dict[f"model.layers.{idx}.mlp.gate_proj"]["input"] / 127
+ scale_dict["up_input_scale"] = act_dict[f"model.layers.{idx}.mlp.up_proj"]["input"] / 127
+ scale_dict["down_input_scale"] = act_dict[f"model.layers.{idx}.mlp.down_proj"]["input"] / 127
+
+ decoder_layer_scales.append(scale_dict)
+
+ for i, layer in enumerate(llama_model.model.layers):
+ orig_layer = layer
+ llama_model.model.layers[i] = LlamaSmoothquantDecoderLayer.pack(orig_layer, **decoder_layer_scales[i])
+
+ llama_model.model.forward = types.MethodType(llama_model_forward, llama_model.model)
+
+ cos, sin = init_to_get_rotary(llama_config)
+ llama_model.model.register_buffer("_cos_cached", cos.to(self.model.device))
+ llama_model.model.register_buffer("_sin_cached", sin.to(self.model.device))
diff --git a/examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py b/colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py
similarity index 97%
rename from examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py
rename to colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py
index 51d520ebbcf6..d758b467c730 100644
--- a/examples/inference/serving/ray_serve/Colossal_Inference_rayserve.py
+++ b/colossalai/legacy/inference/serving/ray_serve/Colossal_Inference_rayserve.py
@@ -68,7 +68,9 @@ def setup(self, world_size, rank, port):
self.model_path, pad_token_id=self.tokenizer.pad_token_id, torch_dtype=torch.float16
)
- shard_config = ShardConfig(enable_tensor_parallelism=True if world_size > 1 else False, inference_only=True)
+ shard_config = ShardConfig(
+ enable_tensor_parallelism=True if world_size > 1 else False, extra_kwargs={"inference_only": True}
+ )
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
diff --git a/examples/inference/serving/ray_serve/README.md b/colossalai/legacy/inference/serving/ray_serve/README.md
similarity index 100%
rename from examples/inference/serving/ray_serve/README.md
rename to colossalai/legacy/inference/serving/ray_serve/README.md
diff --git a/examples/inference/serving/ray_serve/send_request.py b/colossalai/legacy/inference/serving/ray_serve/send_request.py
similarity index 100%
rename from examples/inference/serving/ray_serve/send_request.py
rename to colossalai/legacy/inference/serving/ray_serve/send_request.py
diff --git a/examples/inference/serving/ray_serve/send_requests.py b/colossalai/legacy/inference/serving/ray_serve/send_requests.py
similarity index 100%
rename from examples/inference/serving/ray_serve/send_requests.py
rename to colossalai/legacy/inference/serving/ray_serve/send_requests.py
diff --git a/colossalai/legacy/inference/serving/test_ci.sh b/colossalai/legacy/inference/serving/test_ci.sh
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/examples/inference/serving/torch_serve/Colossal_Inference_Handler.py b/colossalai/legacy/inference/serving/torch_serve/Colossal_Inference_Handler.py
similarity index 98%
rename from examples/inference/serving/torch_serve/Colossal_Inference_Handler.py
rename to colossalai/legacy/inference/serving/torch_serve/Colossal_Inference_Handler.py
index c0d30501efea..e07494b8a1a9 100644
--- a/examples/inference/serving/torch_serve/Colossal_Inference_Handler.py
+++ b/colossalai/legacy/inference/serving/torch_serve/Colossal_Inference_Handler.py
@@ -100,7 +100,9 @@ def initialize(self, ctx):
colossalai.launch(config={}, rank=rank, world_size=world_size, host=host, port=port, backend="nccl")
logger.info("Initializing TPInferEngine ...")
- shard_config = ShardConfig(enable_tensor_parallelism=True if self.tp_size > 1 else False, inference_only=True)
+ shard_config = ShardConfig(
+ enable_tensor_parallelism=True if self.tp_size > 1 else False, extra_kwargs={"inference_only": True}
+ )
self.infer_engine = TPInferEngine(
self.model, shard_config, self.max_batch_size, self.max_input_len, self.max_output_len
)
diff --git a/examples/inference/serving/torch_serve/README.md b/colossalai/legacy/inference/serving/torch_serve/README.md
similarity index 100%
rename from examples/inference/serving/torch_serve/README.md
rename to colossalai/legacy/inference/serving/torch_serve/README.md
diff --git a/examples/inference/serving/torch_serve/config.properties b/colossalai/legacy/inference/serving/torch_serve/config.properties
similarity index 100%
rename from examples/inference/serving/torch_serve/config.properties
rename to colossalai/legacy/inference/serving/torch_serve/config.properties
diff --git a/examples/inference/serving/torch_serve/docker/Dockerfile b/colossalai/legacy/inference/serving/torch_serve/docker/Dockerfile
similarity index 100%
rename from examples/inference/serving/torch_serve/docker/Dockerfile
rename to colossalai/legacy/inference/serving/torch_serve/docker/Dockerfile
diff --git a/examples/inference/serving/torch_serve/model-config.yaml b/colossalai/legacy/inference/serving/torch_serve/model-config.yaml
similarity index 100%
rename from examples/inference/serving/torch_serve/model-config.yaml
rename to colossalai/legacy/inference/serving/torch_serve/model-config.yaml
diff --git a/examples/inference/serving/torch_serve/sample_text.txt b/colossalai/legacy/inference/serving/torch_serve/sample_text.txt
similarity index 100%
rename from examples/inference/serving/torch_serve/sample_text.txt
rename to colossalai/legacy/inference/serving/torch_serve/sample_text.txt
diff --git a/colossalai/inference/tensor_parallel/__init__.py b/colossalai/legacy/inference/tensor_parallel/__init__.py
similarity index 100%
rename from colossalai/inference/tensor_parallel/__init__.py
rename to colossalai/legacy/inference/tensor_parallel/__init__.py
diff --git a/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py
new file mode 100644
index 000000000000..f707a86df37e
--- /dev/null
+++ b/colossalai/legacy/inference/tensor_parallel/batch_infer_state.py
@@ -0,0 +1,118 @@
+# might want to consider combine with InferenceConfig in colossalai/ppinference/inference_config.py later
+from dataclasses import dataclass
+
+import torch
+from transformers.tokenization_utils_base import BatchEncoding
+
+from .kvcache_manager import MemoryManager
+
+
+# adapted from: lightllm/server/router/model_infer/infer_batch.py
+@dataclass
+class BatchInferState:
+ r"""
+ Information to be passed and used for a batch of inputs during
+ a single model forward
+ """
+ batch_size: int
+ max_len_in_batch: int
+
+ cache_manager: MemoryManager = None
+
+ block_loc: torch.Tensor = None
+ start_loc: torch.Tensor = None
+ seq_len: torch.Tensor = None
+ past_key_values_len: int = None
+
+ is_context_stage: bool = False
+ context_mem_index: torch.Tensor = None
+ decode_is_contiguous: bool = None
+ decode_mem_start: int = None
+ decode_mem_end: int = None
+ decode_mem_index: torch.Tensor = None
+ decode_layer_id: int = None
+
+ device: torch.device = torch.device("cuda")
+
+ @property
+ def total_token_num(self):
+ # return self.batch_size * self.max_len_in_batch
+ assert self.seq_len is not None and self.seq_len.size(0) > 0
+ return int(torch.sum(self.seq_len))
+
+ def set_cache_manager(self, manager: MemoryManager):
+ self.cache_manager = manager
+
+ # adapted from: https://github.com/ModelTC/lightllm/blob/28c1267cfca536b7b4f28e921e03de735b003039/lightllm/common/infer_utils.py#L1
+ @staticmethod
+ def init_block_loc(
+ b_loc: torch.Tensor, seq_len: torch.Tensor, max_len_in_batch: int, alloc_mem_index: torch.Tensor
+ ):
+ """in-place update block loc mapping based on the sequence length of the inputs in current bath"""
+ start_index = 0
+ seq_len_numpy = seq_len.cpu().numpy()
+ for i, cur_seq_len in enumerate(seq_len_numpy):
+ b_loc[i, max_len_in_batch - cur_seq_len : max_len_in_batch] = alloc_mem_index[
+ start_index : start_index + cur_seq_len
+ ]
+ start_index += cur_seq_len
+ return
+
+ @classmethod
+ def init_from_batch(
+ cls,
+ batch: torch.Tensor,
+ max_input_len: int,
+ max_output_len: int,
+ cache_manager: MemoryManager,
+ ):
+ if not isinstance(batch, (BatchEncoding, dict, list, torch.Tensor)):
+ raise TypeError(f"batch type {type(batch)} is not supported in prepare_batch_state")
+
+ input_ids_list = None
+ attention_mask = None
+
+ if isinstance(batch, (BatchEncoding, dict)):
+ input_ids_list = batch["input_ids"]
+ attention_mask = batch["attention_mask"]
+ else:
+ input_ids_list = batch
+ if isinstance(input_ids_list[0], int): # for a single input
+ input_ids_list = [input_ids_list]
+ attention_mask = [attention_mask] if attention_mask is not None else attention_mask
+
+ batch_size = len(input_ids_list)
+
+ seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
+ start_index = 0
+
+ max_len_in_batch = -1
+ if isinstance(batch, (BatchEncoding, dict)):
+ for i, attn_mask in enumerate(attention_mask):
+ curr_seq_len = len(attn_mask)
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ else:
+ length = max(len(input_id) for input_id in input_ids_list)
+ for i, input_ids in enumerate(input_ids_list):
+ curr_seq_len = length
+ seq_lengths[i] = curr_seq_len
+ seq_start_indexes[i] = start_index
+ start_index += curr_seq_len
+ max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+ block_loc = torch.zeros((batch_size, max_input_len + max_output_len), dtype=torch.long, device="cuda")
+
+ return cls(
+ batch_size=batch_size,
+ max_len_in_batch=max_len_in_batch,
+ seq_len=seq_lengths.to("cuda"),
+ start_loc=seq_start_indexes.to("cuda"),
+ block_loc=block_loc,
+ decode_layer_id=0,
+ past_key_values_len=0,
+ is_context_stage=True,
+ cache_manager=cache_manager,
+ )
diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/legacy/inference/tensor_parallel/engine.py
similarity index 81%
rename from colossalai/inference/tensor_parallel/engine.py
rename to colossalai/legacy/inference/tensor_parallel/engine.py
index 216b134f5fab..a8fd3ca9e5ba 100644
--- a/colossalai/inference/tensor_parallel/engine.py
+++ b/colossalai/legacy/inference/tensor_parallel/engine.py
@@ -13,6 +13,8 @@
from .batch_infer_state import BatchInferState
from .kvcache_manager import MemoryManager
+# from dynamic_batching.infer_batch import InferBatch
+
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
_supported_models = [
@@ -42,7 +44,7 @@ class TPInferEngine:
>>> # define model and shard config for your inference
>>> model = ...
>>> generate_kwargs = ...
- >>> shard_config = ShardConfig(enable_tensor_parallelism=True, inference_only=True)
+ >>> shard_config = ShardConfig(enable_tensor_parallelism=True, extra_kwargs={"inference_only": True})
>>> infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
>>> outputs = infer_engine.generate(input_ids, **generate_kwargs)
"""
@@ -61,7 +63,6 @@ def __init__(
self.max_input_len = max_input_len
self.max_output_len = max_output_len
self.max_total_token_num = self.max_batch_size * (self.max_input_len + self.max_output_len)
-
# Constraints relatable with specs of devices and model
# This may change into an optional arg in the future
assert self.max_batch_size <= 64, "Max batch size exceeds the constraint"
@@ -76,13 +77,16 @@ def __init__(
)
self.layer_num = num_hidden_layers
- self.multi_query_group_num = 0
+ self.multi_query_group_num = model.config.num_attention_heads
+ # default to attention_heads
+ if hasattr(model.config, "multi_query_attention"):
+ self.multi_query_attention = getattr(model.config, "multi_query_attention")
if hasattr(model.config, "multi_query_group_num"):
- self.multi_query_group_num = model.config.multi_query_group_num
+ self.multi_query_group_num = getattr(model.config, "multi_query_group_num")
if hasattr(model.config, "num_key_value_heads"):
- self.multi_query_group_num = model.config.num_key_value_heads
+ self.multi_query_group_num = getattr(model.config, "num_key_value_heads")
self.tp_size = -1 # to be set with given shard config in self.prepare_shard_config
self.cache_manager = None
@@ -96,6 +100,8 @@ def __init__(
self.shard_config = shard_config
self.model = None
+ self.cache = {}
+
# optimize the original model by sharding with ShardFormer
self._optimize_model(model=model.to(device))
@@ -104,7 +110,7 @@ def _init_manager(self) -> None:
assert self.head_num % self.tp_size == 0, f"Cannot shard {self.head_num} heads with tp size {self.tp_size}"
self.head_num //= self.tp_size # update sharded number of heads
- if self.multi_query_group_num:
+ if hasattr(self, "multi_query_attention"):
# NOTE the logic of MQA tensor parallelism should be specified.
assert (
self.multi_query_group_num % self.tp_size == 0
@@ -175,7 +181,7 @@ def _optimize_model(self, model: nn.Module) -> None:
In further generation, use the sharded model instead of original model.
"""
# NOTE we will change to use an inference config later with additional attrs we want
- assert self.shard_config.inference_only is True
+ assert self.shard_config.extra_kwargs["inference_only"] is True
shardformer = ShardFormer(shard_config=self.shard_config)
self._prepare_with_shard_config(shard_config=self.shard_config)
self._shard_model_by(shardformer, model)
@@ -197,10 +203,10 @@ def _prepare_with_shard_config(self, shard_config: Optional[ShardConfig] = None)
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
- inference_only=True,
+ extra_kwargs={"inference_only": True},
)
else:
- shard_config.inference_only = True
+ shard_config.extra_kwargs = {"inference_only": True}
shard_config.pipeline_stage_manager = None
if shard_config.enable_tensor_parallelism:
self.tp_size = shard_config.tensor_parallel_size
@@ -215,13 +221,11 @@ def _shard_model_by(self, shardformer: ShardFormer, model: nn.Module) -> None:
), "Discrepancy between the tp size of TPInferEngine and the tp size of shard config"
model_name = model.__class__.__name__
assert model_name in self.supported_models, f"Unsupported model cls {model_name} for TP inference."
-
- model = model.model if self.shard_config.inference_gptq else model
-
- policy = get_autopolicy(model, inference_only=True)
+ if self.shard_config.extra_kwargs.get("inference_gptq", False):
+ model = model.model
+ policy = get_autopolicy(model, shard_config=self.shard_config)
self.model, _ = shardformer.optimize(model, policy)
-
- if self.shard_config.inference_gptq:
+ if self.shard_config.extra_kwargs.get("inference_gptq", False):
self._post_init_gptq_buffer(self.model)
self.model = self.model.cuda()
@@ -284,7 +288,6 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
attention_mask = [attention_mask] if attention_mask is not None else attention_mask
batch_size = len(input_ids_list)
-
seq_start_indexes = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
seq_lengths = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
start_index = 0
@@ -309,6 +312,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
seq_start_indexes[i] = start_index
start_index += curr_seq_len
max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch
+
block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda")
batch_infer_state = BatchInferState(batch_size, max_len_in_batch)
batch_infer_state.seq_len = seq_lengths.to("cuda")
@@ -318,6 +322,7 @@ def prepare_batch_state(self, inputs) -> BatchInferState:
batch_infer_state.past_key_values_len = 0
batch_infer_state.is_context_stage = True
batch_infer_state.set_cache_manager(self.cache_manager)
+
return batch_infer_state
@torch.no_grad()
@@ -381,6 +386,85 @@ def _update_batch_state(self, infer_state: Optional[BatchInferState]) -> None:
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device=device)
infer_state.seq_len += 1
+ @torch.no_grad()
+ def forward(self, batch_id, is_prefill):
+ """
+ Forward is used in Dynamic Batching Manager
+ """
+ batch = self.cache.pop(batch_id)
+ if is_prefill:
+ input_ = torch.tensor(batch.all_input_ids).cuda()
+ else:
+ input_ = batch.input_ids.reshape(len(batch), 1)
+
+ batch_args = {
+ "batch_size": len(batch),
+ "max_len_in_batch": batch.nopad_max_len_in_batch,
+ "block_loc": batch.nopad_b_loc,
+ "start_loc": batch.nopad_b_start_loc,
+ "seq_len": batch.nopad_b_seq_len,
+ "cache_manager": batch.cache_manager,
+ "is_context_stage": is_prefill,
+ }
+
+ infer_state = BatchInferState(**batch_args)
+ model = self.model
+ if isinstance(model, LlamaForCausalLM):
+ model = self.model.model
+ elif isinstance(model, BloomForCausalLM):
+ model = self.model.transformer
+
+ setattr(model, "infer_state", infer_state)
+ output = self.model.forward(input_ids=input_)
+ logits = output.logits
+ # bsz, seq_len, vocab_size
+ prob_out = torch.softmax(
+ logits[
+ :,
+ -1,
+ ],
+ dim=-1,
+ ).squeeze(1)
+ # prob_out: bsz, vocab_size
+ predict_ids = torch.argmax(prob_out, dim=-1, keepdim=True)
+ prob_out = torch.log(prob_out).detach().cpu().numpy()
+ predict_ids = predict_ids.detach().cpu().numpy()
+ # [ batch_size, 1 ]
+
+ output_dict = {}
+ new_input_ids = []
+ for i, (r, all_input_ids, next_token_id, next_token_logprob) in enumerate(
+ zip(batch.requests, batch.all_input_ids, predict_ids, prob_out)
+ ):
+ next_token_id = int(next_token_id)
+ next_token_logprob = next_token_logprob[next_token_id]
+ # all_input_ids_tensor = torch.tensor(all_input_ids, dtype=torch.long, device="cuda")
+ all_input_ids.append(next_token_id)
+ # all_input_ids_tensor = None
+ new_input_ids.append(next_token_id)
+ batch.all_input_ids[i] = all_input_ids
+ batch.input_lengths[i] += 1
+ batch.out_token_id_counts[i][next_token_id] += 1
+ metadata = {
+ "id": int(next_token_id),
+ "logprob": float(next_token_logprob),
+ }
+ output_dict[r["request_id"]] = (int(next_token_id), metadata)
+
+ batch.input_ids = torch.tensor(new_input_ids, dtype=torch.long).cuda()
+ batch.nopad_total_token_num += len(batch)
+ batch.nopad_max_len_in_batch += 1 # NOTE: we may repalce this
+ self.cache[batch.batch_id] = batch
+ return output_dict
+
+ @torch.no_grad()
+ def _prefill_batch(self, batch_id):
+ return self.forward(batch_id, is_prefill=True)
+
+ @torch.no_grad()
+ def _decode_batch(self, batch_id):
+ return self.forward(batch_id, is_prefill=False)
+
# might want to create a sequence pool
# add a single request/sequence/input text at a time and record its length
# In other words, store the actual length of input tokens representing a single input text
diff --git a/colossalai/inference/tensor_parallel/kvcache_manager.py b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py
similarity index 98%
rename from colossalai/inference/tensor_parallel/kvcache_manager.py
rename to colossalai/legacy/inference/tensor_parallel/kvcache_manager.py
index c9e7aaae0844..91bb96a1f1f0 100644
--- a/colossalai/inference/tensor_parallel/kvcache_manager.py
+++ b/colossalai/legacy/inference/tensor_parallel/kvcache_manager.py
@@ -32,7 +32,7 @@ def __init__(
):
self.logger = logging.get_logger(__name__)
self.available_size = size
- self.past_key_values_length = 0
+ self.max_len_in_batch = 0
self._init_mem_states(size, device)
self._init_kv_buffers(size, device, dtype, head_num, head_dim, layer_num)
@@ -102,5 +102,5 @@ def free_all(self):
"""free all memory by updating memory states"""
self.available_size = len(self.mem_state)
self.mem_state[:] = 1
- self.past_key_values_length = 0
+ self.max_len_in_batch = 0
self.logger.info("freed all space of memory manager")
diff --git a/colossalai/inference/tensor_parallel/modeling/__init__.py b/colossalai/legacy/inference/tensor_parallel/modeling/__init__.py
similarity index 100%
rename from colossalai/inference/tensor_parallel/modeling/__init__.py
rename to colossalai/legacy/inference/tensor_parallel/modeling/__init__.py
diff --git a/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py
new file mode 100644
index 000000000000..068b64b4f829
--- /dev/null
+++ b/colossalai/legacy/inference/tensor_parallel/modeling/_utils.py
@@ -0,0 +1,67 @@
+"""
+Utils for model inference
+"""
+import os
+
+import torch
+
+from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest
+
+
+def copy_kv_to_mem_cache(layer_id, key_buffer, value_buffer, context_mem_index, mem_manager):
+ """
+ This function copies the key and value cache to the memory cache
+ Args:
+ layer_id : id of current layer
+ key_buffer : key cache
+ value_buffer : value cache
+ context_mem_index : index of memory cache in kv cache manager
+ mem_manager : cache manager
+ """
+ copy_kv_cache_to_dest(key_buffer, context_mem_index, mem_manager.key_buffer[layer_id])
+ copy_kv_cache_to_dest(value_buffer, context_mem_index, mem_manager.value_buffer[layer_id])
+
+
+def init_to_get_rotary(self, base=10000, use_elem=False):
+ """
+ This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
+ Args:
+ self : Model that holds the rotary positional embedding
+ base : calculation arg
+ use_elem : activated when using chatglm-based models
+ """
+ self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
+ if not hasattr(self.config, "rope_scaling"):
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
+
+ if hasattr(self.config, "max_sequence_length"):
+ max_seq_len = self.config.max_sequence_length
+ elif hasattr(self.config, "max_position_embeddings"):
+ max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
+ else:
+ max_seq_len = 2048 * rope_scaling_factor
+ base = float(base)
+
+ # NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
+ ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
+
+ if ntk_alpha is not None:
+ ntk_alpha = float(ntk_alpha)
+ assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
+ if ntk_alpha > 1:
+ print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
+
+ n_elem = self.config.head_dim_
+ if use_elem:
+ n_elem //= 2
+
+ inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
+ t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
+ freqs = torch.outer(t, inv_freq)
+
+ self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
+ self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/legacy/inference/tensor_parallel/modeling/bloom.py
similarity index 93%
rename from colossalai/inference/tensor_parallel/modeling/bloom.py
rename to colossalai/legacy/inference/tensor_parallel/modeling/bloom.py
index 27a26caabefa..74fa5f470bf8 100644
--- a/colossalai/inference/tensor_parallel/modeling/bloom.py
+++ b/colossalai/legacy/inference/tensor_parallel/modeling/bloom.py
@@ -19,6 +19,15 @@
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd
+try:
+ from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import (
+ context_attention_fwd as lightllm_bloom_context_attention_fwd,
+ )
+
+ HAS_LIGHTLLM_KERNEL = True
+except:
+ HAS_LIGHTLLM_KERNEL = False
+
def generate_alibi(n_head, dtype=torch.float16):
"""
@@ -133,17 +142,11 @@ def bloom_model_forward(
assert hasattr(self, "infer_state")
infer_state = self.infer_state
- # Compute alibi tensor: check build_alibi_tensor documentation
- seq_length_with_past = seq_length
- past_key_values_length = 0
- # if self.cache_manager.past_key_values_length > 0:
- if infer_state.cache_manager.past_key_values_length > 0:
- # update the past key values length in cache manager,
- # NOTE use BatchInferState.past_key_values_length instead the one in cache manager
- past_key_values_length = infer_state.cache_manager.past_key_values_length
- seq_length_with_past = seq_length_with_past + past_key_values_length
-
# infer_state.cache_manager = self.cache_manager
+ if infer_state.is_context_stage:
+ past_key_values_length = 0
+ else:
+ past_key_values_length = infer_state.max_len_in_batch - 1
if use_cache and seq_length != 1:
# prefill stage
@@ -160,21 +163,19 @@ def bloom_model_forward(
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
- infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
- print(
- f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
- )
+ print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
- infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+ attention_mask = torch.ones((batch_size, infer_state.max_len_in_batch), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
@@ -195,6 +196,7 @@ def bloom_model_forward(
past_key_values_length=past_key_values_length,
)
+ infer_state.decode_layer_id = 0
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -228,6 +230,7 @@ def custom_forward(*inputs):
infer_state=infer_state,
)
+ infer_state.decode_layer_id += 1
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
@@ -247,7 +250,7 @@ def custom_forward(*inputs):
# and update these information in engine.generate after model foward called
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
- infer_state.decode_layer_id = 0
+ infer_state.max_len_in_batch += 1
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
@@ -453,9 +456,6 @@ def bloom_attention_forward(
mem_manager = infer_state.cache_manager
layer_id = infer_state.decode_layer_id
- if layer_id == 0: # once per model.forward
- infer_state.cache_manager.past_key_values_length += q_length # += 1
-
if infer_state.is_context_stage:
# context process
max_input_len = q_length
@@ -469,7 +469,10 @@ def bloom_attention_forward(
# output = self.output[:batch_size*q_length, :, :]
output = torch.empty_like(q)
- bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
+ if HAS_LIGHTLLM_KERNEL:
+ lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len)
+ else:
+ bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
else:
@@ -506,15 +509,12 @@ def bloom_attention_forward(
b_loc,
b_start_loc,
b_seq_len,
- infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
alibi,
)
context_layer = output.view(batch_size, q_length, H * D_HEAD)
- # update layer id
- infer_state.decode_layer_id += 1
-
# NOTE: always set present as none for now, instead of returning past key value to the next decoding,
# we create the past key value pair from the cache manager
present = None
diff --git a/colossalai/inference/tensor_parallel/modeling/chatglm2.py b/colossalai/legacy/inference/tensor_parallel/modeling/chatglm2.py
similarity index 97%
rename from colossalai/inference/tensor_parallel/modeling/chatglm2.py
rename to colossalai/legacy/inference/tensor_parallel/modeling/chatglm2.py
index b8274d3c660f..b8fe8eb54855 100644
--- a/colossalai/inference/tensor_parallel/modeling/chatglm2.py
+++ b/colossalai/legacy/inference/tensor_parallel/modeling/chatglm2.py
@@ -19,8 +19,11 @@
from ._utils import copy_kv_to_mem_cache
try:
- from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_llama2_context_attention_fwd
from lightllm.models.chatglm2.triton_kernel.rotary_emb import rotary_emb_fwd as chatglm2_rotary_emb_fwd
+ from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
+ context_attention_fwd as lightllm_llama2_context_attention_fwd,
+ )
+
HAS_LIGHTLLM_KERNEL = True
except:
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
@@ -118,13 +121,12 @@ def chatglm_for_conditional_generation_forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
- past_key_values_length = 0
+ if infer_state.is_context_stage:
+ past_key_values_length = 0
+ else:
+ past_key_values_length = infer_state.max_len_in_batch - 1
- # NOT READY FOR PRIME TIME
- # dummy but work, revise it
- past_key_values_length = infer_state.cache_manager.past_key_values_length
seq_length_with_past = seq_length + past_key_values_length
- infer_state.seq_length_with_past = seq_length_with_past
# prefill stage at first
if use_cache and seq_length != 1:
@@ -272,7 +274,6 @@ def chatglm_model_forward(
infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
- infer_state.cache_manager.past_key_values_length += seq_length
if not return_dict:
return tuple(
@@ -394,9 +395,9 @@ def chatglm_flash_attn_kvcache_forward(
assert use_cache is True, "use_cache should be set to True using this chatglm attention"
# hidden_states: original :[sq, b, h] --> this [b, sq, h]
batch_size = hidden_states.shape[0]
+ hidden_size = hidden_states.shape[-1]
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer = self.query_key_value(hidden_states)
-
if self.multi_query_attention:
(query_layer, key_layer, value_layer) = mixed_x_layer.split(
[
@@ -436,7 +437,6 @@ def chatglm_flash_attn_kvcache_forward(
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
-
cos, sin = infer_state.position_cos, infer_state.position_sin
chatglm2_rotary_emb_fwd(
@@ -465,10 +465,10 @@ def chatglm_flash_attn_kvcache_forward(
value_layer = value_layer.reshape(
-1, self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head
)
+
if infer_state.is_context_stage:
# first token generation:
# copy key and value calculated in current step to memory manager
-
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
key_layer,
@@ -476,8 +476,7 @@ def chatglm_flash_attn_kvcache_forward(
infer_state.context_mem_index,
infer_state.cache_manager,
)
-
- attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
+ attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
# NOTE: no bug in context attn fwd (del it )
lightllm_llama2_context_attention_fwd(
@@ -487,7 +486,7 @@ def chatglm_flash_attn_kvcache_forward(
attn_output.view(-1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head),
infer_state.start_loc,
infer_state.seq_len,
- infer_state.seq_length_with_past,
+ infer_state.max_len_in_batch,
)
else:
@@ -513,7 +512,7 @@ def chatglm_flash_attn_kvcache_forward(
)
# second token and follows
- attn_output = torch.empty_like(query_layer.view(-1, self.projection_size))
+ attn_output = torch.empty_like(query_layer.contiguous().view(-1, self.projection_size))
cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id][
: infer_state.decode_mem_end, :, :
]
@@ -541,6 +540,6 @@ def chatglm_flash_attn_kvcache_forward(
# =================
# Output:[b,sq, h]
# =================
+ output = self.dense(attn_output).reshape(batch_size, -1, hidden_size)
- output = self.dense(attn_output).reshape(batch_size, -1, self.projection_size)
return output, kv_cache
diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/legacy/inference/tensor_parallel/modeling/llama.py
similarity index 73%
rename from colossalai/inference/tensor_parallel/modeling/llama.py
rename to colossalai/legacy/inference/tensor_parallel/modeling/llama.py
index a3937f6f10ba..448943b12c9e 100644
--- a/colossalai/inference/tensor_parallel/modeling/llama.py
+++ b/colossalai/legacy/inference/tensor_parallel/modeling/llama.py
@@ -1,32 +1,17 @@
+import math
from typing import List, Optional, Tuple
import torch
from transformers.modeling_outputs import BaseModelOutputWithPast
-from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel, LlamaRMSNorm
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState
from colossalai.kernel.triton import llama_context_attn_fwd, token_attention_fwd
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
-
from ._utils import copy_kv_to_mem_cache
-
-try:
- from vllm import layernorm_ops, pos_encoding_ops
-
- rms_norm = layernorm_ops.rms_norm
- rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
- HAS_VLLM_KERNERL = True
-except:
- print("fall back to original rotary_embedding_neox of huggingface")
- print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
- print(
- "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch"
- )
- HAS_VLLM_KERNERL = False
-
try:
- from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import (
- context_attention_fwd as lightllm_llama2_context_attention_fwd,
+ from lightllm.models.llama.triton_kernel.context_flashattention_nopad import (
+ context_attention_fwd as lightllm_llama_context_attention_fwd,
)
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd
@@ -35,6 +20,14 @@
print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm")
HAS_LIGHTLLM_KERNEL = False
+try:
+ from flash_attn import flash_attn_with_kvcache
+
+ HAS_FLASH_KERNEL = True
+except:
+ HAS_FLASH_KERNEL = False
+ print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention")
+
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
@@ -55,6 +48,64 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed
+def llama_triton_context_attention(
+ query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1
+):
+ # if num_key_value_groups == 1:
+ if HAS_LIGHTLLM_KERNEL is False:
+ llama_context_attn_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ # infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
+ )
+ else:
+ lightllm_llama_context_attention_fwd(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ # infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
+ )
+
+
+def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1):
+ assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models"
+ if num_key_value_groups == 1:
+ token_attention_fwd(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ # infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
+ )
+
+ else:
+ Llama2TokenAttentionForwards.token_attn(
+ query_states,
+ infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
+ infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
+ attn_output,
+ infer_state.block_loc,
+ infer_state.start_loc,
+ infer_state.seq_len,
+ # infer_state.cache_manager.past_key_values_length,
+ infer_state.max_len_in_batch,
+ infer_state.other_kv_index,
+ )
+
+
class LlamaInferenceForwards:
"""
This class holds forwards for llama inference.
@@ -74,12 +125,11 @@ def llama_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
- batch_size = input_ids.shape[0] # input_ids.shape[0]
-
infer_state = self.infer_state
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
@@ -90,15 +140,10 @@ def llama_model_forward(
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
- seq_length_with_past = seq_length
- past_key_values_length = 0
-
- if past_key_values is not None:
- # NOT READY FOR PRIME TIME
- # dummy but work, revise it
- past_key_values_length = infer_state.cache_manager.past_key_values_length
- # past_key_values_length = past_key_values[0][0].shape[2]
- seq_length_with_past = seq_length_with_past + past_key_values_length
+ if infer_state.is_context_stage:
+ past_key_values_length = 0
+ else:
+ past_key_values_length = infer_state.max_len_in_batch - 1
# NOTE: differentiate with prefill stage
# block_loc require different value-assigning method for two different stage
@@ -118,23 +163,23 @@ def llama_model_forward(
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
- infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
else:
print(f" *** Encountered allocation non-contiguous")
- print(
- f" infer_state.cache_manager.past_key_values_length: {infer_state.cache_manager.past_key_values_length}"
- )
+ print(f" infer_state.max_len_in_batch : {infer_state.max_len_in_batch}")
infer_state.decode_is_contiguous = False
alloc_mem = infer_state.cache_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
# infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
# infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
- infer_state.block_loc[:, seq_length_with_past - 1] = infer_state.decode_mem_index
+ infer_state.block_loc[:, infer_state.max_len_in_batch - 1] = infer_state.decode_mem_index
+
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
+ position_ids = position_ids.repeat(batch_size, 1)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
@@ -146,11 +191,12 @@ def llama_model_forward(
infer_state.position_sin = torch.index_select(self._sin_cached, 0, position_ids.view(-1)).view(
position_ids.view(-1).shape[0], -1
)
+
else:
seq_len = infer_state.seq_len
infer_state.position_cos = torch.index_select(self._cos_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
infer_state.position_sin = torch.index_select(self._sin_cached, 0, seq_len - 1).view(seq_len.shape[0], -1)
- infer_state.other_kv_index = infer_state.block_loc[0, seq_length_with_past - 1].item()
+ infer_state.other_kv_index = infer_state.block_loc[0, infer_state.max_len_in_batch - 1].item()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@@ -158,7 +204,7 @@ def llama_model_forward(
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ (batch_size, infer_state.max_len_in_batch), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
@@ -173,7 +219,6 @@ def llama_model_forward(
next_decoder_cache = () if use_cache else None
infer_state.decode_layer_id = 0
-
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx] if past_key_values is not None else None
# NOTE: modify here for passing args to decoder layer
@@ -197,8 +242,9 @@ def llama_model_forward(
# update indices
# infer_state.block_loc[:, infer_state.max_len_in_batch-1] = infer_state.total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
- infer_state.start_loc = infer_state.start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
+ infer_state.start_loc += torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
infer_state.seq_len += 1
+ infer_state.max_len_in_batch += 1
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
@@ -224,7 +270,6 @@ def llama_decoder_layer_forward(
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
-
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
@@ -280,11 +325,8 @@ def llama_flash_attn_kvcache_forward(
# NOTE might want to revise
# need some way to record the length of past key values cache
# since we won't return past_key_value_cache right now
- if infer_state.decode_layer_id == 0: # once per model.forward
- infer_state.cache_manager.past_key_values_length += q_len # seq_len
cos, sin = infer_state.position_cos, infer_state.position_sin
- # print("shape ", cos.shape, query_states.view(-1, self.num_heads, self.head_dim).shape, )
llama_rotary_embedding_fwd(query_states.view(-1, self.num_heads, self.head_dim), cos, sin)
llama_rotary_embedding_fwd(key_states.view(-1, self.num_key_value_heads, self.head_dim), cos, sin)
@@ -295,7 +337,6 @@ def llama_flash_attn_kvcache_forward(
if infer_state.is_context_stage:
# first token generation
-
# copy key and value calculated in current step to memory manager
copy_kv_to_mem_cache(
infer_state.decode_layer_id,
@@ -304,29 +345,16 @@ def llama_flash_attn_kvcache_forward(
infer_state.context_mem_index,
infer_state.cache_manager,
)
-
attn_output = torch.empty_like(query_states)
- if self.num_key_value_groups == 1:
- llama_context_attn_fwd(
- query_states,
- key_states,
- value_states,
- attn_output,
- infer_state.start_loc,
- infer_state.seq_len,
- infer_state.cache_manager.past_key_values_length,
- )
- else:
- lightllm_llama2_context_attention_fwd(
- query_states,
- key_states,
- value_states,
- attn_output,
- infer_state.start_loc,
- infer_state.seq_len,
- infer_state.cache_manager.past_key_values_length,
- )
+ llama_triton_context_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_output,
+ infer_state,
+ num_key_value_groups=self.num_key_value_groups,
+ )
else:
if infer_state.decode_is_contiguous:
# if decode is contiguous, then we copy to key cache and value cache in cache manager directly
@@ -349,33 +377,26 @@ def llama_flash_attn_kvcache_forward(
infer_state.cache_manager,
)
- # second token and follows
- # kv = torch.stack((key_states, value_states), dim=2)
- # (batch_size, seqlen, nheads, headdim)
- attn_output = torch.empty_like(query_states)
-
- if self.num_key_value_groups == 1:
- token_attention_fwd(
- query_states,
- infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
- infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
- attn_output,
- infer_state.block_loc,
- infer_state.start_loc,
- infer_state.seq_len,
- infer_state.cache_manager.past_key_values_length,
+ if HAS_LIGHTLLM_KERNEL:
+ attn_output = torch.empty_like(query_states)
+ llama_triton_token_attention(
+ query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups
)
else:
- Llama2TokenAttentionForwards.token_attn(
- query_states,
- infer_state.cache_manager.key_buffer[infer_state.decode_layer_id],
- infer_state.cache_manager.value_buffer[infer_state.decode_layer_id],
- attn_output,
- infer_state.block_loc,
- infer_state.start_loc,
- infer_state.seq_len,
- infer_state.cache_manager.past_key_values_length,
- infer_state.other_kv_index,
+ self.num_heads // self.num_key_value_heads
+ cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id]
+ cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id]
+
+ query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim)
+ copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim)
+ copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim)
+
+ attn_output = flash_attn_with_kvcache(
+ q=query_states,
+ k_cache=copy_cache_k,
+ v_cache=copy_cache_v,
+ softmax_scale=1 / math.sqrt(self.head_dim),
+ causal=True,
)
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
@@ -384,23 +405,3 @@ def llama_flash_attn_kvcache_forward(
# return past_key_value as None
return attn_output, None, None
-
-
-def get_llama_vllm_rmsnorm_forward():
- if HAS_VLLM_KERNERL:
-
- def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
- x = hidden_states
- out = torch.empty_like(x)
- rms_norm(
- out,
- x,
- self.weight.data,
- self.variance_epsilon,
- )
-
- return out
-
- return _vllm_rmsnorm_forward
- else:
- return None
diff --git a/colossalai/inference/tensor_parallel/policies/__init__.py b/colossalai/legacy/inference/tensor_parallel/policies/__init__.py
similarity index 100%
rename from colossalai/inference/tensor_parallel/policies/__init__.py
rename to colossalai/legacy/inference/tensor_parallel/policies/__init__.py
diff --git a/colossalai/inference/tensor_parallel/policies/bloom.py b/colossalai/legacy/inference/tensor_parallel/policies/bloom.py
similarity index 51%
rename from colossalai/inference/tensor_parallel/policies/bloom.py
rename to colossalai/legacy/inference/tensor_parallel/policies/bloom.py
index 3d6df2097000..f980bdb53add 100644
--- a/colossalai/inference/tensor_parallel/policies/bloom.py
+++ b/colossalai/legacy/inference/tensor_parallel/policies/bloom.py
@@ -4,7 +4,6 @@
from torch.nn import LayerNorm
import colossalai.shardformer.layer as col_nn
-from colossalai.shardformer.modeling.bloom import build_bloom_alibi_tensor_fn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.bloom import BloomForCausalLMPolicy
@@ -38,35 +37,39 @@ def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomAttention, BloomBlock, BloomForCausalLM, BloomModel
policy = super().module_policy()
- if self.shard_config.inference_gptq:
+
+ if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
- policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
- "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
- "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
- },
- sub_module_replacement=[
- SubModuleReplacementDescription(
- suffix="self_attention.query_key_value",
- target_module=ColCaiQuantLinear,
- kwargs={'split_num': 3}),
- SubModuleReplacementDescription(
- suffix="self_attention.dense",
- target_module=RowCaiQuantLinear,
- kwargs={'split_num': 1}),
- SubModuleReplacementDescription(
- suffix="self_attention.attention_dropout",
- target_module=col_nn.DropoutForParallelInput,
- ),
- SubModuleReplacementDescription(
- suffix="mlp.dense_h_to_4h",
- target_module=ColCaiQuantLinear,
- kwargs={'split_num': 1}),
- SubModuleReplacementDescription(
- suffix="mlp.dense_4h_to_h",
- target_module=RowCaiQuantLinear,
- kwargs={'split_num': 1}),
- ])
+
+ policy[BloomBlock] = ModulePolicyDescription(
+ attribute_replacement={
+ "self_attention.hidden_size": self.model.config.hidden_size
+ // self.shard_config.tensor_parallel_size,
+ "self_attention.split_size": self.model.config.hidden_size
+ // self.shard_config.tensor_parallel_size,
+ "self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
+ },
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attention.query_key_value",
+ target_module=ColCaiQuantLinear,
+ kwargs={"split_num": 3},
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attention.dense", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attention.attention_dropout",
+ target_module=col_nn.DropoutForParallelInput,
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.dense_h_to_4h", target_module=ColCaiQuantLinear, kwargs={"split_num": 1}
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.dense_4h_to_h", target_module=RowCaiQuantLinear, kwargs={"split_num": 1}
+ ),
+ ],
+ )
# NOTE set inference mode to shard config
self.shard_config._infer()
diff --git a/colossalai/inference/tensor_parallel/policies/chatglm2.py b/colossalai/legacy/inference/tensor_parallel/policies/chatglm2.py
similarity index 90%
rename from colossalai/inference/tensor_parallel/policies/chatglm2.py
rename to colossalai/legacy/inference/tensor_parallel/policies/chatglm2.py
index 90f8b4fd2d7e..60dc511f5e96 100644
--- a/colossalai/inference/tensor_parallel/policies/chatglm2.py
+++ b/colossalai/legacy/inference/tensor_parallel/policies/chatglm2.py
@@ -48,7 +48,10 @@ def module_policy(self):
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=SelfAttention
)
-
+ if self.shard_config.enable_tensor_parallelism:
+ policy[GLMBlock].attribute_replacement["self_attention.num_multi_query_groups_per_partition"] = (
+ self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
+ )
# for rmsnorm and others, we need to check the shape
return policy
diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/legacy/inference/tensor_parallel/policies/llama.py
similarity index 94%
rename from colossalai/inference/tensor_parallel/policies/llama.py
rename to colossalai/legacy/inference/tensor_parallel/policies/llama.py
index 7e163efe0173..896d55712254 100644
--- a/colossalai/inference/tensor_parallel/policies/llama.py
+++ b/colossalai/legacy/inference/tensor_parallel/policies/llama.py
@@ -9,10 +9,11 @@
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
from ..modeling._utils import init_to_get_rotary
-from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward
+from ..modeling.llama import LlamaInferenceForwards
try:
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward
+
HAS_TRITON_RMSNORM = True
except:
print("you should install triton from https://github.com/openai/triton")
@@ -21,6 +22,7 @@
def get_triton_rmsnorm_forward():
if HAS_TRITON_RMSNORM:
+
def _triton_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor):
return lightllm_rmsnorm_forward(hidden_states, self.weight.data, self.variance_epsilon)
@@ -36,7 +38,7 @@ def __init__(self) -> None:
def module_policy(self):
policy = super().module_policy()
- if self.shard_config.inference_gptq:
+ if self.shard_config.extra_kwargs.get("inference_gptq", False):
from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear
decoder_attribute_replacement = {
@@ -105,9 +107,6 @@ def module_policy(self):
infer_forward = None
if HAS_TRITON_RMSNORM:
infer_forward = get_triton_rmsnorm_forward()
- else:
- # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123
- infer_forward = get_llama_vllm_rmsnorm_forward()
if infer_forward is not None:
method_replacement = {"forward": partial(infer_forward)}
diff --git a/colossalai/legacy/initialize.py b/colossalai/legacy/initialize.py
index ce9c626553bf..4035bd6b54ef 100644
--- a/colossalai/legacy/initialize.py
+++ b/colossalai/legacy/initialize.py
@@ -16,7 +16,6 @@
from torch.utils.data import DataLoader
from colossalai.context import Config, ConfigException
-from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.interface import OptimizerWrapper
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
from colossalai.legacy.amp.naive_amp import NaiveAMPModel
@@ -36,7 +35,6 @@
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
-from colossalai.utils.moe import sync_moe_model_param
def get_default_parser():
@@ -323,8 +321,6 @@ def initialize(
if not use_zero:
if is_using_sequence():
sync_model_param(model, ParallelMode.SEQUENCE_DP)
- elif MOE_CONTEXT.is_initialized:
- sync_moe_model_param(model)
elif is_using_ddp():
sync_model_param(model, ParallelMode.DATA)
else:
@@ -377,14 +373,6 @@ def initialize(
"added even though not specified in the configuration",
ranks=[0],
)
- elif is_using_ddp() and MOE_CONTEXT.is_initialized:
- gradient_handler_cfg = [dict(type="MoeGradientHandler")]
- if verbose:
- logger.info(
- "Data parallel training is detected with moe parallel, MoeGradientHandler is automatically "
- "added even though not specified in the configuration",
- ranks=[0],
- )
elif is_using_sequence():
model = DDP(
model,
diff --git a/colossalai/legacy/nn/layer/parallel_1d/layers.py b/colossalai/legacy/nn/layer/parallel_1d/layers.py
index 8304cd2e1eb7..b6ec5347f2e2 100644
--- a/colossalai/legacy/nn/layer/parallel_1d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_1d/layers.py
@@ -22,7 +22,7 @@
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule
diff --git a/colossalai/legacy/nn/layer/parallel_2d/layers.py b/colossalai/legacy/nn/layer/parallel_2d/layers.py
index 3b2e032e5127..f81c5334ad77 100644
--- a/colossalai/legacy/nn/layer/parallel_2d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2d/layers.py
@@ -18,7 +18,7 @@
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
diff --git a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
index fc2e35f36cbc..b451a4031c25 100644
--- a/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_2p5d/layers.py
@@ -19,7 +19,7 @@
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
diff --git a/colossalai/legacy/nn/layer/parallel_3d/layers.py b/colossalai/legacy/nn/layer/parallel_3d/layers.py
index 196679994197..16e515f87da3 100644
--- a/colossalai/legacy/nn/layer/parallel_3d/layers.py
+++ b/colossalai/legacy/nn/layer/parallel_3d/layers.py
@@ -27,7 +27,7 @@
partition_tensor_parallel_state_dict,
)
from colossalai.nn import init as init
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (
diff --git a/colossalai/legacy/nn/layer/vanilla/layers.py b/colossalai/legacy/nn/layer/vanilla/layers.py
index 12965a4a6409..590ad5ff6085 100644
--- a/colossalai/legacy/nn/layer/vanilla/layers.py
+++ b/colossalai/legacy/nn/layer/vanilla/layers.py
@@ -10,7 +10,7 @@
from colossalai.legacy.context import seed
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ..utils import to_2tuple
diff --git a/colossalai/legacy/utils/activation_checkpoint.py b/colossalai/legacy/utils/activation_checkpoint.py
index 387e1c54ec87..9a8051ae937f 100644
--- a/colossalai/legacy/utils/activation_checkpoint.py
+++ b/colossalai/legacy/utils/activation_checkpoint.py
@@ -7,7 +7,7 @@
from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
-from colossalai.utils import get_current_device
+from colossalai.utils.device import autocast, get_current_device
def copy_to_device(obj, device):
@@ -110,7 +110,7 @@ def backward(ctx, *args):
inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd:
- with torch.enable_grad(), torch.cuda.amp.autocast():
+ with torch.enable_grad(), autocast():
outputs = ctx.run_function(*detached_inputs)
else:
with torch.enable_grad():
@@ -226,7 +226,7 @@ def inner_unpack(packed):
# rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd:
- with torch.enable_grad(), torch.cuda.amp.autocast(), torch.autograd.graph.saved_tensors_hooks(
+ with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks(
inner_pack, inner_unpack
):
_unused = function(*args)
diff --git a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
index 19f77d4305af..e336717f4164 100644
--- a/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
+++ b/colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
@@ -3,7 +3,7 @@
from time import time
from typing import List
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
diff --git a/colossalai/moe/__init__.py b/colossalai/moe/__init__.py
new file mode 100644
index 000000000000..721da69d0741
--- /dev/null
+++ b/colossalai/moe/__init__.py
@@ -0,0 +1,17 @@
+from .checkpoint import MoECheckpintIO
+from .experts import MLPExperts
+from .layers import SparseMLP
+from .routers import MoeRouter, Top1Router, Top2Router, TopKRouter
+from .utils import NormalNoiseGenerator, UniformNoiseGenerator
+
+__all__ = [
+ "MLPExperts",
+ "MoeRouter",
+ "Top1Router",
+ "Top2Router",
+ "TopKRouter",
+ "NormalNoiseGenerator",
+ "UniformNoiseGenerator",
+ "SparseMLP",
+ "MoECheckpintIO",
+]
diff --git a/colossalai/moe/_operation.py b/colossalai/moe/_operation.py
new file mode 100644
index 000000000000..c71e6c1f40c7
--- /dev/null
+++ b/colossalai/moe/_operation.py
@@ -0,0 +1,336 @@
+from typing import Any, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+from torch.cuda.amp import custom_bwd, custom_fwd
+from torch.distributed import ProcessGroup
+
+MOE_KERNEL = None
+
+
+def load_moe():
+ global MOE_KERNEL
+ from colossalai.kernel.op_builder import MOEBuilder
+
+ MOE_KERNEL = MOEBuilder().load()
+
+
+class AllGather(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ inputs: Tensor,
+ group: Optional[ProcessGroup] = None,
+ overlap: bool = False,
+ ) -> Tuple[Tensor, Any]:
+ """
+ Returns:
+ outputs: Tensor
+ handle: Optional[Work], if overlap is True
+ """
+ assert ctx is not None or not overlap
+
+ if ctx is not None:
+ ctx.comm_grp = group
+
+ comm_size = dist.get_world_size(group)
+ if comm_size == 1:
+ return inputs.unsqueeze(0), None
+
+ buffer_shape = (comm_size,) + inputs.shape
+ outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
+ buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
+ if not overlap:
+ dist.all_gather(buffer_list, inputs, group=group)
+ return outputs, None
+ else:
+ handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
+ return outputs, handle
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
+ return (
+ ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
+ None,
+ None,
+ )
+
+
+class ReduceScatter(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: Any,
+ inputs: Tensor,
+ group: ProcessGroup,
+ overlap: bool = False,
+ ) -> Tuple[Tensor, Any]:
+ """
+ Returns:
+ outputs: Tensor
+ handle: Optional[Work], if overlap is True
+ """
+ assert ctx is not None or not overlap
+
+ if ctx is not None:
+ ctx.comm_grp = group
+
+ comm_size = dist.get_world_size(group)
+ if comm_size == 1:
+ return inputs.squeeze(0), None
+
+ if not inputs.is_contiguous():
+ inputs = inputs.contiguous()
+
+ output_shape = inputs.shape[1:]
+ outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
+ buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
+ if not overlap:
+ dist.reduce_scatter(outputs, buffer_list, group=group)
+ return outputs, None
+ else:
+ handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
+ return outputs, handle
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
+ # TODO: support async backward
+ return (
+ AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
+ None,
+ None,
+ )
+
+
+class AllToAll(torch.autograd.Function):
+ """Dispatches input tensor [e, c, h] to all experts by all_to_all_single
+ operation in torch.distributed.
+ """
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ inputs: Tensor,
+ group: ProcessGroup,
+ overlap: bool = False,
+ ) -> Tuple[Tensor, Any]:
+ """
+ Returns:
+ outputs: Tensor
+ handle: Optional[Work], if overlap is True
+ """
+ assert ctx is not None or not overlap
+
+ if ctx is not None:
+ ctx.comm_grp = group
+ if not inputs.is_contiguous():
+ inputs = inputs.contiguous()
+ if dist.get_world_size(group) == 1:
+ return inputs, None
+ output = torch.empty_like(inputs)
+ if not overlap:
+ dist.all_to_all_single(output, inputs, group=group)
+ return output, None
+ else:
+ handle = dist.all_to_all_single(output, inputs, group=group, async_op=True)
+ return output, handle
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
+ return (
+ AllToAll.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
+ None,
+ None,
+ )
+
+
+class HierarchicalAllToAll(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: Any,
+ inputs: Tensor,
+ groups: Tuple[ProcessGroup, ProcessGroup],
+ src_rank: int
+ ) -> Tensor:
+ """
+ Returns:
+ outputs: Tensor
+ """
+ # TODO: we can reduce comm volume by removing empty capacity
+ if ctx is not None:
+ ctx.comm_grps = groups
+ ctx.src_rank = src_rank
+ intra_node_group, inter_node_group = groups
+
+ local_world_size = dist.get_world_size(intra_node_group)
+ num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
+ world_size = local_world_size * num_group
+ outputs = torch.empty_like(inputs)
+
+ if dist.get_rank() == src_rank:
+ # intra-node gather
+ intra_output = [torch.empty_like(inputs) for _ in range(local_world_size)]
+ dist.gather(inputs, intra_output, dst=src_rank, group=intra_node_group)
+
+ intra_output = [v.chunk(world_size, dim=0) for v in intra_output]
+ intra_output = torch.cat(sum(zip(*intra_output), ()))
+
+ # inter-node all-to-all
+ if inter_node_group is not None:
+ inter_output = torch.empty_like(intra_output)
+ dist.all_to_all_single(inter_output, intra_output, group=inter_node_group)
+
+ # layout transform
+ inter_output = inter_output.chunk(num_group, dim=0)
+ inter_output = [v.chunk(local_world_size, dim=0) for v in inter_output]
+ intra_output = torch.cat(sum(zip(*inter_output), ()))
+
+ # intra-node scatter
+ intra_output = list(intra_output.chunk(local_world_size, dim=0))
+ dist.scatter(outputs, intra_output, src=src_rank, group=intra_node_group)
+
+ else:
+ dist.gather(inputs, dst=src_rank, group=intra_node_group)
+ dist.scatter(outputs, src=src_rank, group=intra_node_group)
+
+ return outputs
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
+ return (
+ HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
+ None,
+ None,
+ )
+
+
+class MoeDispatch(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, tokens, mask, dest_idx, ec):
+ s = tokens.size(0)
+ h = tokens.size(1)
+ dtype = tokens.dtype
+
+ if MOE_KERNEL is None:
+ load_moe()
+ if tokens.dtype != torch.float32:
+ tokens = tokens.to(torch.float32)
+ expert_input = MOE_KERNEL.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
+ if expert_input.dtype != dtype:
+ expert_input = expert_input.to(dtype)
+ ctx.save_for_backward(mask, dest_idx)
+ ctx.s = s
+ ctx.h = h
+ ctx.ec = ec
+ ctx.dtype = dtype
+
+ return expert_input
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, output_grad):
+ mask, dest_idx = ctx.saved_tensors
+ if output_grad.dtype != torch.float32:
+ output_grad = output_grad.to(torch.float32)
+ d_tokens = MOE_KERNEL.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
+ if d_tokens.dtype != ctx.dtype:
+ d_tokens = d_tokens.to(ctx.dtype)
+ return d_tokens, None, None, None
+
+
+class MoeCombine(torch.autograd.Function):
+ @staticmethod
+ @custom_fwd
+ def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
+ assert logits.dtype == torch.float32
+
+ s = logits.size(0)
+ e = logits.size(1)
+ c = ec // e
+ h = expert_tokens.size(-1)
+ dtype = expert_tokens.dtype
+
+ if expert_tokens.dtype != torch.float32:
+ expert_tokens = expert_tokens.to(torch.float32)
+ if MOE_KERNEL is None:
+ load_moe()
+ output = MOE_KERNEL.combine_forward(s, e, c, h, expert_tokens, logits, mask, dest_idx)
+ if output.dtype != dtype:
+ output = output.to(dtype)
+
+ ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
+ ctx.s = s
+ ctx.e = e
+ ctx.c = c
+ ctx.h = h
+ ctx.dtype = dtype
+
+ return output
+
+ @staticmethod
+ @custom_bwd
+ def backward(ctx, tokens_grad):
+ expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
+ if tokens_grad.dtype != torch.float32:
+ tokens_grad = tokens_grad.to(torch.float32)
+
+ d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
+ mask, dest_idx)
+ if d_expert.dtype != ctx.dtype:
+ d_expert = d_expert.to(ctx.dtype)
+
+ return d_expert, d_logits, None, None, None
+
+
+def moe_cumsum(inputs: Tensor, use_kernel: bool = False):
+ dim0 = inputs.size(0)
+ flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
+ if flag and use_kernel:
+ if MOE_KERNEL is None:
+ load_moe()
+ return MOE_KERNEL.cumsum_sub_one(inputs)
+ else:
+ return torch.cumsum(inputs, dim=0) - 1
+
+
+class MoeInGradScaler(torch.autograd.Function):
+ """
+ Scale the gradient back by the number of experts
+ because the batch size increases in the moe stage
+ """
+
+ @staticmethod
+ def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
+ if ctx is not None:
+ ctx.ep_size = ep_size
+ return inputs
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
+ assert len(grad_outputs) == 1
+ grad = grad_outputs[0]
+ if ctx.ep_size != 1:
+ grad = grad * ctx.ep_size
+ return grad, None
+
+
+class MoeOutGradScaler(torch.autograd.Function):
+ """
+ Scale the gradient by the number of experts
+ because the batch size increases in the moe stage
+ """
+
+ @staticmethod
+ def forward(ctx: Any, inputs: Tensor, ep_size: int) -> Tensor:
+ ctx.ep_size = ep_size
+ return inputs
+
+ @staticmethod
+ def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
+ assert len(grad_outputs) == 1
+ grad = grad_outputs[0]
+ if ctx.ep_size != 1:
+ grad = grad / ctx.ep_size
+ return grad, None
diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py
new file mode 100644
index 000000000000..a8c50eab66e3
--- /dev/null
+++ b/colossalai/moe/checkpoint.py
@@ -0,0 +1,781 @@
+import copy
+import logging
+import os
+from pathlib import Path
+from shutil import rmtree
+from typing import Dict, Iterator, Optional, OrderedDict, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.distributed import ProcessGroup
+
+from colossalai.checkpoint_io import CheckpointIndexFile, HybridParallelCheckpointIO
+from colossalai.checkpoint_io.utils import (
+ StateDictSharder,
+ gather_distributed_param,
+ get_model_base_filenames,
+ get_optimizer_base_filenames,
+ is_safetensors_available,
+ load_shard_state_dict,
+ load_state_dict,
+ load_state_dict_into_model,
+ load_states_into_optimizer,
+ save_config_file,
+ save_param_groups,
+ save_state_dict,
+ save_state_dict_shards,
+ sharded_optimizer_loading_epilogue,
+)
+from colossalai.interface import OptimizerWrapper
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.tensor.moe_tensor.api import (
+ get_dp_group,
+ get_dp_rank,
+ get_dp_size,
+ get_ep_group,
+ get_ep_rank,
+ get_ep_size,
+ is_moe_tensor,
+)
+
+
+class MoECheckpintIO(HybridParallelCheckpointIO):
+ def __init__(
+ self,
+ dp_group: ProcessGroup,
+ pp_group: ProcessGroup,
+ tp_group: ProcessGroup,
+ zero_stage: int,
+ ) -> None:
+ assert zero_stage in [
+ 0,
+ 1,
+ 2,
+ ], f"zero_stage should be 0 or 1 or 2, got {zero_stage}"
+ super().__init__(dp_group, pp_group, tp_group, zero_stage)
+ self.parallel = MOE_MANAGER.parallel
+
+ def pre_load_model(self, model: nn.Module, state_dict: dict) -> dict:
+ """
+ Preprocess state_dict before loading and slice the state_dict of MOE tensors.
+ """
+ for name, param in state_dict.items():
+ if ".experts." in name:
+ if name in dict(model.named_parameters()):
+ model_param = dict(model.named_parameters())[name]
+ if is_moe_tensor(model_param):
+ ep_rank = get_ep_rank(model_param)
+ ep_size = get_ep_size(model_param)
+ expert_num = param.shape[0] // ep_size
+ assert param.shape[0] % ep_size == 0
+ param = param[ep_rank * expert_num : (ep_rank + 1) * expert_num]
+ state_dict[name] = param
+ dist.barrier()
+ return state_dict
+
+ def _model_sharder(
+ self,
+ state_dict: nn.Module,
+ prefix: str = "",
+ keep_vars: bool = False,
+ size_per_shard: int = 1024,
+ ) -> Iterator[Tuple[OrderedDict, int]]:
+ # An internel method that breaks state_dict of model into shards within limited size.
+ state_dict_sharder = StateDictSharder(size_per_shard)
+
+ for name, param in state_dict.items():
+ if param is None:
+ continue
+ # Gather tensor pieces when using tensor parallel.
+ param_ = gather_distributed_param(param, keep_vars=False)
+ block, block_size = state_dict_sharder.append_param(prefix + name, param_)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool) -> None:
+ state_dict = torch.load(checkpoint)
+ state_dict = self.pre_load_model(model, state_dict)
+ model.load_state_dict(state_dict, strict=strict if self.pp_size == 1 else False)
+
+ def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
+ """
+ Load sharded model with the given path to index file of checkpoint folder.
+
+ Args:
+ model (nn.Module): The model to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ strict (bool, optional): For name matching during loading state_dict. Defaults to False.
+ This argument should be manually set to False since params on same device might be stored in different files.
+ """
+
+ # Check whether the checkpoint uses safetensors.
+ use_safetensors = False
+ if "safetensors" in checkpoint_index_file.name:
+ use_safetensors = True
+
+ if use_safetensors and not is_safetensors_available():
+ raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ strict = False
+
+ # Load params & buffers to model.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+
+ def _load(name: str):
+ if name not in weight_map:
+ raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
+ filename = weight_map[name]
+
+ # If this param/buffer has been loaded before, directly return.
+ if filename in loaded_file:
+ return
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
+ state_dict = self.pre_load_model(model, state_dict)
+ missing_keys = []
+
+ load_state_dict_into_model(
+ model,
+ state_dict,
+ missing_keys=missing_keys,
+ strict=strict,
+ load_sub_module=True,
+ )
+ loaded_file.add(filename)
+
+ # Load parameters.
+ for name, _ in model.named_parameters():
+ _load(name)
+
+ if self.verbose:
+ logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+
+ def pre_save_model(self, model: nn.Module) -> dict:
+ state_dict = model.state_dict()
+ for name, param in model.named_parameters():
+ if ".experts." in name and is_moe_tensor(param):
+ ep_group = get_ep_group(param)
+ ep_rank = get_ep_rank(param)
+ ep_size = get_ep_size(param)
+ dp_rank = get_dp_rank(param)
+ if dp_rank == 0:
+ param = param.data.cuda()
+ all_param = [torch.zeros_like(param) for _ in range(ep_size)]
+ # gather param from every ep rank
+ dist.all_gather(all_param, param, group=ep_group)
+ if ep_rank == 0:
+ all_param = torch.cat(all_param, dim=0)
+ state_dict[name] = all_param.cpu()
+ if self.pp_size > 1:
+ if self.dp_rank == 0:
+ out = [None for _ in range(self.pp_size)]
+ dist.all_gather_object(out, state_dict, group=self.pp_group)
+ if self.pp_rank == 0:
+ new_state_dict = {}
+ for o in out:
+ new_state_dict.update(o)
+ state_dict = new_state_dict
+ dist.barrier()
+ return state_dict
+
+ def save_unsharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool,
+ use_safetensors: bool,
+ ):
+ state_dict = self.pre_save_model(model)
+ if dist.get_rank() == 0:
+ torch.save(state_dict, checkpoint)
+ dist.barrier()
+
+ def save_sharded_model(
+ self,
+ model: nn.Module,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ use_safetensors: bool = False,
+ ) -> None:
+ """
+ Save sharded model checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
+ - Multiple files that store state tensors of models.
+ The filenames are in the form of "pytorch_model.-000XX.bin"
+
+ Args:
+ model (nn.Module): Model on local device to be saved.
+ checkpoint (str): Checkpointing path which should be a directory path.
+ gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
+ prefix (str, optional): Perfix of file to save. Defaults to None.
+ size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
+ use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
+ """
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Then collect the sharded parameters & buffers along tp_group.
+ # Only devices with tp_rank == 0 are responsible for model saving.
+ state_dict = self.pre_save_model(model)
+
+ if dist.get_rank() == 0:
+ state_dict_shard = self._model_sharder(state_dict, size_per_shard=size_per_shard)
+
+ # Devices along the same dp_group share the same copies of model.
+ # So only let the device with dp_rank == 0 save the model.
+ if self.dp_rank != 0:
+ return
+
+ weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = self.tp_rank == 0
+
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=weights_name,
+ is_master=control_saving,
+ use_safetensors=use_safetensors,
+ )
+ if control_saving:
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ save_config_file(model, checkpoint)
+ if self.verbose:
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+ dist.barrier()
+
+ # ========================================================
+ # Abstract methods for optimizer loading/saving implementation
+ # ========================================================
+
+ def pre_load_optim(
+ self,
+ state: OrderedDict,
+ working_param,
+ current_shape: torch.Size,
+ original_shape: torch.Size,
+ device: torch.device,
+ inplace: bool,
+ ) -> OrderedDict:
+ """
+ With complete optimizer states of a specific parameter loaded from checkpoint,
+ slice out the sharded optimizer states kept by current device.
+
+ Args:
+ state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
+ current_shape (torch.Size): The size of parameter after sharding.
+ original_shape (torch.Size): The size of parameter before sharding.
+ device (torch.device): The destination device of loaded optimizer states.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+
+ Returns:
+ OrderedDict: The sharded optimizer state of the given parameter.
+ """
+ state_ = state if inplace else copy.deepcopy(state)
+ is_moe_tensor_flag = is_moe_tensor(working_param)
+ if is_moe_tensor_flag:
+ ep_rank = get_ep_rank(working_param)
+ ep_size = get_ep_size(working_param)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != "step":
+ if is_moe_tensor_flag:
+ with torch.no_grad():
+ expert_num = v.shape[0] // ep_size
+ assert v.shape[0] % ep_size == 0
+ v = v[ep_rank * expert_num : (ep_rank + 1) * expert_num]
+ else:
+ # Shard state along data parallel group when using Zero.
+ padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
+ with torch.no_grad():
+ v = v.flatten()
+ if padding_size > 0:
+ v = torch.nn.functional.pad(v, [0, padding_size])
+ slice_size = v.numel() // self.dp_size
+ v = v.split(slice_size, dim=0)[self.dp_rank]
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
+
+ def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
+ """
+ Load sharded optimizer with the given path to index file of checkpoint folder.
+
+ Args:
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the index file of checkpointing folder.
+ prefix (str): Not used.
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ return optimizer.param_info["param2id"][id(working_param)]
+
+ # id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
+ # When Zero is used, the mapped parameter objects should be fp32 master parameters.
+ # IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
+ id_map = {}
+ master_to_working_map = optimizer.get_master_to_working_map()
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ id_map[param_id] = param
+
+ # Read checkpoint index file.
+ ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
+ ckpt_root_path = ckpt_index_file.root_path
+ weight_map = ckpt_index_file.weight_map
+ weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
+
+ # Load param_groups
+ param_group_path = ckpt_index_file.get_param_group_filename()
+ if param_group_path is None:
+ raise RuntimeError(
+ f"Invalid index file path {checkpoint_index_file} for an optimizer. \
+ Lacking param group file under current directory."
+ )
+ saved_groups = torch.load(param_group_path)
+
+ updated_groups = []
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ # obtain updated param group
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = old_pg["params"] # The parameters in the same group shouln't change.
+ updated_groups.append(new_pg)
+ # ep extra group
+ if MOE_MANAGER.parallel == "EP":
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = optimizer.optim.param_groups[-1][
+ "params"
+ ] # Only keep the parameters kept by current pipeline stage.
+ for param in new_pg["params"]:
+ param.data = param.data.to(torch.float32)
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
+
+ # Load saved states to optimizer.
+ # Keep a record of loaded files so that file will not be repeatedly loaded.
+ loaded_file = set()
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ if param is None:
+ continue
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ if param_id not in weight_map:
+ continue
+ filename = weight_map[param_id]
+
+ # If this param's states has been loaded before, directly return.
+ if filename in loaded_file:
+ continue
+
+ file_path = os.path.join(ckpt_root_path, filename)
+ state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
+ load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
+ loaded_file.add(filename)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ device = param.device
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.pre_load_optim(
+ state,
+ param,
+ current_shape=working_param.shape,
+ original_shape=original_shape,
+ device=device,
+ inplace=True,
+ )
+ optimizer.optim.state[param] = sharded_state
+
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
+ dist.barrier()
+
+ def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
+ """
+ Load optimizer from a file with given path.
+
+ Args:
+ optimizer (OptimizerWrapper): The optimizer to be loaded.
+ checkpoint_index_file (str): Path to the checkpoint file.
+ """
+
+ def _get_param_id_from_optimizer_param(
+ param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
+ ):
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ if id(working_param) in optimizer.param_info["param2id"]:
+ return optimizer.param_info["param2id"][id(working_param)]
+ else:
+ None
+
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
+
+ # Complete optimizer state_dict loaded from checkpoint, need to be processed later.
+ state_dict = load_state_dict(checkpoint)
+
+ # Load param_groups.
+ updated_groups = []
+ saved_groups = state_dict["param_groups"]
+ for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = old_pg["params"] # Only keep the parameters kept by current pipeline stage.
+ updated_groups.append(new_pg)
+ # ep extra group
+ if MOE_MANAGER.parallel == "EP":
+ new_pg = copy.deepcopy(saved_pg)
+ new_pg["params"] = optimizer.optim.param_groups[-1][
+ "params"
+ ] # Only keep the parameters kept by current pipeline stage.
+ for param in new_pg["params"]:
+ param.data = param.data.to(torch.float32)
+ updated_groups.append(new_pg)
+ optimizer.optim.__dict__.update({"param_groups": updated_groups})
+
+ # Load saved states to optimizer. First discard those states not belonging to current pipeline stage.
+ master_to_working_map = optimizer.get_master_to_working_map()
+ id_map = {}
+ for pg in optimizer.optim.param_groups:
+ for param in pg["params"]:
+ param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
+ if param_id is not None:
+ id_map[param_id] = param
+ load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
+
+ # Then shard the loaded optimizer states if using tp/zero.
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+ device = param.device
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+ original_shape = optimizer.param_info["param2shape"][id(working_param)]
+ sharded_state = self.pre_load_optim(
+ state,
+ param,
+ current_shape=working_param.shape,
+ original_shape=original_shape,
+ device=device,
+ inplace=True,
+ )
+ optimizer.optim.state[param] = sharded_state
+ sharded_optimizer_loading_epilogue(optimizer.optim)
+ dist.barrier()
+
+ def pre_save_optim(
+ self,
+ state: OrderedDict,
+ param: torch.Tensor,
+ inplace: bool,
+ device: torch.device = torch.device("cpu"),
+ ) -> OrderedDict:
+ """
+ With given parameter and its optimizer states, gather the complete optimizer state for saving.
+
+ Args:
+ state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
+ param (torch.Tensor): The given parameter. It should be working_param when using Zero.
+ original_shape (torch.Size): The size of parameter before sharding.
+ dp_group (ProcessGroup): The process group of data parallel.
+ tp_group (ProcessGroup): The process group of tensor parallel.
+ use_zero (bool): Whether Zero is used.
+ inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
+ device (torch.device): The destination device of loaded optimizer states. Defaults to torch.device('cpu').
+
+ Returns:
+ OrderedDict: The complete optimizer state of given parameter.
+ """
+ if is_moe_tensor(param):
+ moe_dp_group = get_dp_group(param)
+ moe_dp_size = get_dp_size(param)
+ moe_ep_group = get_ep_group(param)
+ moe_ep_size = get_ep_size(param)
+ state_ = state if inplace else copy.deepcopy(state)
+
+ for k, v in state_.items():
+ if isinstance(v, torch.Tensor) and k != "step":
+ # moe param
+ if is_moe_tensor(param):
+ # dp gather
+ v = v.cuda()
+ gather_tensor = [torch.zeros_like(v) for _ in range(moe_dp_size)]
+ dist.all_gather(gather_tensor, v, group=moe_dp_group)
+ v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
+ # ep gather
+ gather_tensor = [torch.zeros_like(v) for _ in range(moe_ep_size)]
+ dist.all_gather(gather_tensor, v, group=moe_ep_group)
+ v = torch.cat(gather_tensor, dim=0)
+ else:
+ # global dp
+ v = v.cuda()
+ gather_tensor = [torch.zeros_like(v) for _ in range(dist.get_world_size(self.dp_group))]
+ dist.all_gather(gather_tensor, v, group=self.dp_group)
+ v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
+
+ state_[k] = v.detach().clone().to(device)
+
+ return state_
+
+ def _optimizer_sharder(
+ self,
+ optimizer: OptimizerWrapper,
+ size_per_shard: int = 1024,
+ ):
+ # An internel method that breaks state_dict of optimizer into shards within limited size.
+
+ state_dict_sharder = StateDictSharder(size_per_shard)
+ param_info = optimizer.param_info
+ master_to_working_map = optimizer.get_master_to_working_map()
+
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ param_id = param_info["param2id"][id(working_param)]
+ state_ = self.pre_save_optim(
+ state,
+ working_param,
+ inplace=False,
+ device=torch.device("cuda"),
+ )
+
+ block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
+ if block is not None:
+ yield block, block_size
+
+ # Return the last block in sharder.
+ yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
+
+ def save_sharded_optimizer(
+ self,
+ optimizer: OptimizerWrapper,
+ checkpoint: str,
+ gather_dtensor: bool = True,
+ prefix: Optional[str] = None,
+ size_per_shard: int = 1024,
+ ):
+ """
+ Save sharded optimizer checkpoint under the given checkpointing path.
+ The following files will be created under the path:
+ - An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
+ - A group file (pytorch_optim_group.bin) recording information of param_groups
+ - Multiple files that store state tensors of optimizers.
+ If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.-stage-000XX-shard-000XX.bin".
+ If pipeline parallelism is not used, "pytorch_optim.-000XX.bin"
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
+ checkpoint (str): Path to save optimizer state_dict
+ gather_dtensor (bool): Whether to gather_dtensor, not used
+ prefix (str): Perfix of file to save
+ size_per_shard (int): Max file size of each file shard that store state tensors
+ """
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+ if os.path.isfile(checkpoint):
+ logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
+ return
+
+ Path(checkpoint).mkdir(parents=True, exist_ok=True)
+
+ # Devices along the same dp_group share the same copies of states when zero is not used.
+ # In this case only let the device with dp_rank == 0 save the model.
+ if not self.use_zero and self.dp_rank != 0:
+ return
+
+ # Then collect the sharded states along dp_group(if using zero)/tp_group.
+ # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
+ state_dict_shard = self._optimizer_sharder(
+ optimizer,
+ size_per_shard=size_per_shard,
+ )
+ states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
+ index_file = CheckpointIndexFile(checkpoint)
+ control_saving = self.dp_rank == 0 and self.tp_rank == 0
+ if self.pp_size == 1:
+ # When pipeline is not used, save the optimizer shards as in general checkpointIO
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ )
+
+ if control_saving:
+ # Store param groups.
+ index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(optimizer.param_info, group_file_path)
+ # Store index file.
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The optimizer is going to be split to checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ else:
+ # When pipeline is used, each stage produces its own shard files and index files.
+ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
+ # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
+
+ final_index_file_path = copy.deepcopy(save_index_file)
+ tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
+ Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
+
+ # Manage filenames of sharded weights and index file for each pipeline stage.
+ states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
+ save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
+ save_index_file = os.path.join("tmp_index_files", save_index_file)
+
+ total_size = save_state_dict_shards(
+ sharded_state_dict=state_dict_shard,
+ checkpoint=checkpoint,
+ index_file=index_file,
+ base_filename=states_name,
+ is_master=control_saving,
+ use_pp_format=True,
+ )
+
+ if control_saving:
+ assert (
+ self.dp_rank == 0 and self.tp_rank == 0
+ ), "The saving process should have both dp_rank and tp_rank as 0."
+ index_file.append_meta_data("total_size", total_size)
+ index_file.write_index_file(save_index_file)
+ else:
+ return
+
+ dist.barrier(self.pp_group)
+
+ # The global master rank integrates the index files and clean the folder.
+ if self.pp_rank == 0:
+ final_index_file = CheckpointIndexFile(checkpoint)
+ final_index_file.append_meta_data("total_size", 0)
+
+ for filename in os.listdir(tmp_index_file_folder):
+ stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
+ final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
+ for param_id, state_filename in stage_index_file.weight_map.items():
+ final_index_file.append_weight_map(param_id, state_filename)
+
+ # Store param groups.
+ final_index_file.append_meta_data("param_groups", param_group_file)
+ group_file_path = os.path.join(checkpoint, param_group_file)
+ save_param_groups(optimizer.param_info, group_file_path)
+
+ final_index_file.write_index_file(final_index_file_path)
+ rmtree(tmp_index_file_folder)
+
+ if self.verbose and self.coordinator.is_master():
+ logging.info(
+ f"The model is split into checkpoint shards. "
+ f"You can find where each parameters has been saved in the "
+ f"index located at {final_index_file_path}."
+ )
+
+ def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
+ """
+ Save optimizer state dict to a file with given path.
+
+ Args:
+ optimizer (OptimizerWrapper): Optimizer to save sharded state_dict.
+ checkpoint (str): Path to save optimizer state_dict.
+ gather_dtensor (bool): Whether to gather_dtensor, not used.
+ """
+ if self.coordinator.is_master():
+ logging.warning("Please avoid using unsharded checkpointing methods when dealing with large models!")
+
+ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
+
+ # optimizer states of parameters kept by local device('s pipeline stage)
+ local_states = dict()
+
+ for param, state in optimizer.optim.state.items():
+ if param is None:
+ continue
+
+ # working param is needed for obtaining correct param_id
+ master_to_working_map = optimizer.get_master_to_working_map()
+ if master_to_working_map is not None and id(param) in master_to_working_map:
+ working_param = master_to_working_map[id(param)]
+ else:
+ working_param = param
+
+ # gather complete state from tp shards & dp shards
+ param_id = optimizer.param_info["param2id"][id(working_param)]
+ local_states[param_id] = self.pre_save_optim(
+ state,
+ working_param,
+ inplace=False,
+ device=torch.device("cuda"),
+ )
+
+ if self.pp_size == 1:
+ # When pipeline is not used, let master rank directly save the collected state_dict.
+ state_dict = {"param_groups": optimizer.optim.param_groups, "state": local_states}
+ if self.coordinator.is_master():
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+ else:
+ # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
+ states_list = [None for _ in range(self.pp_size)]
+ dist.barrier(self.pp_group)
+ dist.all_gather_object(states_list, local_states, self.pp_group)
+
+ # Only the master rank do the saving.
+ if self.coordinator.is_master():
+ state_dict = {"param_groups": optimizer.optim.param_groups, "state": dict()}
+ for _states in states_list:
+ state_dict["state"].update(_states)
+ save_state_dict(state_dict, checkpoint, use_safetensors=False)
+ dist.barrier()
diff --git a/colossalai/moe/experts.py b/colossalai/moe/experts.py
new file mode 100644
index 000000000000..477b76547c7e
--- /dev/null
+++ b/colossalai/moe/experts.py
@@ -0,0 +1,157 @@
+import math
+from typing import Callable, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
+from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.moe.utils import get_activation
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.moe_tensor.api import get_ep_rank, get_ep_size, set_moe_tensor_info
+
+if HAS_TRITON:
+ from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
+
+
+class MLPExperts(nn.Module):
+ """
+ SparseMLP is a multi-layer perceptron with sparse expert parallel layers.
+
+ Args:
+ num_experts (int): The number of experts
+ hidden_size (int): The hidden size of MLP
+ intermediate_size (int): The intermediate size of MLP
+ expert_parallel (str, optional): The parallelism of experts. Now we have None, EP and TP.
+ activation (optional): The activation function of MLP
+ drop_rate (float, optional): The drop rate of MLP
+ gated (bool, optional): Whether to use gated MLP
+ use_kernel (bool, optional): Whether to use kernel optimization
+ """
+
+ def __init__(
+ self,
+ num_experts: int,
+ hidden_size: int,
+ intermediate_size: int,
+ expert_parallel: Optional[str] = None,
+ activation: Optional[Callable] = None,
+ drop_rate: Optional[float] = 0,
+ gated: Optional[bool] = False,
+ use_kernel: Optional[bool] = False,
+ ):
+ super().__init__()
+ assert expert_parallel in ["EP", "TP", None]
+ self.expert_parallel = expert_parallel
+ self.num_total_experts = num_experts
+ self.gated = gated
+ self.use_kernel = use_kernel
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+
+ # get expert parallel info
+ if expert_parallel is not None:
+ self.num_local_experts, self.moe_info = MOE_MANAGER.get_info(
+ num_experts, use_tp=True if expert_parallel == "TP" else False
+ )
+ # get settings for different parallel
+ self.ep_size = get_ep_size(self)
+ if expert_parallel == "TP":
+ intermediate_size = intermediate_size // self.ep_size
+ num_experts = self.num_total_experts
+ else:
+ num_experts = self.num_local_experts
+ else:
+ self.num_local_experts = self.num_total_experts
+ self.ep_size = 1
+
+ if gated:
+ self.wi_gate = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size * 2))
+ self.wi_up = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
+ else:
+ self.wi = nn.Parameter(torch.empty(num_experts, hidden_size, intermediate_size))
+ self.wo = nn.Parameter(torch.empty(num_experts, intermediate_size, hidden_size))
+
+ self.act_name = activation
+ self.act = get_activation(activation)
+ self.drop = nn.Dropout(p=drop_rate)
+
+ if expert_parallel is not None:
+ for param in self.parameters():
+ set_moe_tensor_info(param, self.moe_info)
+
+ # init param
+ self.reset_parameters()
+
+ @torch.no_grad()
+ def reset_parameters(self):
+ # expert param should be different
+ if self.expert_parallel is not None:
+ seed_ctx = Randomizer(get_ep_rank(self)).fork_rng(enable_cpu=True)
+ else:
+ seed_ctx = Randomizer(42).fork_rng(enable_cpu=True)
+ with seed_ctx:
+ if self.gated:
+ torch.nn.init.normal_(self.wi_gate, std=math.sqrt(0.1 / self.hidden_size))
+ torch.nn.init.normal_(self.wi_up, std=math.sqrt(0.1 / self.hidden_size))
+ else:
+ torch.nn.init.normal_(self.wi, std=math.sqrt(0.1 / self.hidden_size))
+ torch.nn.init.normal_(self.wo, std=math.sqrt(0.1 / self.intermediate_size))
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ param_slice: Tuple[slice] = (slice(None),),
+ use_sparse: bool = True,
+ ) -> torch.Tensor:
+ """
+ forward: hidden_size --> intermediate_size --> hidden_size
+
+ Args:
+ x (torch.Tensor): The input tensor of shape (num_groups, num_experts, capacity, hidden_size)
+
+ Returns:
+ torch.Tensor: The output tensor of shape (num_groups, num_experts, capacity, hidden_size)
+ """
+ x = MoeInGradScaler.apply(x, self.ep_size)
+
+ e = x.size(1)
+ h = x.size(-1)
+
+ x = x.transpose(0, 1)
+ inshape = x.shape
+ x = x.reshape(e, -1, h)
+
+ if self.use_kernel and use_sparse:
+ seq_len = x.shape[1]
+ with torch.no_grad():
+ mask = x[:, :, 0] != 0.0
+ mask = torch.sum(mask, dim=-1)
+ x_list = []
+ for i in range(e):
+ x_list.append(x[i, : mask[i]])
+ x = x_list
+
+ if self.gated:
+ x_gate = [torch.mm(x[i], self.wi_gate[param_slice][i]) for i in range(e)]
+ x_up = [torch.mm(x[i], self.wi_up[param_slice][i]) for i in range(e)]
+ if self.use_kernel and HAS_TRITON and self.act_name == "swiglu":
+ x = [LlamaActCombine.apply(x_gate[i], x_up[i]) for i in range(e)]
+ else:
+ x = [self.act(x_gate[i]) * x_up[i] for i in range(e)]
+ else:
+ x = [torch.mm(x[i], self.wi[param_slice][i]) for i in range(e)]
+ x = [self.act(x[i]) for i in range(e)]
+ x = [self.drop(x[i]) for i in range(e)]
+ x = [torch.mm(x[i], self.wo[param_slice][i]) for i in range(e)]
+
+ if self.use_kernel and use_sparse:
+ for i in range(e):
+ x[i] = torch.nn.functional.pad(x[i], (0, 0, 0, seq_len - x[i].shape[0]), mode="constant", value=0)
+
+ x = torch.cat([x[i].unsqueeze(0) for i in range(e)], dim=0)
+ x = x.reshape(inshape)
+ x = x.transpose(0, 1).contiguous()
+ x = MoeOutGradScaler.apply(x, self.ep_size)
+ return x
diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py
new file mode 100644
index 000000000000..b768fb94a585
--- /dev/null
+++ b/colossalai/moe/layers.py
@@ -0,0 +1,392 @@
+import dataclasses
+import math
+from typing import Any, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+
+from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter
+from colossalai.moe.experts import MLPExperts
+from colossalai.moe.load_balance import LoadBalancer
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.moe.routers import MoeRouter, get_router_cls
+from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
+from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
+
+
+class SparseMLP(nn.Module):
+ """A class for users to create MoE modules in their models.
+
+ Args:
+ dim_model (int): Hidden dimension of training model
+ num_experts (int): The number experts
+ top_k (int, optional): The number of experts for dispatchment of each token
+ capacity_factor_train (float, optional): Capacity factor in routing during training
+ capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
+ min_capacity (int, optional): The minimum number of the capacity of each expert
+ noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
+ 'Jitter' can be found in `Switch Transformer paper`_.
+ 'Gaussian' can be found in `ViT-MoE paper`_.
+ drop_tks (bool, optional): Whether drops tokens in evaluation
+ use_residual (bool, optional): Makes this MoE layer a Residual MoE.
+ More information can be found in `Microsoft paper`_.
+ residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
+ expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
+ expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
+ expert_args (optional): The args of expert when no instance is given
+
+ .. _Switch Transformer paper:
+ https://arxiv.org/abs/2101.03961
+ .. _ViT-MoE paper:
+ https://arxiv.org/abs/2106.05974
+ .. _Microsoft paper:
+ https://arxiv.org/abs/2201.05596
+ """
+
+ def __init__(
+ self,
+ num_experts: int,
+ hidden_size: int,
+ intermediate_size: int,
+ router_top_k: int = 1,
+ router_capacity_factor_train: float = 1.25,
+ router_capacity_factor_eval: float = 2.0,
+ router_min_capacity: int = 4,
+ router_noisy_policy: Optional[str] = None,
+ router_drop_tks: bool = True,
+ mlp_activation: Optional[str] = None,
+ mlp_gated: bool = False,
+ enable_load_balance: bool = False,
+ load_balance_tolerance: float = 0.1,
+ load_balance_beam_width: int = 8,
+ load_balance_group_swap_factor: float = 0.4,
+ enable_kernel: bool = False,
+ enable_comm_overlap: bool = False,
+ enable_hierarchical_comm: bool = False,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_experts = num_experts
+ self.gated = mlp_gated
+ self.enable_kernel = enable_kernel
+ self.enable_comm_overlap = enable_comm_overlap
+ self.expert_parallel = MOE_MANAGER.get_parallel()
+
+ # moe router
+ noisy_func = get_noise_generator(router_noisy_policy, num_experts)
+ router_cls = get_router_cls(router_top_k)
+ self.topk = router_top_k
+ self.router: MoeRouter = router_cls(
+ capacity_factor_train=router_capacity_factor_train,
+ capacity_factor_eval=router_capacity_factor_eval,
+ min_capacity=router_min_capacity,
+ noisy_func=noisy_func,
+ drop_tks=router_drop_tks,
+ )
+
+ # gate
+ self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, self.hidden_size))
+
+ # moe experts
+ self.experts = MLPExperts(
+ num_experts=self.num_experts,
+ expert_parallel=self.expert_parallel,
+ hidden_size=self.hidden_size,
+ intermediate_size=self.intermediate_size,
+ activation=mlp_activation,
+ gated=mlp_gated,
+ use_kernel=self.enable_kernel,
+ )
+
+ # get parallel settings
+ if self.expert_parallel is not None:
+ self.ep_group = get_ep_group(self.experts)
+ self.ep_size = get_ep_size(self.experts)
+ self.ep_hierarchical_group = None
+ if enable_hierarchical_comm:
+ self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
+ get_ep_group_ranks(self.experts)
+ )
+ self.dp_group = get_dp_group(self.experts)
+ else:
+ self.ep_group = None
+ self.dp_group = None
+ self.num_local_experts = self.experts.num_local_experts
+
+ # load balance
+ self.enable_load_balance = enable_load_balance
+ if self.enable_load_balance == True:
+ self.load_balancer = LoadBalancer(
+ experts=self.experts,
+ gate=self.gate_weight,
+ local_expert_num=self.num_local_experts,
+ expert_num=self.num_experts,
+ ep_group=self.ep_group,
+ dp_group=self.dp_group,
+ tolerance=load_balance_tolerance,
+ beam_width=load_balance_beam_width,
+ group_swap_factor=load_balance_group_swap_factor,
+ )
+
+ # init param
+ self.reset_parameters()
+
+ @torch.no_grad()
+ def reset_parameters(self):
+ torch.nn.init.normal_(self.gate_weight, std=math.sqrt(0.1 / self.hidden_size))
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ inputs (torch.Tensor): The input tensor of shape (batch_size, seq_len, hidden_size)
+
+ Returns:
+ torch.Tensor: The output tensor of shape (batch_size, seq_len, hidden_size)
+ """
+ # reshape the input tokens
+ tokens = inputs.reshape(-1, self.hidden_size)
+
+ # the data type of the inputs in the gating should be fp32
+ fp32_input = tokens.to(torch.float)
+ fp32_weight = self.gate_weight.to(torch.float)
+ gate_output = F.linear(fp32_input, fp32_weight)
+
+ # update expert load
+ if self.enable_load_balance == True:
+ with torch.no_grad():
+ # TODO: optimize computation
+ expert_load = torch.topk(gate_output, k=self.topk, dim=-1)[1]
+ # TODO: bincount introduces synchronize, fix it
+ expert_load = torch.bincount(expert_load.view(-1))
+ self.load_balancer.update_load(expert_load)
+
+ # the result from the router
+ used_capacity, *route_result_list = self.router(
+ inputs=gate_output, use_kernel=self.enable_kernel, ep_group=self.ep_group)
+
+ # dispatch_data: (num_experts, capacity, hidden_size)
+ if self.enable_kernel:
+ dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
+ dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.hidden_size)
+ else:
+ sec_mask_f = route_result_list[1].type_as(inputs)
+ dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
+
+ # expert_output: (num_groups, num_experts, capacity, hidden_size)
+ if self.expert_parallel == "EP":
+ expert_output = self._ep_process(
+ dispatch_data,
+ used_capacity,
+ overlap=self.enable_comm_overlap
+ )
+ elif self.expert_parallel == "TP":
+ expert_output = self._tp_process(
+ dispatch_data,
+ used_capacity,
+ overlap=self.enable_comm_overlap
+ )
+ elif self.expert_parallel is None:
+ expert_output = self._local_process(dispatch_data)
+ else:
+ raise NotImplementedError("This kind of communication has not been implemented yet.\n"
+ "Please use Experts build function.")
+
+ if self.enable_kernel:
+ expert_output = expert_output.reshape(-1, self.hidden_size)
+ ans = MoeCombine.apply(expert_output, *route_result_list)
+ else:
+ combine_weights = route_result_list[0].type_as(inputs)
+ combine_weights = combine_weights.view(combine_weights.shape[0], -1)
+ expert_output = expert_output.view(-1, expert_output.shape[-1])
+ ans = torch.matmul(combine_weights, expert_output)
+
+ ans = ans.reshape(inputs.shape)
+ return ans
+
+ def _local_process(self, expert_in: torch.Tensor) -> torch.Tensor:
+ expert_in = expert_in.unsqueeze(0)
+ expert_out = self.experts(expert_in)
+ return expert_out
+
+ def _ep_process(
+ self,
+ dispatch_data: torch.Tensor,
+ used_capacity: torch.Tensor,
+ overlap: bool = False
+ ) -> torch.Tensor:
+ """
+ Expert Parallel
+
+ Args:
+ dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
+
+ Returns:
+ torch.Tensor: (num_experts, capacity, hidden_size)
+ """
+ if not overlap or dist.get_world_size(self.ep_group) == 1:
+ if self.ep_hierarchical_group is not None:
+ expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
+ expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
+ expert_output = self.experts(expert_input)
+ expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
+ return expert_output
+ else:
+ expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
+ expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
+ expert_output = self.experts(expert_input)
+ expert_output = AllToAll.apply(expert_output, self.ep_group, False)[0]
+ return expert_output
+ else:
+
+ @dataclasses.dataclass
+ class Capsule:
+ data: torch.Tensor
+ handle: Any = None
+
+ NUM_CHUNK = 4
+ NUM_STAGES = 4
+
+ assert (dispatch_data.shape[1] % NUM_CHUNK == 0), "arbitrary chunk num is not supported yet"
+ chunk_size = dispatch_data.shape[1] // NUM_CHUNK
+ input_shape = (self.ep_size, self.num_local_experts, -1, self.hidden_size)
+ dispatch_data = dispatch_data.reshape(*input_shape)
+ chunk_data = torch.split(dispatch_data, chunk_size, dim=2)
+ output = torch.empty_like(dispatch_data)
+
+ offset = 0
+ _expert_in, expert_in, _expert_out, expert_out = None, None, None, None
+
+ for i in range(NUM_CHUNK + NUM_STAGES - 1):
+ if expert_out is not None:
+ expert_out.handle.wait()
+ output[:, :, offset:offset + chunk_size, :] = expert_out.data
+ offset += chunk_size
+ expert_out = None
+
+ # all2all last output
+ if _expert_out is not None:
+ expert_out = Capsule(*AllToAll.apply(_expert_out.data, self.ep_group, True),)
+ _expert_out = None
+
+ # all2all next input
+ if 0 <= i < NUM_CHUNK:
+ _expert_in = Capsule(*AllToAll.apply(chunk_data[i].contiguous(), self.ep_group, True))
+
+ # compute
+ if expert_in is not None:
+ expert_in.handle.wait()
+ _expert_out = Capsule(data=self.experts(expert_in.data), handle=None)
+ expert_in = None
+
+ if _expert_in is not None:
+ expert_in = _expert_in
+ _expert_in = None
+
+ return output
+
+ def _tp_process(
+ self,
+ dispatch_data: torch.Tensor,
+ used_capacity: torch.Tensor,
+ overlap: bool = False
+ ) -> torch.Tensor:
+ """
+ without overlap:
+ | C |
+ | A | | R |
+
+ with overlap:
+ | C1 || C2 || C3 || C4 |
+ | A1 || A2 | | R1 | A3 || R2 | A4 || R3 | | R4 |
+
+ where C is computation, A is all gather, R is reduce scatter.
+
+ Args:
+ dispatch_data (torch.Tensor): (num_experts, capacity, hidden_size)
+
+ Returns:
+ torch.Tensor: (num_experts, capacity, hidden_size)
+ """
+ if not overlap or dist.get_world_size(self.ep_group) == 1:
+ expert_in = AllGather.apply(dispatch_data, self.ep_group, False)[0]
+ expert_out = self.experts(expert_in)
+ expert_out = ReduceScatter.apply(expert_out, self.ep_group, False)[0]
+ return expert_out
+ else:
+
+ @dataclasses.dataclass
+ class Capsule:
+ data: torch.Tensor
+ handle: Any
+ indices: Tuple
+
+ NUM_CHUNK = 4
+ NUM_STAGES = 4
+
+ assert dispatch_data.shape[0] % NUM_CHUNK == 0, \
+ "arbitrary chunk num is not supported yet, please use chunk num that can divide num_experts"
+ chunk_size = dispatch_data.shape[0] // NUM_CHUNK
+ chunk_data = torch.split(dispatch_data, chunk_size, dim=0)
+ output = torch.empty_like(dispatch_data)
+
+ def get_chunk_slice(idx: int, chunk_size: int) -> Tuple[slice]:
+ return (slice(idx * chunk_size, (idx + 1) * chunk_size),)
+
+ _expert_in, expert_in, _expert_out, expert_out = None, None, None, None
+
+ for i in range(NUM_CHUNK + NUM_STAGES - 1):
+ if expert_out is not None:
+ expert_out.handle.wait()
+ output[expert_out.indices] = expert_out.data
+ expert_out = None
+
+ # reduce scatter last output
+ if _expert_out is not None:
+ expert_out = Capsule(
+ *ReduceScatter.apply(_expert_out.data, self.ep_group, True),
+ indices=_expert_out.indices,
+ )
+ _expert_out = None
+
+ # all gather next input
+ if 0 <= i < NUM_CHUNK:
+ _expert_in = Capsule(
+ *AllGather.apply(chunk_data[i].contiguous(), self.ep_group, True),
+ indices=get_chunk_slice(i, chunk_size),
+ )
+
+ # compute
+ if expert_in is not None:
+ expert_in.handle.wait()
+ _expert_out = Capsule(
+ self.experts(expert_in.data, expert_in.indices),
+ handle=None,
+ indices=expert_in.indices,
+ )
+ expert_in = None
+
+ if _expert_in is not None:
+ expert_in = _expert_in
+ _expert_in = None
+
+ return output
+
+
+def apply_load_balance(model: nn.Module, optim: Any) -> None:
+ """
+ apply load balance to every experts in the model
+ """
+
+ def _apply_recursive(module: nn.Module):
+ for _, sub_module in module.named_children():
+ if isinstance(sub_module, SparseMLP):
+ if sub_module.enable_load_balance == True:
+ sub_module.load_balancer.balance_load(optim)
+ _apply_recursive(sub_module)
+
+ torch.cuda.empty_cache()
+ _apply_recursive(model)
+ torch.cuda.empty_cache()
diff --git a/colossalai/moe/load_balance.py b/colossalai/moe/load_balance.py
new file mode 100644
index 000000000000..85c12d73fa52
--- /dev/null
+++ b/colossalai/moe/load_balance.py
@@ -0,0 +1,442 @@
+from copy import deepcopy
+from typing import List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+from torch import Tensor, nn
+from torch.distributed import ProcessGroup
+
+from colossalai.cluster import ProcessGroupMesh
+from colossalai.moe.experts import MLPExperts
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.zero.low_level import LowLevelZeroOptimizer
+
+
+class LoadBalancer:
+ def __init__(
+ self,
+ experts: MLPExperts,
+ gate: nn.Parameter,
+ local_expert_num: int,
+ expert_num: int,
+ ep_group: ProcessGroup,
+ dp_group: ProcessGroup,
+ tolerance: Optional[float] = 0.1,
+ beam_width: Optional[int] = 8,
+ group_swap_factor: Optional[float] = 0.4,
+ ) -> None:
+ self.experts: MLPExperts = experts
+ self.gate: nn.Parameter = gate
+ self.moe_ep_group: ProcessGroup = ep_group
+ self.moe_ep_ranks = MOE_MANAGER.parallel_info_dict[dist.get_world_size(self.moe_ep_group)].ep_group_ranks
+ self.moe_dp_group: ProcessGroup = dp_group
+ self.tolerance = tolerance
+ self.beam_width = beam_width
+ self.group_swap_factor = group_swap_factor
+ self.local_expert_num = local_expert_num
+ self.expert_num = expert_num
+ self.local_load = None
+ # TODO: use a global process group mesh
+ pp_size = 1 if MOE_MANAGER.pp_size is None else MOE_MANAGER.pp_size
+ global_dp_group = ProcessGroupMesh(pp_size, dist.get_world_size() // pp_size)
+ self.global_dp_group = global_dp_group.get_group_along_axis(1)
+ self.global_dp_rank = dist.get_rank(self.global_dp_group)
+ self.global_dp_size = dist.get_world_size(self.global_dp_group)
+
+ def _clear_load(self) -> None:
+ self.local_load = None
+
+ def _sync_load(self) -> Tensor:
+ new_load = self.local_load.clone().detach()
+ # all reduce load between ep group
+ dist.all_reduce(new_load, group=self.moe_ep_group)
+ # all reduce load between dp group
+ dist.all_reduce(new_load, group=self.moe_dp_group)
+ return new_load
+
+ @staticmethod
+ def _get_diff_from_avg(data: List, group: int, avg: float) -> float:
+ return abs(sum(data[group]) / len(data[group]) - avg)
+
+ @staticmethod
+ def _swap_data(data: List, group_i: int, index_i: int, group_j: int, index_j: int) -> None:
+ data[group_i][index_i], data[group_j][index_j] = (
+ data[group_j][index_j],
+ data[group_i][index_i],
+ )
+
+ @staticmethod
+ def _normalize_data(data: List) -> List:
+ max_value = max(max(sublist) for sublist in data)
+ data = [[i / max_value for i in sublist] for sublist in data]
+ return data
+
+ @staticmethod
+ def _get_swap_loss(
+ group_swap_factor: float,
+ swap_list: List,
+ group_i: int,
+ index_i: int,
+ group_j: int,
+ index_j: int,
+ ) -> float:
+ """
+ Get swap loss. The swap loss is used to avoid the situation that
+ the same index is swapped twice and the same group is swapped for multiple times.
+ """
+ swap_loss = 0
+ for swap in swap_list:
+ for group_id, index_id in zip([group_i, group_j], [index_i, index_j]):
+ # the group has been swapped
+ if group_id in [swap[0], swap[2]]:
+ # the index has been swapped
+ # we want to avoid the situation that the same index is swapped twice
+ if index_id in [swap[1], swap[3]]:
+ swap_loss += 1e5
+ # the index has not been swapped
+ # this is acceptable but as less as possible
+ else:
+ swap_loss += group_swap_factor
+ return swap_loss
+
+ @staticmethod
+ def _check_convergence(data: List, avg: float, tolerance: float):
+ """
+ Check whether the data is converged after swap.
+ """
+ for sublist in data:
+ if abs(sum(sublist) / len(sublist) - avg) > tolerance * avg:
+ return False
+ return True
+
+ def _beam_search(
+ self,
+ inputs: Tuple[List, float, List],
+ beam_width: int,
+ avg: float,
+ group_swap_factor: float,
+ ) -> List:
+ """
+ Beam search for the best swap combination.
+ Specifically, we swap two elements from two groups and calculate the score.
+ The score is the difference between the origin group sum and the new group sum.
+ The larger the score, the better the swap combination.
+
+ Args:
+ inputs (Tuple): (data, origin_score, swap_list)
+ beam_width (int): beam width for beam search
+ avg (float): average value of the data
+ group_swap_factor (float): group loss for group swap loss
+
+ Returns:
+ List: results list
+ """
+ data, origin_score, swap_list = inputs
+ results = []
+ group_num = len(data)
+ group_size = len(data[0])
+ origin_diff_list = [self._get_diff_from_avg(data, i, avg) for i in range(group_num)]
+
+ for group_num_i in range(group_num):
+ for group_size_i in range(group_size):
+ for group_num_j in range(group_num_i + 1, group_num):
+ for group_size_j in range(group_size):
+ new_data = deepcopy(data)
+ # calculate origin group sum
+ origin_diff = origin_diff_list[group_num_i] + origin_diff_list[group_num_j]
+ # swap data
+ self._swap_data(
+ new_data,
+ group_num_i,
+ group_size_i,
+ group_num_j,
+ group_size_j,
+ )
+ # calculate new group sum
+ new_diff = self._get_diff_from_avg(new_data, group_num_i, avg) + self._get_diff_from_avg(
+ new_data, group_num_j, avg
+ )
+ # caculate score
+ new_score = origin_diff - new_diff
+ if new_score > 0:
+ new_score = origin_score + new_score
+ # get swap loss
+ swap_loss = self._get_swap_loss(
+ group_swap_factor,
+ swap_list,
+ group_num_i,
+ group_size_i,
+ group_num_j,
+ group_size_j,
+ )
+ new_score = new_score - swap_loss
+ # update swap list
+ new_swap_list = swap_list + [(group_num_i, group_size_i, group_num_j, group_size_j)]
+ results.append((new_data, new_score, new_swap_list))
+ # sort results
+ results.sort(key=lambda x: x[1], reverse=True)
+ # select top k results
+ results = results[:beam_width]
+ return results
+
+ def _load_to_list(self, load: Tensor) -> List:
+ load_len = len(load)
+ assert load_len % self.local_expert_num == 0
+ load_list = []
+ tmp_list = []
+ for i in range(len(load)):
+ tmp_list.append(float(load[i]))
+ if (i + 1) % self.local_expert_num == 0:
+ load_list.append(tmp_list)
+ tmp_list = []
+ return load_list
+
+ def _search_balance(
+ self,
+ data: List,
+ tolerance: Optional[float] = 0.1,
+ beam_width: Optional[int] = 8,
+ group_swap_factor: Optional[float] = 0.4,
+ return_swapped_data: Optional[bool] = False,
+ ) -> Tuple[List, List]:
+ """
+ Search for the best swap combination to balance the data within the specified tolerance.
+ And return the balanced data and the swap list. The swap list is used to record the swap.
+ The swap list is a list of tuples. Each tuple is a swap operation.
+
+ Args:
+ data (List): expert load list.
+ E.g. [[9.2, 8.3], [2.3, 10.0], [6.1, 7.2], [5.3, 3.2]]
+ This means there are 4 devices and each devices has 2 experts.
+ The value is the load of the expert.
+ tolerance (float): tolerance for balance.
+ beam_width (int): beam width for beam search.
+ group_swap_factor (float): group swap factor for group swap loss.
+ The bigger it is, the less times a group will be swapped.
+ return_swapped_data (bool): whether to return the swapped data.
+
+ Returns:
+ Tuple: (balanced data, swap list).
+ The swap list is a list of tuples. Each tuple is a swap operation.
+ E.g. [(0, 0, 1, 0), (...), (...)]. The first tuple means
+ the first expert of the first device is swapped with the first expert
+ of the second device.
+ """
+ norm_data = self._normalize_data(data)
+ avg = sum(sum(sublist) / len(sublist) for sublist in norm_data) / len(norm_data)
+ results = [(norm_data, 0, [])]
+ stop_flag = False
+
+ while stop_flag == False:
+ new_results = []
+ best_score = results[0][1]
+ for i in range(len(results)):
+ new_results.extend(self._beam_search(results[i], beam_width, avg, group_swap_factor))
+ if len(new_results) == 0:
+ stop_flag = True
+ break
+ new_results.sort(key=lambda x: x[1], reverse=True)
+ new_best_score = new_results[0][1]
+ if new_best_score == best_score:
+ stop_flag = True
+ break
+ new_results = new_results[:beam_width]
+ results = new_results
+ for i in results:
+ if self._check_convergence(results[0][0], avg, tolerance):
+ stop_flag = True
+ break
+
+ swap_list = results[0][2]
+ if return_swapped_data:
+ out = deepcopy(data)
+ for swap in swap_list:
+ self._swap_data(out, *swap)
+ return out, swap_list
+ else:
+ return swap_list
+
+ @staticmethod
+ def _swap_expert_single_tensor(
+ weight: nn.Parameter,
+ expert_idx: int,
+ comm_group: ProcessGroup,
+ send_first: bool,
+ comm_rank: int,
+ ):
+ # exchange weight
+ local_weight = weight.data[expert_idx]
+ new_weight = torch.empty_like(local_weight)
+ if send_first:
+ dist.send(local_weight, dst=comm_rank, group=comm_group)
+ dist.recv(new_weight, src=comm_rank, group=comm_group)
+ else:
+ dist.recv(new_weight, src=comm_rank, group=comm_group)
+ dist.send(local_weight, dst=comm_rank, group=comm_group)
+ weight.data[expert_idx] = new_weight
+
+ def _swap_expert_param_and_optim(
+ self,
+ weight: nn.Parameter,
+ expert_idx: int,
+ comm_group: ProcessGroup,
+ send_first: bool,
+ comm_rank: int,
+ optim: LowLevelZeroOptimizer,
+ ):
+ # need to update master and working param if master param exists
+ # else just update working param
+ if weight in optim.optim.state:
+ master_weight_ptr = None
+ working_weight_ptr = weight
+ exp_avg_ptr = optim.optim.state[working_weight_ptr]["exp_avg"]
+ exp_avg_sq_ptr = optim.optim.state[working_weight_ptr]["exp_avg_sq"]
+ else:
+ master_weight_ptr = optim._param_store.working_to_master_param[id(weight)]
+ working_weight_ptr = weight
+ exp_avg_ptr = optim.optim.state[master_weight_ptr]["exp_avg"]
+ exp_avg_sq_ptr = optim.optim.state[master_weight_ptr]["exp_avg_sq"]
+
+ # exchange weight
+ self._swap_expert_single_tensor(
+ working_weight_ptr,
+ expert_idx,
+ comm_group,
+ send_first,
+ comm_rank,
+ )
+ if master_weight_ptr is not None:
+ # TODO: exchange master weight, skip for now
+ # master weight is shared by dp group
+ tmp = working_weight_ptr.view(-1).split(
+ working_weight_ptr.numel() // dist.get_world_size(self.moe_dp_group)
+ )[dist.get_rank(self.moe_dp_group)]
+ master_weight_ptr.data.copy_(tmp.clone().detach().to(master_weight_ptr.device).to(master_weight_ptr.dtype))
+ # exchange optim
+ self._swap_expert_single_tensor(exp_avg_ptr, expert_idx, comm_group, send_first, comm_rank)
+ self._swap_expert_single_tensor(exp_avg_sq_ptr, expert_idx, comm_group, send_first, comm_rank)
+
+ def _gather_global_dp_group(self, data: Tensor) -> Tensor:
+ data_list = [torch.zeros_like(data) for _ in range(self.global_dp_size)]
+ dist.all_gather(data_list, data, group=self.global_dp_group)
+ data_list = torch.cat(data_list, dim=0)
+ return data_list
+
+ def _swap_moe_param(self, swap_list: List, optim: LowLevelZeroOptimizer) -> None:
+ """
+ Swap moe param and optim.
+ We use different strategies to swap expert and gate.
+ For expert, we exchange the param and optim of the expert by p2p.
+ For gate, we all gather the gate choose the part we want.
+
+ Args:
+ swap_list (List)
+ optim (LowLevelZeroOptimizer)
+ """
+ # get all experts weights
+ local_rank = dist.get_rank(self.moe_ep_group)
+ if self.experts.gated:
+ weight_list = [self.experts.wi_up, self.experts.wi_gate]
+ else:
+ weight_list = [self.experts.wi]
+ weight_list.append(self.experts.wo)
+
+ # gate optim should be obtained first
+ gate_shape = self.gate.shape
+ # get master weight and optim
+ master_gate_weight = optim._param_store.working_to_master_param[id(self.gate)]
+ gate_exp_avg = optim.optim.state[master_gate_weight]["exp_avg"]
+ gate_exp_avg_sq = optim.optim.state[master_gate_weight]["exp_avg_sq"]
+ # gather
+ global_master_gate_weight = self._gather_global_dp_group(master_gate_weight).view(gate_shape)
+ global_gate_exp_avg = self._gather_global_dp_group(gate_exp_avg).view(gate_shape)
+ global_gate_exp_avg_sq = self._gather_global_dp_group(gate_exp_avg_sq).view(gate_shape)
+ assert (
+ self.gate.shape
+ == global_master_gate_weight.shape
+ == global_gate_exp_avg.shape
+ == global_gate_exp_avg_sq.shape
+ )
+
+ for swap in swap_list:
+ source_group, source_idx, target_group, target_idx = swap
+ source_rank = self.moe_ep_ranks[source_group]
+ target_rank = self.moe_ep_ranks[target_group]
+ # exchange expert
+ if local_rank in [source_group, target_group]:
+ for weight in weight_list:
+ if local_rank == source_group:
+ self._swap_expert_param_and_optim(
+ weight,
+ source_idx,
+ self.moe_ep_group,
+ True,
+ target_rank,
+ optim,
+ )
+ elif local_rank == target_group:
+ self._swap_expert_param_and_optim(
+ weight,
+ target_idx,
+ self.moe_ep_group,
+ False,
+ source_rank,
+ optim,
+ )
+ # exchange gate
+ source_expert_pos = source_group * self.local_expert_num + source_idx
+ target_expert_pos = target_group * self.local_expert_num + target_idx
+ for gate in [
+ self.gate,
+ global_master_gate_weight,
+ global_gate_exp_avg,
+ global_gate_exp_avg_sq,
+ ]:
+ origin_source = gate.data[source_expert_pos].clone().detach()
+ origin_target = gate.data[target_expert_pos].clone().detach()
+ gate.data[source_expert_pos], gate.data[target_expert_pos] = (
+ origin_target,
+ origin_source,
+ )
+
+ # update gate
+ global_master_gate_weight = global_master_gate_weight.view(-1).split(
+ global_master_gate_weight.numel() // self.global_dp_size
+ )[self.global_dp_rank]
+ master_gate_weight.data.copy_(global_master_gate_weight)
+ global_gate_exp_avg = global_gate_exp_avg.view(-1).split(global_gate_exp_avg.numel() // self.global_dp_size)[
+ self.global_dp_rank
+ ]
+ gate_exp_avg.data.copy_(global_gate_exp_avg)
+ global_gate_exp_avg_sq = global_gate_exp_avg_sq.view(-1).split(
+ global_gate_exp_avg_sq.numel() // self.global_dp_size
+ )[self.global_dp_rank]
+ gate_exp_avg_sq.data.copy_(global_gate_exp_avg_sq)
+
+ @torch.no_grad()
+ def update_load(self, load: Tensor) -> None:
+ if len(load) != self.expert_num:
+ padding_size = self.expert_num - len(load)
+ padding = torch.zeros(padding_size, dtype=load.dtype, device=load.device)
+ load = torch.cat((load, padding), dim=0)
+ if self.local_load is None:
+ self.local_load = load
+ else:
+ self.local_load += load
+
+ @torch.no_grad()
+ def balance_load(self, optim: LowLevelZeroOptimizer) -> None:
+ # prepare load
+ load = self._sync_load()
+ load = self._load_to_list(load)
+ # search balance
+ swap_list = self._search_balance(load)
+ if dist.get_rank() == 0:
+ if len(swap_list) > 0:
+ print(f"[Load Balance] Applying expert swap...")
+ else:
+ print(f"[Load Balance] Invalid swap, skip...")
+ # swap expert and gate
+ self._swap_moe_param(swap_list, optim)
+ # clear load
+ self._clear_load()
diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/moe/loss.py
similarity index 92%
rename from colossalai/nn/loss/loss_moe.py
rename to colossalai/moe/loss.py
index 40cea788c3c3..75624510b452 100644
--- a/colossalai/nn/loss/loss_moe.py
+++ b/colossalai/moe/loss.py
@@ -1,11 +1,9 @@
import torch.nn as nn
from torch.nn.modules.loss import _Loss
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.legacy.registry import LOSSES
+from colossalai.moe.manager import MOE_MANAGER
-@LOSSES.register_module
class MoeCrossEntropyLoss(_Loss):
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
@@ -45,11 +43,10 @@ def forward(self, *args):
`Cross_entropy `_.
"""
main_loss = self.loss(*args)
- aux_loss = MOE_CONTEXT.get_loss()
+ aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss
-@LOSSES.register_module
class MoeLoss(_Loss):
"""A wrapper class for any loss module to add with auxiliary loss.
@@ -77,5 +74,5 @@ def forward(self, *args, **kwargs):
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
"""
main_loss = self.loss_fn(*args, **kwargs)
- aux_loss = MOE_CONTEXT.get_loss()
+ aux_loss = MOE_MANAGER.get_loss()
return main_loss + self.aux_weight * aux_loss
diff --git a/colossalai/moe/manager.py b/colossalai/moe/manager.py
new file mode 100644
index 000000000000..3e64d796cce7
--- /dev/null
+++ b/colossalai/moe/manager.py
@@ -0,0 +1,163 @@
+from typing import Tuple
+
+import torch
+import torch.distributed as dist
+
+from colossalai.context.singleton_meta import SingletonMeta
+from colossalai.tensor.moe_tensor.api import get_moe_info
+from colossalai.tensor.moe_tensor.moe_info import MoeParallelInfo
+
+
+class MoEManager(metaclass=SingletonMeta):
+ """MoE manager. This class manages different
+ parallel groups in MoE context and MoE loss in training.
+ """
+
+ def __init__(self):
+ self.parallel = None
+ self.mode = None
+ self.use_ep_inside = None
+ self.world_size = None
+ self._parallel_info_dict = dict()
+
+ # router
+ self.router_aux_loss = []
+ self.router_z_loss = []
+
+ # fixed mode
+ self.pp_size = None
+ self.dp_size = None
+ self.ep_size = None
+
+ # dynamic mode
+ # Users may want to set maximum expert parallel size smaller than the world size
+ # since very low bandwidth across nodes may constrain the performance of MoE
+ # When we have a maximum expert parallel size, we have a minimum data parallel size naturally
+ self.max_ep_size = None
+
+ self.has_setup = False
+
+ @property
+ def parallel_info_dict(self):
+ return self._parallel_info_dict
+
+ @property
+ def is_initialized(self):
+ return self.has_setup
+
+ def setup(
+ self,
+ parallel: str = None,
+ mode: str = "dynamic",
+ max_ep_size: int = 8,
+ fixed_dp_size: int = 0,
+ fixed_ep_size: int = 0,
+ fixed_pp_size: int = 0,
+ use_ep_inside: bool = True,
+ ) -> None:
+ """
+ Setup MoE distributed context.
+
+ Args:
+ seed (int): Random seed. Defaults to 42.
+ use_kernel_optim (bool, optional): Use cuda kernel. Defaults to True.
+ parallel (bool, optional): Parallel mode, should be EP, TP or None. Defaults to None.
+ mode (str, optional): Should be "fixed" or "dynamic". Defaults to "dynamic".
+ In fixed mode, the ep size and dp size is fixed.
+ In dynamic mode, the ep size and dp size will be changed according to num experts.
+ max_ep_size (int, optional): Max ep size in dynamic mode. Defaults to 8.
+ fixed_dp_size (int, optional): Fixed dp size in fixed mode. Defaults to 0.
+ fixed_ep_size (int, optional): Fixed ep size in fixed mode. Defaults to 0.
+ fixed_pp_size (int, optional): Fixed pp size in fixed mode. Defaults to 0.
+ use_ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
+ """
+ assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
+ assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
+
+ self.parallel = parallel
+ self.use_ep_inside = use_ep_inside
+ self.world_size = dist.get_world_size()
+
+ # init by mode
+ self.mode = mode
+ assert self.mode in ["fixed", "dynamic"], "mode should be fixed or dynamic"
+ if self.mode == "dynamic":
+ self.max_ep_size = min(max_ep_size, self.world_size)
+ else:
+ assert (
+ fixed_dp_size > 0 and fixed_ep_size > 0 and fixed_pp_size > 0
+ ), "dp_size, ep_size and pp_size should be greater than 0"
+ assert (
+ isinstance(fixed_dp_size, int) and isinstance(fixed_ep_size, int) and isinstance(fixed_pp_size, int)
+ ), "dp_size, ep_size and pp_size should be int"
+ self.ep_size = fixed_ep_size
+ self.dp_size = fixed_dp_size
+ self.pp_size = fixed_pp_size
+
+ self.has_setup = True
+
+ def get_info(self, num_experts: int, use_tp: bool = False) -> Tuple[int, MoeParallelInfo]:
+ """Calculate the Data Parallel Group and Expert Parallel Group.
+
+ Parameters
+ ----------
+ num_experts : int
+ The number experts
+
+ Returns
+ -------
+ int, MoeParallelInfo
+ number of local experts, the MoeParallelInfo of the current ep_size
+ """
+
+ if self.mode == "dynamic":
+ gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
+ lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
+ assert gt_flag or lt_flag, (
+ "Automatic experts placement dose not not support expert number"
+ " is not a multiple of ep size or vice versa."
+ )
+ dp_size = 1 if gt_flag else self.world_size // num_experts
+ ep_size = min(self.world_size // dp_size, self.max_ep_size)
+ dp_size = self.world_size // ep_size
+ pp_size = 1
+ else:
+ dp_size = self.dp_size
+ ep_size = self.ep_size
+ pp_size = self.pp_size
+
+ # Calculate the number of experts for each GPU
+ if use_tp:
+ num_local_experts = num_experts
+ else:
+ if self.mode == "dynamic":
+ num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
+ else:
+ num_local_experts = num_experts // ep_size
+
+ if not (ep_size in self.parallel_info_dict):
+ self.parallel_info_dict[ep_size] = get_moe_info(ep_size, dp_size, pp_size, ep_inside=self.use_ep_inside)
+ if dist.get_rank() == 0:
+ if self.use_ep_inside:
+ print(f"MoE Parallel: pp {pp_size}, dp {dp_size}, ep {ep_size}")
+ else:
+ print(f"MoE Parallel: pp {pp_size}, ep {ep_size}, dp {dp_size}")
+
+ return num_local_experts, self.parallel_info_dict[ep_size]
+
+ def reset_loss(self):
+ self.router_aux_loss, self.router_z_loss = [], []
+
+ def add_loss(self, aux_loss: float = 0.0, z_loss: float = 0.0):
+ self.router_aux_loss.append(aux_loss)
+ self.router_z_loss.append(z_loss)
+
+ def get_loss(self):
+ cur_loss = self.router_aux_loss, self.router_z_loss
+ return cur_loss
+
+ def get_parallel(self):
+ return self.parallel
+
+
+MOE_MANAGER = MoEManager()
diff --git a/colossalai/moe/routers.py b/colossalai/moe/routers.py
new file mode 100644
index 000000000000..c5bb508621b2
--- /dev/null
+++ b/colossalai/moe/routers.py
@@ -0,0 +1,427 @@
+import math
+from abc import ABC
+from typing import Callable, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed import ProcessGroup
+
+from colossalai.moe._operation import moe_cumsum
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.utils import get_current_device
+
+
+class MoeRouter(nn.Module, ABC):
+ """Base class for all MoE routers.
+ Args:
+ k_value (int): The value of top_k.
+ capacity_factor_train (float): Capacity factor in routing of training.
+ capacity_factor_eval (float): Capacity factor in routing of evaluation.
+ min_capacity (int): The minimum number of the capacity of each expert.
+ noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
+ drop_tks (bool, optional): Whether drops tokens in evaluation
+ """
+
+ def __init__(self,
+ k_value: int,
+ capacity_factor_train: float,
+ capacity_factor_eval: float,
+ min_capacity: int,
+ noisy_func: Optional[Callable] = None,
+ drop_tks: bool = True,
+ use_kernel: bool = False):
+ super().__init__()
+ self.k_value = k_value
+ self.capacity_factor_train = capacity_factor_train
+ self.capacity_factor_eval = capacity_factor_eval
+ self.min_capacity = min_capacity
+ self.noisy_func = noisy_func
+ self.drop_tks = drop_tks
+ self._aux_loss = None
+ self._z_loss = None
+ self.use_kernel = use_kernel
+
+ def get_capacity(self, logits_shape):
+ capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
+ capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
+ capacity += capacity % 2
+ capacity = max(capacity, self.min_capacity)
+ assert capacity > 0
+ return int(capacity)
+
+ def set_aux_loss(self, router_probs: torch.Tensor, expert_indices: torch.Tensor, num_experts: int) -> None:
+ """Computes auxiliary load balancing loss as in Switch Transformer.
+
+ See Switch Transformer (https://arxiv.org/abs/2101.03961). This function
+ implements the loss function presented in equations (4) - (6). It aims to
+ penalize those cases where the routing between experts is unbalanced.
+
+ Args:
+ router_probs: Probability assigned to each expert per token. Shape:
+ [num_groups, tokens_per_group, num_experts].
+ expert_indices: [num_groups, tokens_per_group, num_selected_experts]
+ indices identifying the top num_selected_experts for a given token.
+ """
+ assert self._aux_loss is None
+ if router_probs.dim() == expert_indices.dim() == 2:
+ router_probs = router_probs.unsqueeze(0)
+ expert_indices = expert_indices.unsqueeze(0)
+ assert router_probs.dim() == expert_indices.dim() == 3, \
+ "router_probs must be 3D tensor and expert_indices must be 4D tensor"
+
+ # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
+ expert_mask = F.one_hot(expert_indices, num_experts)
+ # For a given token, determine if it was routed to a given expert.
+ # Shape: [num_groups, tokens_per_group, num_experts]
+ expert_mask = expert_mask.max(dim=-2)[0]
+
+ tokens_per_group_and_expert = torch.mean(expert_mask.float(), dim=-2)
+ router_prob_per_group_and_expert = torch.mean(router_probs.float(), dim=-2)
+ aux_loss = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
+ self._aux_loss = aux_loss
+
+ def set_z_loss(self, router_logits: torch.Tensor):
+ """Compute router z-loss.
+
+ The router z-loss was introduced in Designing Effective Sparse Expert Models
+ (https://arxiv.org/abs/2202.08906). It encourages router logits to remain
+ small in an effort to improve stability.
+
+ Args:
+ router_logits: [num_groups, tokens_per_group, num_experts] router logits.
+ """
+ assert self._z_loss is None
+ if router_logits.dim() == 2:
+ router_logits = router_logits.unsqueeze(0)
+ assert router_logits.dim() == 3, "router_logits must be 3D tensor"
+ num_groups, tokens_per_group, _ = router_logits.shape
+ log_z = torch.logsumexp(router_logits, dim=-1)
+ z_loss = torch.sum(log_z**2, dtype=torch.float32) / (num_groups * tokens_per_group)
+ self._z_loss = z_loss
+
+ def pop_router_loss(self) -> torch.Tensor:
+ assert self._aux_loss is not None
+ MOE_MANAGER.add_loss(self._aux_loss, self._z_loss)
+ self._aux_loss = None
+ self._z_loss = None
+
+
+class Top1Router(MoeRouter):
+ """Top1 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
+ and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
+ function can be found in the paper about Switch Transformer of Google.
+
+ Args:
+ capacity_factor_train (float, optional): Capacity factor in routing of training.
+ capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
+ min_capacity (int, optional): The minimum number of the capacity of each expert.
+ select_policy (str, optional): The policy about tokens selection.
+ noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
+ drop_tks (bool, optional): Whether drops tokens in evaluation
+ """
+
+ def __init__(self,
+ capacity_factor_train: float = 1.25,
+ capacity_factor_eval: float = 2.0,
+ min_capacity: int = 4,
+ select_policy: str = "first",
+ noisy_func: Optional[Callable] = None,
+ drop_tks: bool = True):
+ super().__init__(k_value=1,
+ capacity_factor_train=capacity_factor_train,
+ capacity_factor_eval=capacity_factor_eval,
+ min_capacity=min_capacity,
+ noisy_func=noisy_func,
+ drop_tks=drop_tks)
+ self.select_policy = select_policy
+ assert select_policy in {"first", "random"}
+ if select_policy == "random":
+ self.uniform = torch.distributions.uniform.Uniform(
+ low=torch.tensor(0.0, device=get_current_device()),
+ high=torch.tensor(1.0, device=get_current_device())
+ ).rsample
+
+ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
+ """
+ Args:
+ inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
+
+ Returns:
+ 1. use_kernel is False:
+ The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
+ The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
+ 2. use_kernel is True:
+ ...
+ """
+ if self.noisy_func is not None and self.training:
+ inputs = self.noisy_func(inputs)
+
+ assert inputs.dtype == torch.float
+ probs = F.softmax(inputs, dim=-1)
+ num_experts = probs.size(-1)
+ capacity = self.get_capacity(inputs.shape)
+
+ top1_idx = torch.argmax(inputs, dim=-1)
+ mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
+
+ # calculate router loss
+ self.set_aux_loss(probs, top1_idx.unsqueeze(-1), num_experts)
+ self.set_z_loss(inputs)
+ self.pop_router_loss()
+
+ if not self.training and not self.drop_tks and ep_group is not None:
+ max_num = torch.max(torch.sum(mask, dim=0))
+ dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
+ capacity = max_num.item()
+
+ if self.select_policy == "random":
+ rand_mask = mask * self.uniform(mask.shape)
+ _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
+ mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
+ ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
+ elif self.select_policy == "first":
+ ranks = moe_cumsum(mask, use_kernel=self.use_kernel)
+ mask = mask * torch.lt(ranks, capacity)
+ else:
+ raise NotImplementedError("Not support such select policy yet.")
+
+ ranks = torch.sum(mask * ranks, dim=-1)
+ used_capacity = mask.sum(dim=0)
+
+ if use_kernel:
+ mask = torch.sum(mask, dim=-1)
+ mask = torch.stack([mask], dim=0).to(torch.int32)
+ dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
+ return used_capacity, probs, mask, dest_idx, num_experts * capacity
+ else:
+ ranks = F.one_hot(ranks, num_classes=capacity)
+ weight = mask * probs.type_as(inputs)
+ combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
+ sec_mask = combine_weights.bool()
+ return used_capacity, combine_weights, sec_mask
+
+
+class Top2Router(MoeRouter):
+ """Top2 router that returns the dispatch mask (batch_size * seq_len, num_experts, capacity)
+ and combine weight (batch_size * seq_len, num_experts, capacity) for routing usage. More detailed
+ function can be found in the paper about ViT-MoE.
+
+ Args:
+ capacity_factor_train (float, optional): Capacity factor in routing of training.
+ capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
+ min_capacity (int, optional): The minimum number of the capacity of each expert
+ noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
+ drop_tks (bool, optional): Whether drops tokens in evaluation.
+ """
+
+ def __init__(self,
+ capacity_factor_train: float = 1.25,
+ capacity_factor_eval: float = 2.0,
+ min_capacity: int = 4,
+ noisy_func: Optional[Callable] = None,
+ drop_tks: bool = True):
+ super().__init__(k_value=2,
+ capacity_factor_train=capacity_factor_train,
+ capacity_factor_eval=capacity_factor_eval,
+ min_capacity=min_capacity,
+ noisy_func=noisy_func,
+ drop_tks=drop_tks)
+
+ def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
+ """
+ Args:
+ inputs (torch.Tensor): The input tensor of shape (batch_size * seq_len, num_experts).
+
+ Returns:
+ 1. use_kernel is False:
+ The combine weight tensor of shape (batch_size * seq_len, num_experts, capacity).
+ The dispatch mask tensor of shape (batch_size * seq_len, num_experts, capacity).
+ 2. use_kernel is True:
+ ...
+ """
+ if self.noisy_func is not None and self.training:
+ inputs = self.noisy_func(inputs)
+
+ assert inputs.dtype == torch.float
+ probs = F.softmax(inputs, dim=-1)
+ num_experts = probs.size(-1)
+ capacity = self.get_capacity(inputs.shape)
+
+ top1_idx = torch.argmax(probs, dim=-1)
+ mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
+ logits_except1 = probs.masked_fill(mask1.bool(), float("-inf"))
+ top2_idx = torch.argmax(logits_except1, dim=-1)
+ mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
+
+ cmask = (mask1 + mask2) # loss: [s, e]
+ cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
+
+ # calculate loss
+ expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
+ self.set_aux_loss(probs, expert_indices, num_experts)
+ self.set_z_loss(inputs)
+ self.pop_router_loss()
+
+ if not self.training and not self.drop_tks and ep_group is not None:
+ max_num = torch.max(torch.sum(cmask, dim=0))
+ dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
+ capacity = max_num.item()
+
+ rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
+ rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
+ rank2 += torch.sum(mask1, dim=-2, keepdim=True)
+
+ mask1 *= torch.lt(rank1, capacity)
+ mask2 *= torch.lt(rank2, capacity)
+ used_capacity = mask1.sum(dim=0) + mask2.sum(dim=0)
+
+ rank1 = torch.sum(mask1 * rank1, dim=-1)
+ rank2 = torch.sum(mask2 * rank2, dim=-1)
+
+ if use_kernel:
+ mask1 = torch.sum(mask1, dim=-1)
+ mask2 = torch.sum(mask2, dim=-1)
+
+ mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
+ dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
+
+ return used_capacity, probs, mask, dest_idx, num_experts * capacity
+ else:
+ """
+ The following code is equivalent to:
+
+ ```
+ weight1 = mask1 * probs.type_as(inputs)
+ weight2 = mask2 * probs.type_as(inputs)
+ rank1_sc = F.one_hot(rank1, num_classes=capacity)
+ rank2_sc = F.one_hot(rank2, num_classes=capacity)
+
+ cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
+ cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
+ cb_weight = cb_weight1 + cb_weight2
+ sec_mask = cb_weight.bool()
+ ```
+ """
+
+ weight1 = mask1 * probs.type_as(inputs)
+ weight2 = mask2 * probs.type_as(inputs)
+
+ cb_weight = torch.zeros(inputs.shape + (capacity,), device=inputs.device)
+ sec_mask = torch.zeros_like(cb_weight, dtype=torch.bool)
+ indices = torch.arange(0, inputs.shape[0], device=inputs.device)
+ cb_weight[indices, top1_idx[indices], rank1[indices]] += weight1[indices, top1_idx[indices]]
+ cb_weight[indices, top2_idx[indices], rank2[indices]] += weight2[indices, top2_idx[indices]]
+ sec_mask[indices, top1_idx[indices], rank1[indices]] |= mask1.bool()[indices, top1_idx[indices]]
+ sec_mask[indices, top2_idx[indices], rank2[indices]] |= mask2.bool()[indices, top2_idx[indices]]
+
+ return used_capacity, cb_weight, sec_mask
+
+
+class TopKRouter(MoeRouter):
+ """Masked matmul router using tokens choose top-k experts assignment.
+
+ NOTE: this is modified from flaxformer.
+ This router uses the same mechanism as in Switch Transformer
+ (https://arxiv.org/abs/2101.03961) and V-MoE
+ (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are
+ sorted by router_probs and then routed to their choice of expert until the
+ expert's expert_capacity is reached. There is no guarantee that each token is
+ processed by an expert, or that each expert receives at least one token.
+
+ Attributes:
+ num_selected_experts: Maximum number of experts to which each token is
+ routed. Tokens may be routed to fewer experts if particular experts are
+ oversubscribed / reach capacity.
+ """
+
+ def __init__(self,
+ num_selected_experts: int,
+ capacity_factor_train: float = 1.25,
+ capacity_factor_eval: float = 2.0,
+ min_capacity: int = 4,
+ noisy_func: Optional[Callable] = None,
+ drop_tks: bool = True):
+ super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
+ drop_tks)
+
+ def forward(
+ self,
+ router_probs: torch.Tensor,
+ expert_capacity: int,
+ ) -> Tuple:
+ """Computes masks for the top-k experts per token.
+
+ Args:
+ router_probs: [num_groups, tokens_per_group, num_experts]
+ probabilities used to determine the routing of tokens to the experts.
+
+ Returns:
+ Dispatch and combine arrays for routing with masked matmuls.
+ """
+ # TODO: FIXME: add parallel group
+ num_groups, _, num_experts = router_probs.shape
+
+ # Top-k router probability and corresponding expert indices for each token.
+ # Shape: [num_groups, tokens_per_group, num_selected_experts].
+ expert_gate, expert_index = torch.topk(router_probs, self.k_value)
+
+ self.set_aux_loss(router_probs, expert_index, num_experts)
+ self.pop_router_loss()
+
+ # Make num_selected_experts the leading axis to ensure that top-1 choices
+ # have priority over top-2 choices, which have priority over top-3 choices,
+ # etc.
+ expert_index = torch.transpose(expert_index, 1, 2)
+ # Shape: [num_groups, num_selected_experts * tokens_per_group]
+ expert_index = expert_index.reshape(num_groups, -1)
+
+ # Create mask out of indices.
+ # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
+ expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
+
+ # Experts have a fixed capacity that we cannot exceed. A token's priority
+ # within the expert's buffer is given by the masked, cumulative capacity of
+ # its target expert.
+ # Shape: [num_groups, tokens_per_group * num_selected_experts, num_experts].
+ token_priority = torch.cumsum(expert_mask, dim=1) * expert_mask - 1
+ # Shape: [num_groups, num_selected_experts, tokens_per_group, num_experts].
+ token_priority = token_priority.reshape((num_groups, self.k_value, -1, num_experts))
+ # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
+ token_priority = torch.transpose(token_priority, 1, 2)
+ # For each token, across all selected experts, select the only non-negative
+ # (unmasked) priority. Now, for group G routing to expert E, token T has
+ # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
+ # is its targeted expert.
+ # Shape: [num_groups, tokens_per_group, num_experts].
+ token_priority = torch.max(token_priority, dim=2)[0]
+
+ # Token T can only be routed to expert E if its priority is positive and
+ # less than the expert capacity. One-hot matrix will ignore indices outside
+ # the range [0, expert_capacity).
+ # Shape: [num_groups, tokens_per_group, num_experts, expert_capacity].
+ valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
+ token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
+ dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
+ valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, -1, expert_capacity)
+ dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
+
+ # The combine array will be used for combining expert outputs, scaled by the
+ # router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
+ # expert_capacity].
+ combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
+
+ return combine_array, dispatch_mask
+
+
+def get_router_cls(top_k: int, grouped: bool = False) -> MoeRouter:
+ if not grouped:
+ if top_k == 1:
+ return Top1Router
+ elif top_k == 2:
+ return Top2Router
+ else:
+ raise NotImplementedError("top_k > 2 is not supported yet")
+ else:
+ return TopKRouter
diff --git a/colossalai/moe/utils.py b/colossalai/moe/utils.py
new file mode 100644
index 000000000000..5a17a6e0d769
--- /dev/null
+++ b/colossalai/moe/utils.py
@@ -0,0 +1,224 @@
+import contextlib
+import os
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
+from colossalai.utils import get_current_device
+
+
+class ForceFP32Parameter(torch.nn.Parameter):
+
+ def half(self, memory_format=None):
+ return self.data.clone()
+
+
+class NormalNoiseGenerator:
+ """Generates a random noisy mask for logits tensor.
+
+ All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
+ `E = the number of experts`.
+
+ Args:
+ num_experts (int): The number of experts.
+ """
+
+ def __init__(self, num_experts: int):
+ self.normal = torch.distributions.normal.Normal(
+ loc=torch.tensor(0.0, device=get_current_device()),
+ scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
+ ).rsample
+
+ def __call__(self, inputs: torch.Tensor):
+ noisy = self.normal(inputs.shape)
+ return inputs + noisy
+
+
+class UniformNoiseGenerator:
+ """Generates a random noisy mask for logits tensor.
+ copied from mesh tensorflow:
+ Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
+ Makes models more resilient to rounding errors introduced by bfloat16.
+ This seems particularly important for logits.
+
+ Args:
+ eps (float, optional): Epsilon in generator, defaults 1e-2.
+ """
+
+ def __init__(self, eps: float = 1e-2):
+ self.uniform = torch.distributions.uniform.Uniform(
+ low=torch.tensor(1.0 - eps, device=get_current_device()),
+ high=torch.tensor(1.0 + eps, device=get_current_device()),
+ ).rsample
+
+ def __call__(self, inputs: torch.Tensor):
+ noisy = self.uniform(inputs.shape)
+ return inputs * noisy
+
+
+def autocast_softmax(logit: torch.Tensor, dim: int):
+ return F.softmax(logit, dim=dim, detype=torch.float32)
+
+
+def get_noise_generator(noise_type: str, num_experts: int) -> Callable:
+ if noise_type is None:
+ return None
+ elif noise_type == "Jitter":
+ noisy_func = UniformNoiseGenerator()
+ elif noise_type == "Gaussian":
+ noisy_func = NormalNoiseGenerator(num_experts)
+ else:
+ raise NotImplementedError("Unsupported input noisy policy")
+ return noisy_func
+
+
+def get_activation(act: str) -> Callable:
+ if act is None or act == "relu":
+ return torch.nn.ReLU()
+ elif act == "gelu":
+ return torch.nn.GELU()
+ elif act == "swiglu":
+ return SwiGLU
+ else:
+ raise NotImplementedError("Unsupported activation function")
+
+
+def SwiGLU(x):
+ """Gated linear unit activation function.
+ Args:
+ x : input array
+ axis: the axis along which the split should be computed (default: -1)
+ """
+ size = x.shape[-1]
+ assert size % 2 == 0, "axis size must be divisible by 2"
+ x1, x2 = torch.split(x, size // 2, -1)
+ return x1 * (x2 * torch.sigmoid(x2))
+
+
+@contextlib.contextmanager
+def skip_init():
+ """
+ skip param random init
+ """
+
+ def _skip_init(*args, **kwargs):
+ pass
+
+ init_func = {
+ "constant_": torch.nn.init.constant_,
+ "uniform_": torch.nn.init.uniform_,
+ "normal_": torch.nn.init.normal_,
+ "kaiming_uniform_": torch.nn.init.kaiming_uniform_,
+ "kaiming_normal_": torch.nn.init.kaiming_normal_,
+ "xavier_normal_": torch.nn.init.xavier_normal_,
+ "xavier_uniform_": torch.nn.init.xavier_uniform_,
+ "trunc_normal_": torch.nn.init.trunc_normal_,
+ }
+
+ for method_name, original_init in init_func.items():
+ setattr(torch.nn.init, method_name, _skip_init)
+
+ yield
+
+ for method_name, original_init in init_func.items():
+ setattr(torch.nn.init, method_name, original_init)
+
+ return
+
+
+def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
+ """Returns a parameter dictionary, the key of which is the expert parallel
+ size of every parameter. Since the parameters in data parallelism is replicated
+ in each GPU, we set their ep_size to 1.
+
+ Args:
+ model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
+ """
+ epsize_param_dict = dict()
+ for param in model.parameters():
+ if not is_moe_tensor(param):
+ ep_size = 1 # set ep_size to 1 for dp parameters
+ else:
+ ep_size = get_ep_size(param)
+ if ep_size not in epsize_param_dict:
+ epsize_param_dict[ep_size] = []
+ epsize_param_dict[ep_size].append(param)
+
+ return epsize_param_dict
+
+
+def sync_moe_model_param(model: nn.Module):
+ """Make sure model parameters are consistent in MoE parallel context.
+
+ Args:
+ model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
+ """
+ param_dict = get_moe_epsize_param_dict(model)
+
+ # synchronize the parameters whose dp_group is the whole world
+ if 1 in param_dict:
+ for param in param_dict[1]:
+ dist.broadcast(param, src=0)
+
+ for ep_size in param_dict:
+ # When ep_size = world_size, communication is not needed
+ if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
+ for param in param_dict[ep_size]:
+ src_rank = get_dp_group_ranks(param)[0]
+ dist.broadcast(param, src=src_rank, group=get_dp_group(param))
+
+
+def set_moe_args(config: Any, args: dict):
+ for k, v in args.items():
+ setattr(config, k, v)
+
+
+def create_ep_hierarchical_group(
+ ep_group_ranks: List[int],
+ nproc_per_node: Optional[int] = None,
+) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
+ """
+ e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
+ Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
+ """
+ assert dist.is_initialized(), "Please initialize torch.distributed first."
+ rank = dist.get_rank()
+ if nproc_per_node is None:
+ nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
+ assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
+ nproc_per_node = int(nproc_per_node)
+ else:
+ assert dist.get_world_size() % nproc_per_node == 0, \
+ "nproc_per_node should be a divisor of world_size."
+ num_node = dist.get_world_size() // nproc_per_node
+
+ intra_src_rank = None
+ ep_intra_node_group = None
+ for i in range(num_node):
+ ep_intra_ranks = [
+ i * nproc_per_node + j
+ for j in range(nproc_per_node)
+ if j in ep_group_ranks
+ ]
+ group = dist.new_group(ep_intra_ranks)
+ if rank in ep_intra_ranks:
+ assert ep_intra_node_group is None
+ ep_intra_node_group = group
+ intra_src_rank = ep_intra_ranks[0]
+
+ ep_inter_node_group = None
+ ep_inter_ranks = [
+ ep_group_ranks[0] + i * nproc_per_node
+ for i in range(num_node)
+ ]
+ if len(ep_inter_ranks) > 1:
+ group = dist.new_group(ep_inter_ranks)
+ if rank in ep_inter_ranks:
+ ep_inter_node_group = group
+
+ return intra_src_rank, ep_intra_node_group, ep_inter_node_group
diff --git a/colossalai/nn/layer/__init__.py b/colossalai/nn/layer/__init__.py
index 9aeab9f44a6d..16281fe0b66d 100644
--- a/colossalai/nn/layer/__init__.py
+++ b/colossalai/nn/layer/__init__.py
@@ -1,2 +1 @@
-# from .moe import *
from .utils import *
diff --git a/colossalai/nn/layer/moe/__init__.py b/colossalai/nn/layer/moe/__init__.py
deleted file mode 100644
index 6a5ccff510be..000000000000
--- a/colossalai/nn/layer/moe/__init__.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from .checkpoint import load_moe_model, save_moe_model
-from .experts import Experts, FFNExperts, TPExperts
-from .layers import MoeLayer, MoeModule
-from .routers import MoeRouter, Top1Router, Top2Router
-from .utils import NormalNoiseGenerator, UniformNoiseGenerator, build_ffn_experts
-
-__all__ = [
- "Experts",
- "FFNExperts",
- "TPExperts",
- "Top1Router",
- "Top2Router",
- "MoeLayer",
- "NormalNoiseGenerator",
- "UniformNoiseGenerator",
- "build_ffn_experts",
- "MoeModule",
- "MoeRouter",
- "save_moe_model",
- "load_moe_model",
-]
diff --git a/colossalai/nn/layer/moe/_operation.py b/colossalai/nn/layer/moe/_operation.py
deleted file mode 100644
index 2f0b7e43673a..000000000000
--- a/colossalai/nn/layer/moe/_operation.py
+++ /dev/null
@@ -1,171 +0,0 @@
-from typing import Any, Optional, Tuple
-
-import torch
-import torch.distributed as dist
-from torch import Tensor
-from torch.distributed import ProcessGroup
-
-COL_MOE_KERNEL_FLAG = False
-
-try:
- from colossalai._C import moe
-except:
- moe = None
-
-
-def build_moe_if_not_prebuilt():
- # load moe kernel during runtime if not pre-built
- global moe
- if moe is None:
- from colossalai.kernel.op_builder import MOEBuilder
-
- moe = MOEBuilder().load()
-
-
-class AllGather(torch.autograd.Function):
- @staticmethod
- def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
- global moe
-
- if moe is None:
- from colossalai.kernel.op_builder import MOEBuilder
-
- moe = MOEBuilder().load()
-
- if ctx is not None:
- ctx.comm_grp = group
-
- comm_size = dist.get_world_size(group)
- if comm_size == 1:
- return inputs.unsqueeze(0)
-
- buffer_shape = (comm_size,) + inputs.shape
- outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
- buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
- dist.all_gather(buffer_list, inputs, group=group)
- return outputs
-
- @staticmethod
- def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
- return ReduceScatter.forward(None, grad_outputs, ctx.comm_grp), None
-
-
-class ReduceScatter(torch.autograd.Function):
- @staticmethod
- def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
- if ctx is not None:
- ctx.comm_grp = group
-
- comm_size = dist.get_world_size(group)
- if comm_size == 1:
- return inputs.squeeze(0)
-
- if not inputs.is_contiguous():
- inputs = inputs.contiguous()
-
- output_shape = inputs.shape[1:]
- outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
- buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
- dist.reduce_scatter(outputs, buffer_list, group=group)
- return outputs
-
- @staticmethod
- def backward(ctx: Any, grad_outputs: Tensor) -> Tuple[Tensor, None]:
- return AllGather.forward(None, grad_outputs, ctx.comm_grp), None
-
-
-class AllToAll(torch.autograd.Function):
- """Dispatches input tensor [e, c, h] to all experts by all_to_all_single
- operation in torch.distributed.
- """
-
- @staticmethod
- def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
- if ctx is not None:
- ctx.comm_grp = group
- if not inputs.is_contiguous():
- inputs = inputs.contiguous()
- if dist.get_world_size(group) == 1:
- return inputs
- output = torch.empty_like(inputs)
- dist.all_to_all_single(output, inputs, group=group)
- return output
-
- @staticmethod
- def backward(ctx: Any, *grad_outputs: Tensor) -> Tuple[Tensor, None]:
- return AllToAll.forward(None, *grad_outputs, ctx.comm_grp), None
-
-
-class MoeDispatch(torch.autograd.Function):
- @staticmethod
- def forward(ctx, tokens, mask, dest_idx, ec):
- s = tokens.size(0)
- h = tokens.size(1)
-
- # load moe kernel during runtime if not pre-built
- build_moe_if_not_prebuilt()
-
- expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
-
- ctx.save_for_backward(mask, dest_idx)
- ctx.s = s
- ctx.h = h
- ctx.ec = ec
-
- return expert_input
-
- @staticmethod
- def backward(ctx, output_grad):
- mask, dest_idx = ctx.saved_tensors
- d_tokens = moe.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
- return d_tokens, None, None, None
-
-
-class MoeCombine(torch.autograd.Function):
- @staticmethod
- def forward(ctx, expert_tokens, logits, mask, dest_idx, ec):
- assert logits.dtype == torch.float32
-
- s = logits.size(0)
- e = logits.size(1)
- c = ec // e
- h = expert_tokens.size(-1)
-
- # load moe kernel during runtime if not pre-built
- build_moe_if_not_prebuilt()
-
- fp16_flag = expert_tokens.dtype == torch.float16
- cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
- ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
- output = ctokens.to(torch.float16) if fp16_flag else ctokens
-
- ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
- ctx.s = s
- ctx.e = e
- ctx.c = c
- ctx.h = h
- ctx.fp16_flag = fp16_flag
-
- return output
-
- @staticmethod
- def backward(ctx, tokens_grad):
- expert_tokens, logits, mask, dest_idx = ctx.saved_tensors
-
- cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 else tokens_grad
- cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
- d_expert, d_logits = moe.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits, mask, dest_idx)
- d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
-
- return d_expert, d_logits, None, None, None
-
-
-def moe_cumsum(inputs: Tensor):
- dim0 = inputs.size(0)
- flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
- if flag and COL_MOE_KERNEL_FLAG:
- # load moe kernel during runtime if not pre-built
- build_moe_if_not_prebuilt()
- return moe.cumsum_sub_one(inputs)
- else:
- return torch.cumsum(inputs, dim=0) - 1
diff --git a/colossalai/nn/layer/moe/checkpoint.py b/colossalai/nn/layer/moe/checkpoint.py
deleted file mode 100644
index adad19d581ef..000000000000
--- a/colossalai/nn/layer/moe/checkpoint.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-
-from .experts import MoeExperts
-
-
-def save_moe_model(model: nn.Module, save_path: str):
- state_dict = model.state_dict()
- if dist.get_rank() == 0:
- torch.save(state_dict, save_path)
- dist.barrier()
-
-
-def load_moe_model(model: nn.Module, load_path: str):
- state_dict = torch.load(load_path)
-
- for prefix, module in model.named_modules():
- if prefix.endswith(".moe_layer.experts"):
- # this module should be an Experts instance
- assert isinstance(module, MoeExperts)
-
- ep_rank = dist.get_rank(module.dist_info.ep_group)
- num_local = module.num_local_experts
- for i in range(num_local):
- expert_id = ep_rank * num_local + i
- for name, _ in module.experts[i].named_parameters():
- cur_key = f"{prefix}.experts.{i}.{name}"
- param_key = f"{prefix}.experts.{expert_id}.{name}"
- load_param = state_dict[param_key]
- state_dict[cur_key] = load_param
-
- for name, _ in module.experts[0].named_parameters():
- pop_pre = f"{prefix}.experts."
- pop_suf = f".{name}"
- for i in range(num_local, module.num_total_experts):
- pop_key = f"{pop_pre}{i}{pop_suf}"
- state_dict.pop(pop_key)
-
- model.load_state_dict(state_dict)
diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py
deleted file mode 100644
index 4b2ecb241702..000000000000
--- a/colossalai/nn/layer/moe/experts.py
+++ /dev/null
@@ -1,201 +0,0 @@
-import math
-from copy import deepcopy
-from typing import Type
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.legacy.context import ParallelMode, seed
-from colossalai.legacy.zero.init_ctx import no_shard_zero_decrator
-from colossalai.utils import get_current_device
-
-
-class MoeExperts(nn.Module):
- """Basic class for experts in MoE. It stores what kind of communication experts use
- to exchange tokens, how many experts in a single GPU and parallel information such as
- expert parallel size, data parallel size and their distributed communication groups.
- """
-
- def __init__(self, comm_name: str, num_experts: int):
- super().__init__()
- assert comm_name in {
- "all_to_all",
- "all_gather",
- }, "This kind of communication has not been implemented yet.\n Please use Experts build function."
- self.comm_name = comm_name
- self.num_total_experts = num_experts
- # Get the configuration of experts' deployment and parallel information from moe context
- self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
-
-
-@no_shard_zero_decrator(is_replicated=False)
-class Experts(MoeExperts):
- """A wrapper class to create experts. It will create E experts across the
- moe model parallel group, where E is the number of experts. Every expert
- is a instance of the class, 'expert' in initialization parameters.
-
- Args:
- expert_cls (:class:`torch.nn.Module`): The class of all experts
- num_experts (int): The number of experts
- expert_args: Args used to initialize experts, the args could be found in corresponding expert class
- """
-
- def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
- super().__init__("all_to_all", num_experts)
-
- # Use seed to make every expert different from others
- with seed(ParallelMode.TENSOR):
- self.experts = nn.ModuleList([expert_cls(**expert_args) for _ in range(self.num_local_experts)])
-
- # Attach parallel information for all parameters in Experts
- for exp in self.experts:
- for param in exp.parameters():
- param.__setattr__("moe_info", self.dist_info)
-
- def forward(self, inputs: torch.Tensor):
- # Split inputs for each expert
- expert_input = torch.chunk(inputs, self.num_local_experts, dim=1)
- expert_output = []
-
- # Get outputs from each expert
- for i in range(self.num_local_experts):
- expert_output.append(self.experts[i](expert_input[i]))
-
- # Concatenate all outputs together
- output = torch.cat(expert_output, dim=1).contiguous()
- return output
-
- def state_dict(self, destination=None, prefix="", keep_vars=False):
- assert keep_vars == False, "Only support keep_vars=False now"
- dp_rank = dist.get_rank(self.dist_info.dp_group)
- ep_rank = dist.get_rank(self.dist_info.ep_group)
- submodule_dict = dict()
- example_submodule = None
- for name, subm in self.experts.named_modules():
- if subm is self.experts:
- continue
- module_number = self.num_local_experts * ep_rank + int(name)
- submodule_dict[module_number] = subm
- example_submodule = subm
-
- if dp_rank == 0:
- local_prefix = prefix + "experts."
- buffer_module = deepcopy(example_submodule)
- for i in range(self.num_total_experts):
- source_rank = i // self.num_local_experts
- current_prefix = local_prefix + str(i) + "."
- comm_module = submodule_dict.get(i, buffer_module)
- for name, param in comm_module.named_parameters():
- dist.broadcast(param.data, src=source_rank, group=self.dist_info.ep_group)
- if ep_rank == 0:
- destination[current_prefix + name] = param.data.cpu()
-
- dist.barrier()
-
-
-class FFNExperts(MoeExperts):
- """Use torch.bmm to speed up for multiple experts."""
-
- def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
- super().__init__("all_to_all", num_experts)
-
- self.w1 = nn.Parameter(torch.empty(self.num_local_experts, d_model, d_ff, device=get_current_device()))
- self.b1 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_ff, device=get_current_device()))
-
- self.w2 = nn.Parameter(torch.empty(self.num_local_experts, d_ff, d_model, device=get_current_device()))
- self.b2 = nn.Parameter(torch.empty(self.num_local_experts, 1, d_model, device=get_current_device()))
-
- s1 = math.sqrt(0.1 / d_model)
- s2 = math.sqrt(0.1 / d_ff)
-
- with seed(ParallelMode.TENSOR):
- nn.init.trunc_normal_(self.w1, std=s1)
- nn.init.trunc_normal_(self.b1, std=s1)
- nn.init.trunc_normal_(self.w2, std=s2)
- nn.init.trunc_normal_(self.b2, std=s2)
-
- self.act = nn.GELU() if activation is None else activation
- self.drop = nn.Dropout(p=drop_rate)
-
- for param in self.parameters():
- param.__setattr__("moe_info", self.dist_info)
-
- def forward(self, inputs): # inputs [g, el, c, h]
- el = inputs.size(1)
- h = inputs.size(-1)
-
- inputs = inputs.transpose(0, 1)
- inshape = inputs.shape
- inputs = inputs.reshape(el, -1, h)
-
- out_ff = torch.baddbmm(self.b1, inputs, self.w1)
- out_act = self.act(out_ff)
- with seed(ParallelMode.TENSOR):
- out_inter = self.drop(out_act)
-
- out_model = torch.baddbmm(self.b2, out_inter, self.w2)
- with seed(ParallelMode.TENSOR):
- outputs = self.drop(out_model) # outputs [el, gc, h]
-
- outputs = outputs.reshape(inshape)
- outputs = outputs.transpose(0, 1).contiguous()
- return outputs
-
-
-class TPExperts(MoeExperts):
- """Use tensor parallelism to split each expert evenly, which can deploy experts in
- case that the number of experts can't be divide by maximum expert parallel size or
- maximum expert parallel size can't be divide by the number of experts.
- """
-
- def __init__(self, num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
- super().__init__("all_gather", MOE_CONTEXT.max_ep_size)
-
- assert d_ff % MOE_CONTEXT.max_ep_size == 0, "d_ff should be divide by maximum expert parallel size"
-
- p_ff = d_ff // MOE_CONTEXT.max_ep_size
-
- self.w1 = nn.Parameter(torch.empty(num_experts, d_model, p_ff, device=get_current_device()))
- self.b1 = nn.Parameter(torch.empty(num_experts, 1, p_ff, device=get_current_device()))
-
- self.w2 = nn.Parameter(torch.empty(num_experts, p_ff, d_model, device=get_current_device()))
- self.b2 = nn.Parameter(torch.empty(num_experts, 1, d_model, device=get_current_device()))
-
- s1 = math.sqrt(0.1 / d_model)
- s2 = math.sqrt(0.1 / d_ff)
-
- with seed(ParallelMode.TENSOR):
- nn.init.trunc_normal_(self.w1, std=s1)
- nn.init.trunc_normal_(self.b1, std=s1)
- nn.init.trunc_normal_(self.w2, std=s2)
-
- nn.init.trunc_normal_(self.b2, std=s2)
-
- self.act = nn.GELU() if activation is None else activation
- self.drop = nn.Dropout(p=drop_rate)
-
- self.w1.__setattr__("moe_info", self.dist_info)
- self.w2.__setattr__("moe_info", self.dist_info)
- self.b1.__setattr__("moe_info", self.dist_info)
-
- def forward(self, inputs): # inputs [g, e, c, h]
- e = inputs.size(1)
- h = inputs.size(-1)
-
- inputs = inputs.transpose(0, 1)
- inshape = inputs.shape
- inputs = inputs.reshape(e, -1, h)
-
- out_ff = torch.baddbmm(self.b1, inputs, self.w1)
- out_act = self.act(out_ff)
- with seed(ParallelMode.TENSOR):
- out_inter = self.drop(out_act)
-
- out_model = torch.baddbmm(self.b2, out_inter, self.w2)
- outputs = self.drop(out_model) # outputs [e, gc, h]
-
- outputs = outputs.reshape(inshape)
- outputs = outputs.transpose(0, 1).contiguous()
- return outputs # outputs [g, e, c, h]
diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py
deleted file mode 100644
index 23d483e6a17a..000000000000
--- a/colossalai/nn/layer/moe/layers.py
+++ /dev/null
@@ -1,212 +0,0 @@
-import math
-from typing import Optional, Tuple, Type
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.legacy.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator
-from colossalai.nn.layer.moe._operation import (
- COL_MOE_KERNEL_FLAG,
- AllGather,
- AllToAll,
- MoeCombine,
- MoeDispatch,
- ReduceScatter,
-)
-from colossalai.nn.layer.moe.experts import Experts, MoeExperts
-from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
-from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
-from colossalai.utils import get_current_device
-
-
-@no_shard_zero_decrator(is_replicated=True)
-class MoeLayer(nn.Module):
- """A MoE layer, that puts its input tensor to its gate and uses the output logits
- to router all tokens, is mainly used to exchange all tokens for every expert across
- the moe tensor group by all to all communication. Then it will get the output of all
- experts and exchange the output. At last returns the output of the moe system.
-
- Args:
- dim_model (int): Dimension of model.
- num_experts (int): The number of experts.
- router (MoeRouter): Instance of router used in routing.
- experts (MoeExperts): Instance of experts generated by Expert.
- """
-
- def __init__(self, dim_model: int, num_experts: int, router: MoeRouter, experts: MoeExperts):
- super().__init__()
- self.d_model = dim_model
- self.num_experts = num_experts
- self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
- self.router: MoeRouter = router
- self.experts: MoeExperts = experts
- self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
- self.ep_group = experts.dist_info.ep_group
- self.ep_size = experts.dist_info.ep_size
- self.num_local_experts = experts.num_local_experts
-
- nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
-
- def a2a_process(self, dispatch_data: torch.Tensor):
- expert_input = AllToAll.apply(dispatch_data, self.ep_group)
- input_shape = expert_input.shape
- expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
- expert_output = self.experts(expert_input)
- expert_output = expert_output.reshape(input_shape)
- expert_output = AllToAll.apply(expert_output, self.ep_group)
- return expert_output
-
- def tp_process(self, dispatch_data: torch.Tensor):
- expert_in = AllGather.apply(dispatch_data, self.ep_group)
- expert_out = self.experts(expert_in)
- expert_out = ReduceScatter.apply(expert_out, self.ep_group)
- return expert_out
-
- def forward(self, inputs: torch.Tensor) -> Tuple:
- # reshape the input tokens
- tokens = inputs.reshape(-1, self.d_model)
-
- # the data type of the inputs in the gating should be fp32
- fp32_input = tokens.to(torch.float)
- fp32_weight = self.gate_weight.to(torch.float)
- gate_output = F.linear(fp32_input, fp32_weight)
-
- # the result from the router
- route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
-
- if self.use_kernel:
- dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
- dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
- else:
- sec_mask_f = route_result_list[1].type_as(inputs)
- dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
-
- # dispatch_data [e, c, h]
- if self.experts.comm_name == "all_to_all":
- expert_output = self.a2a_process(dispatch_data)
- elif self.experts.comm_name == "all_gather":
- expert_output = self.tp_process(dispatch_data)
- else:
- raise NotImplementedError(
- "This kind of communication has not been implemented yet.\n Please use Experts " "build function."
- )
- # expert_output [e, c, h]
- if self.use_kernel:
- expert_output = expert_output.reshape(-1, self.d_model)
- ans = MoeCombine.apply(expert_output, *route_result_list)
- else:
- combine_weights = route_result_list[0].type_as(inputs)
- combine_weights = combine_weights.view(combine_weights.shape[0], -1)
- expert_output = expert_output.view(-1, expert_output.shape[-1])
- ans = torch.matmul(combine_weights, expert_output)
-
- ans = ans.reshape(inputs.shape)
- l_aux = self.router.pop_routing_loss()
- return ans, l_aux
-
-
-class MoeModule(nn.Module):
- """A class for users to create MoE modules in their models.
-
- Args:
- dim_model (int): Hidden dimension of training model
- num_experts (int): The number experts
- top_k (int, optional): The number of experts for dispatchment of each token
- capacity_factor_train (float, optional): Capacity factor in routing during training
- capacity_factor_eval (float, optional): Capacity factor in routing during evaluation
- min_capacity (int, optional): The minimum number of the capacity of each expert
- noisy_policy (str, optional): The policy of noisy function. Now we have 'Jitter' and 'Gaussian'.
- 'Jitter' can be found in `Switch Transformer paper`_.
- 'Gaussian' can be found in `ViT-MoE paper`_.
- drop_tks (bool, optional): Whether drops tokens in evaluation
- use_residual (bool, optional): Makes this MoE layer a Residual MoE.
- More information can be found in `Microsoft paper`_.
- residual_instance (nn.Module, optional): The instance of residual module in Residual MoE
- expert_instance (MoeExperts, optional): The instance of experts module in MoeLayer
- expert_cls (Type[nn.Module], optional): The class of each expert when no instance is given
- expert_args (optional): The args of expert when no instance is given
-
- .. _Switch Transformer paper:
- https://arxiv.org/abs/2101.03961
- .. _ViT-MoE paper:
- https://arxiv.org/abs/2106.05974
- .. _Microsoft paper:
- https://arxiv.org/abs/2201.05596
- """
-
- def __init__(
- self,
- dim_model: int,
- num_experts: int,
- top_k: int = 1,
- capacity_factor_train: float = 1.25,
- capacity_factor_eval: float = 2.0,
- min_capacity: int = 4,
- noisy_policy: Optional[str] = None,
- drop_tks: bool = True,
- use_residual: bool = False,
- residual_instance: Optional[nn.Module] = None,
- expert_instance: Optional[MoeExperts] = None,
- expert_cls: Optional[Type[nn.Module]] = None,
- **expert_args,
- ):
- super().__init__()
-
- noisy_func = None
- if noisy_policy is not None:
- if noisy_policy == "Jitter":
- noisy_func = UniformNoiseGenerator()
- elif noisy_policy == "Gaussian":
- noisy_func = NormalNoiseGenerator(num_experts)
- else:
- raise NotImplementedError("Unsupported input noisy policy")
-
- if top_k == 1:
- moe_router_cls = Top1Router
- elif top_k == 2:
- moe_router_cls = Top2Router
- else:
- raise NotImplementedError("top_k > 2 is not supported yet")
-
- self.moe_router = moe_router_cls(
- capacity_factor_train=capacity_factor_train,
- capacity_factor_eval=capacity_factor_eval,
- min_capacity=min_capacity,
- noisy_func=noisy_func,
- drop_tks=drop_tks,
- )
- self.use_residual = use_residual
- if use_residual:
- if residual_instance is not None:
- self.residual_module = residual_instance
- else:
- assert expert_cls is not None, "Expert class can't be None when residual instance is not given"
- self.residual_module = expert_cls(**expert_args)
-
- with no_shard_zero_context():
- self.residual_combine = nn.Linear(dim_model, 2, device=get_current_device())
-
- if expert_instance is not None:
- my_experts = expert_instance
- else:
- assert expert_cls is not None, "Expert class can't be None when experts instance is not given"
- my_experts = Experts(expert_cls, num_experts, **expert_args)
-
- self.moe_layer = MoeLayer(
- dim_model=dim_model, num_experts=num_experts, router=self.moe_router, experts=my_experts
- )
-
- def forward(self, inputs: torch.Tensor):
- moe_output, l_aux = self.moe_layer(inputs)
-
- if self.use_residual:
- residual_output = self.residual_module(inputs)
- combine_coef = self.residual_combine(inputs)
- combine_coef = F.softmax(combine_coef, dim=-1)
- output = moe_output * combine_coef[..., 0:1] + residual_output * combine_coef[..., 1:]
- else:
- output = moe_output
-
- return output, l_aux
diff --git a/colossalai/nn/layer/moe/routers.py b/colossalai/nn/layer/moe/routers.py
deleted file mode 100644
index 7ba83b2787a0..000000000000
--- a/colossalai/nn/layer/moe/routers.py
+++ /dev/null
@@ -1,235 +0,0 @@
-import math
-from abc import ABC
-from typing import Callable, Optional
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.distributed import ProcessGroup
-
-from colossalai.nn.layer.moe._operation import moe_cumsum
-from colossalai.utils import get_current_device
-
-
-class MoeRouter(nn.Module, ABC):
- """Base class for all MoE routers.
- Args:
- k_value (int): The value of top_k.
- capacity_factor_train (float): Capacity factor in routing of training.
- capacity_factor_eval (float): Capacity factor in routing of evaluation.
- min_capacity (int): The minimum number of the capacity of each expert.
- noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
- drop_tks (bool, optional): Whether drops tokens in evaluation
- """
-
- def __init__(
- self,
- k_value: int,
- capacity_factor_train: float,
- capacity_factor_eval: float,
- min_capacity: int,
- noisy_func: Callable = None,
- drop_tks: bool = True,
- ):
- super().__init__()
- self.k_value = k_value
- self.capacity_factor_train = capacity_factor_train
- self.capacity_factor_eval = capacity_factor_eval
- self.min_capacity = min_capacity
- self.noisy_func = noisy_func
- self.drop_tks = drop_tks
- self._routing_loss = None
-
- def get_capacity(self, logits_shape):
- capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
- capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
- capacity += capacity % 2
- capacity = max(capacity, self.min_capacity)
- assert capacity > 0
- return capacity
-
- def set_routing_loss(self, aux_loss: torch.Tensor) -> None:
- assert self._routing_loss is None
- self._routing_loss = aux_loss
-
- def pop_routing_loss(self) -> torch.Tensor:
- assert self._routing_loss is not None
- reservation = self._routing_loss
- self._routing_loss = None
- return reservation
-
-
-class Top1Router(MoeRouter):
- """Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
- for routing usage. More detailed function can be found in the paper about Switch Transformer
- of Google.
- Args:
- capacity_factor_train (float, optional): Capacity factor in routing of training.
- capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
- min_capacity (int, optional): The minimum number of the capacity of each expert.
- select_policy (str, optional): The policy about tokens selection.
- noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
- drop_tks (bool, optional): Whether drops tokens in evaluation
- """
-
- def __init__(
- self,
- capacity_factor_train: float = 1.25,
- capacity_factor_eval: float = 2.0,
- min_capacity: int = 4,
- select_policy: str = "first",
- noisy_func: Callable = None,
- drop_tks: bool = True,
- ):
- super().__init__(
- k_value=1,
- capacity_factor_train=capacity_factor_train,
- capacity_factor_eval=capacity_factor_eval,
- min_capacity=min_capacity,
- noisy_func=noisy_func,
- drop_tks=drop_tks,
- )
- self.select_policy = select_policy
- assert select_policy in {"first", "random"}
- if select_policy == "random":
- self.uniform = torch.distributions.uniform.Uniform(
- low=torch.tensor(0.0, device=get_current_device()), high=torch.tensor(1.0, device=get_current_device())
- ).rsample
-
- def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
- if self.noisy_func is not None and self.training:
- inputs = self.noisy_func(inputs)
-
- assert inputs.dtype == torch.float
- logits = F.softmax(inputs, dim=-1)
- num_experts = logits.size(-1)
- capacity = self.get_capacity(logits.shape)
-
- top1_idx = torch.argmax(inputs, dim=-1)
- mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
-
- # caculate the auxiliary loss
- me = torch.mean(logits, dim=0)
- ce = torch.mean(mask.float(), dim=0)
- l_aux = num_experts * torch.sum(me * ce)
- self.set_routing_loss(l_aux)
-
- if not self.training and not self.drop_tks:
- max_num = torch.max(torch.sum(mask, dim=0))
- dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
- capacity = max_num.item()
-
- if self.select_policy == "random":
- rand_mask = mask * self.uniform(mask.shape)
- _, dispatch_idx = torch.topk(rand_mask, k=capacity, dim=0)
- mask = mask * torch.zeros_like(mask).scatter_(0, dispatch_idx, 1)
- ranks = moe_cumsum(mask)
- elif self.select_policy == "first":
- ranks = moe_cumsum(mask)
- mask = mask * torch.lt(ranks, capacity)
- else:
- raise NotImplementedError("Not support such select policy yet.")
-
- ranks = torch.sum(mask * ranks, dim=-1)
-
- if use_kernel:
- mask = torch.sum(mask, dim=-1)
- mask = torch.stack([mask], dim=0).to(torch.int32)
- dest_idx = torch.stack([top1_idx * capacity + ranks], dim=0).to(torch.int32)
- return logits, mask, dest_idx, num_experts * capacity
- else:
- ranks = F.one_hot(ranks, num_classes=capacity)
- weight = mask * logits.type_as(inputs)
- combine_weights = weight.unsqueeze(2) * ranks.unsqueeze(1)
- sec_mask = combine_weights.bool()
- return combine_weights, sec_mask
-
-
-class Top2Router(MoeRouter):
- """Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
- for routing usage. More detailed function can be found in the paper about ViT-MoE.
- Args:
- capacity_factor_train (float, optional): Capacity factor in routing of training.
- capacity_factor_eval (float, optional): Capacity factor in routing of evaluation.
- min_capacity (int, optional): The minimum number of the capacity of each expert
- noisy_func (:class:`typing.Callable`, optional): Noisy function used in logits.
- drop_tks (bool, optional): Whether drops tokens in evaluation.
- """
-
- def __init__(
- self,
- capacity_factor_train: float = 1.25,
- capacity_factor_eval: float = 2.0,
- min_capacity: int = 4,
- noisy_func: Callable = None,
- drop_tks: bool = True,
- ):
- super().__init__(
- k_value=2,
- capacity_factor_train=capacity_factor_train,
- capacity_factor_eval=capacity_factor_eval,
- min_capacity=min_capacity,
- noisy_func=noisy_func,
- drop_tks=drop_tks,
- )
-
- def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None):
- # inputs: [s, h]
- if self.noisy_func is not None and self.training:
- inputs = self.noisy_func(inputs)
-
- assert inputs.dtype == torch.float
- logits = F.softmax(inputs, dim=-1) # logits: [s, e]
- num_experts = logits.size(-1)
- capacity = self.get_capacity(logits.shape)
-
- top1_idx = torch.argmax(logits, dim=-1)
- mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
- logits_except1 = logits.masked_fill(mask1.bool(), float("-inf"))
- top2_idx = torch.argmax(logits_except1, dim=-1)
- mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
-
- cmask = mask1 + mask2 # loss: [s, e]
-
- # caculate the auxiliary loss
- me = torch.mean(logits, dim=0)
- ce = torch.mean(cmask.float(), dim=0)
- l_aux = num_experts * torch.sum(me * ce) / 2.0 # div 2 to normalize it to 1
- self.set_routing_loss(l_aux)
-
- if not self.training and not self.drop_tks:
- max_num = torch.max(torch.sum(cmask, dim=0))
- dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
- capacity = max_num.item()
-
- rank1 = moe_cumsum(mask1) # rank1: [s, e]
- rank2 = moe_cumsum(mask2)
- rank2 += torch.sum(mask1, dim=-2, keepdim=True)
-
- mask1 *= torch.lt(rank1, capacity)
- mask2 *= torch.lt(rank2, capacity)
-
- rank1 = torch.sum(mask1 * rank1, dim=-1)
- rank2 = torch.sum(mask2 * rank2, dim=-1)
-
- if use_kernel:
- mask1 = torch.sum(mask1, dim=-1)
- mask2 = torch.sum(mask2, dim=-1)
-
- mask = torch.stack([mask1, mask2], dim=0).to(torch.int32)
- dest_idx = torch.stack([top1_idx * capacity + rank1, top2_idx * capacity + rank2], dim=0).to(torch.int32)
-
- return logits, mask, dest_idx, num_experts * capacity
- else:
- weight1 = mask1 * logits.type_as(inputs)
- weight2 = mask2 * logits.type_as(inputs)
- rank1_sc = F.one_hot(rank1, num_classes=capacity)
- rank2_sc = F.one_hot(rank2, num_classes=capacity)
-
- cb_weight1 = weight1.unsqueeze(2) * rank1_sc.unsqueeze(1)
- cb_weight2 = weight2.unsqueeze(2) * rank2_sc.unsqueeze(1)
- cb_weight = cb_weight1 + cb_weight2
- sec_mask = cb_weight.bool()
-
- return cb_weight, sec_mask
diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py
deleted file mode 100644
index 4f31dd5579dc..000000000000
--- a/colossalai/nn/layer/moe/utils.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import torch
-import torch.nn.functional as F
-
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.utils import get_current_device
-
-from .experts import FFNExperts, TPExperts
-
-
-class ForceFP32Parameter(torch.nn.Parameter):
- def half(self, memory_format=None):
- return self.data.clone()
-
-
-class NormalNoiseGenerator:
- """Generates a random noisy mask for logits tensor.
-
- All noise is generated from a normal distribution :math:`(0, 1 / E^2)`, where
- `E = the number of experts`.
-
- Args:
- num_experts (int): The number of experts.
- """
-
- def __init__(self, num_experts: int):
- self.normal = torch.distributions.normal.Normal(
- loc=torch.tensor(0.0, device=get_current_device()),
- scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
- ).rsample
-
- def __call__(self, inputs: torch.Tensor):
- noisy = self.normal(inputs.shape)
- return inputs + noisy
-
-
-class UniformNoiseGenerator:
- """Generates a random noisy mask for logits tensor.
- copied from mesh tensorflow:
- Multiply values by a random number between :math:`1-epsilon` and :math:`1+epsilon`.
- Makes models more resilient to rounding errors introduced by bfloat16.
- This seems particularly important for logits.
-
- Args:
- eps (float, optional): Epsilon in generator, defaults 1e-2.
- """
-
- def __init__(self, eps: float = 1e-2):
- self.uniform = torch.distributions.uniform.Uniform(
- low=torch.tensor(1.0 - eps, device=get_current_device()),
- high=torch.tensor(1.0 + eps, device=get_current_device()),
- ).rsample
-
- def __call__(self, inputs: torch.Tensor):
- noisy = self.uniform(inputs.shape)
- return inputs * noisy
-
-
-def autocast_softmax(logit: torch.Tensor, dim: int):
- if logit.dtype != torch.float32:
- logit = logit.float()
- return F.softmax(logit, dim=dim)
-
-
-def build_ffn_experts(num_experts: int, d_model: int, d_ff: int, activation=None, drop_rate: float = 0):
- mep_size = MOE_CONTEXT.max_ep_size
- if num_experts % mep_size == 0 or mep_size % num_experts == 0:
- return FFNExperts(num_experts, d_model, d_ff, activation, drop_rate)
- elif d_ff % mep_size == 0:
- return TPExperts(num_experts, d_model, d_ff, activation, drop_rate)
- else:
- raise NotImplementedError(f"Can not build {num_experts} experts in {mep_size} GPUS.")
diff --git a/colossalai/nn/loss/__init__.py b/colossalai/nn/loss/__init__.py
index 7c6fb099d272..e69de29bb2d1 100644
--- a/colossalai/nn/loss/__init__.py
+++ b/colossalai/nn/loss/__init__.py
@@ -1 +0,0 @@
-# from .loss_moe import MoeCrossEntropyLoss, MoeLoss
diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py
index c3c0180e8516..7d53a1dd6834 100644
--- a/colossalai/nn/optimizer/cpu_adam.py
+++ b/colossalai/nn/optimizer/cpu_adam.py
@@ -1,9 +1,10 @@
import math
+import platform
from typing import Optional
import torch
-from colossalai.kernel.op_builder import CPUAdamBuilder
+from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder
from .nvme_optimizer import NVMeOptimizer
@@ -77,7 +78,7 @@ def __init__(
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
- cpu_adam = CPUAdamBuilder().load()
+ cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load()
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py
index c7a309b872ce..d34fd601ab25 100644
--- a/colossalai/nn/optimizer/hybrid_adam.py
+++ b/colossalai/nn/optimizer/hybrid_adam.py
@@ -84,9 +84,10 @@ def __init__(
nvme_offload_fraction,
nvme_offload_dir,
)
- fused_optim = FusedOptimBuilder().load()
- self.gpu_adam_op = fused_optim.multi_tensor_adam
- self._dummy_overflow_buf = torch.cuda.IntTensor([0])
+ if torch.cuda.is_available():
+ fused_optim = FusedOptimBuilder().load()
+ self.gpu_adam_op = fused_optim.multi_tensor_adam
+ self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@torch.no_grad()
def step(self, closure=None, div_scale: float = -1):
@@ -118,11 +119,11 @@ def step(self, closure=None, div_scale: float = -1):
group_step = state["step"]
beta1, beta2 = group["betas"]
- if target_device.type == "cpu":
- assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
- assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
+ if target_device.type == "cpu" or target_device.type == "npu":
+ assert state["exp_avg"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
+ assert state["exp_avg_sq"].device.type in ("cpu", "npu"), "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
- if p.grad.dtype is torch.bfloat16:
+ if p.grad.dtype is torch.bfloat16 or p.grad.device.type == "npu":
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]
diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py
index f822c1819adc..6e49fa36bb83 100644
--- a/colossalai/pipeline/p2p.py
+++ b/colossalai/pipeline/p2p.py
@@ -5,9 +5,12 @@
import pickle
import re
from typing import Any, List, Optional, Union
+from collections import namedtuple
import torch
import torch.distributed as dist
+from dataclasses import dataclass
+from enum import Enum
from packaging.version import Version
from torch.distributed import ProcessGroup
from torch.distributed import distributed_c10d as c10d
@@ -45,6 +48,21 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
return unpickle
+def check_for_nccl_backend(group):
+ pg = group or c10d._get_default_group()
+ # Gate PG wrapper check on Gloo availability.
+ if c10d._GLOO_AVAILABLE:
+ # It is not expected for PG to be wrapped many times, but support it just
+ # in case
+ while isinstance(pg, c10d._ProcessGroupWrapper):
+ pg = pg.wrapped_pg
+
+ return (
+ c10d.is_nccl_available() and
+ pg.name() == c10d.Backend.NCCL
+ )
+
+
def _broadcast_object_list(
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None
):
@@ -65,7 +83,7 @@ def _broadcast_object_list(
c10d._warn_not_in_group("broadcast_object_list")
return
- is_nccl_backend = c10d._check_for_nccl_backend(group)
+ is_nccl_backend = check_for_nccl_backend(group)
current_device = None
if device is not None:
@@ -113,7 +131,7 @@ def _broadcast_object_list(
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
- obj_view = object_tensor[offset : offset + obj_size]
+ obj_view = object_tensor[offset: offset + obj_size]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
@@ -131,6 +149,258 @@ def _broadcast_object_list(
object_list[i] = unpickle_object
+def check_device(group):
+ is_nccl_backend = check_for_nccl_backend(group)
+ current_device = None
+
+ current_device = torch.device("cpu")
+ if is_nccl_backend:
+ current_device = torch.device("cuda", torch.cuda.current_device())
+ return current_device, is_nccl_backend
+
+
+TensorMetadata = namedtuple('TensorMetadata', ['key', 'shape', 'dtype', 'requires_grad'])
+
+
+class P2PDataType(Enum):
+ serialization = 0
+ tensor = 1
+ list = 2
+ dict = 3
+
+
+@dataclass
+class P2PMetadata:
+ data_type: P2PDataType
+ content: Union[List[TensorMetadata], TensorMetadata, Any]
+
+
+def filling_ops_queue(obj, comm_op, comm_rank, ops_queue, group):
+ if isinstance(obj, torch.Tensor):
+ obj = obj.contiguous()
+ op_to_add = dist.P2POp(comm_op, obj, comm_rank, group)
+ ops_queue.append(op_to_add)
+ else:
+ for tensor_to_comm in obj:
+ tensor_to_comm = tensor_to_comm.contiguous()
+ op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank, group)
+ ops_queue.append(op_to_add)
+
+
+def create_recv_buffer(p2p_metadata: P2PMetadata, current_device):
+ if p2p_metadata.data_type == P2PDataType.tensor:
+ metadata = p2p_metadata.content
+ tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
+ return tensor_recv
+ elif p2p_metadata.data_type in (P2PDataType.list, P2PDataType.dict):
+ buffer_recv = []
+ for metadata in p2p_metadata.content:
+ tensor_recv = torch.empty(metadata.shape, requires_grad=metadata.requires_grad, device=current_device, dtype=metadata.dtype)
+ buffer_recv.append(tensor_recv)
+ return buffer_recv
+ else:
+ raise ValueError(f"Unknown data_type: {p2p_metadata.data_type}")
+
+
+def _batch_send_recv_tensor(send_tensor_list, recv_tensor_metadata, send_dst, recv_src, send_group, recv_group, current_device):
+ buffer_recv = None
+ if recv_tensor_metadata is not None:
+ buffer_recv = create_recv_buffer(recv_tensor_metadata, current_device)
+
+ ops = []
+
+ if send_dst is not None:
+ filling_ops_queue(send_tensor_list, dist.isend, send_dst, ops, send_group)
+
+ if recv_src is not None:
+ assert buffer_recv is not None
+ filling_ops_queue(buffer_recv, dist.irecv, recv_src, ops, recv_group)
+
+ if len(ops) > 0:
+ reqs = dist.batch_isend_irecv(ops)
+ for req in reqs:
+ req.wait()
+
+ torch.cuda.synchronize()
+
+ # Remove synchronization according to Pytorch's documentation
+ # However, the Megatron-LM does synchronization here
+ # https://github.com/microsoft/Megatron-DeepSpeed/blob/ef13d099c2a1609225a4ce4c1a1753cc76dd90a1/megatron/p2p_communication.py#L111-L112
+ # In case there is potential error, uncomment the following `torch.cuda.synchronize()`
+ # torch.cuda.synchronize()
+
+ return buffer_recv
+
+
+def _send_recv_serialization_object(
+ object: Any,
+ send_dst: Optional[int], recv_src: Optional[int],
+ send_group: Optional[ProcessGroup], recv_group: Optional[ProcessGroup],
+ current_device,
+ is_nccl_backend):
+ ops = []
+ send_object_tensor = None
+ if object is not None and send_dst is not None:
+ if Version(torch.__version__) >= Version("1.13.0"):
+ send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object, device=current_device)
+ else:
+ send_object_tensor, send_object_size_tensor = c10d._object_to_tensor(object)
+
+ if is_nccl_backend:
+ send_object_size_tensor = send_object_size_tensor.to(current_device)
+ send_object_tensor = send_object_tensor.to(current_device)
+
+ filling_ops_queue(send_object_size_tensor, dist.isend, send_dst, ops, send_group)
+
+ recv_object_size_tensor = None
+ if recv_src is not None:
+ recv_object_size_tensor = torch.empty(1, dtype=torch.long)
+ if is_nccl_backend:
+ recv_object_size_tensor = recv_object_size_tensor.to(current_device)
+ filling_ops_queue(recv_object_size_tensor, dist.irecv, recv_src, ops, recv_group)
+
+ if len(ops) > 0:
+ reqs = dist.batch_isend_irecv(ops)
+ for req in reqs:
+ req.wait()
+
+ torch.cuda.synchronize()
+
+ # See the comment in `_batch_send_recv_tensor`
+ # torch.cuda.synchronize()
+
+ ops = []
+
+ if send_dst is not None and send_object_tensor is not None:
+ filling_ops_queue(send_object_tensor, dist.isend, send_dst, ops, send_group)
+
+ recv_object_tensor = None
+ if recv_src is not None and recv_object_size_tensor is not None:
+ recv_object_tensor = torch.empty(recv_object_size_tensor.item(), dtype=torch.uint8)
+ if is_nccl_backend:
+ recv_object_tensor = recv_object_tensor.to(current_device)
+ filling_ops_queue(recv_object_tensor, dist.irecv, recv_src, ops, recv_group)
+
+ if len(ops) > 0:
+ reqs = dist.batch_isend_irecv(ops)
+ for req in reqs:
+ req.wait()
+
+ torch.cuda.synchronize()
+
+ # See the comment in `_batch_send_recv_tensor`
+ # torch.cuda.synchronize()
+
+ if recv_object_tensor is not None and recv_object_size_tensor is not None:
+ recv_object_tensor = recv_object_tensor.type(torch.uint8)
+ if recv_object_tensor.device != torch.device("cpu"):
+ recv_object_tensor = recv_object_tensor.cpu()
+
+ unpickle_object = _cuda_safe_tensor_to_object(
+ recv_object_tensor, recv_object_size_tensor.item())
+
+ if (
+ isinstance(unpickle_object, torch.Tensor)
+ and unpickle_object.device.index != torch.cuda.current_device()
+ ):
+ unpickle_object = unpickle_object.cuda()
+
+ return unpickle_object
+
+
+def _check_if_fast_send_available(object):
+ if type(object) is torch.Tensor:
+ return True
+ elif type(object) is list:
+ is_list_of_tensor = all([type(v) is torch.Tensor for v in object])
+ return is_list_of_tensor
+ elif type(object) is dict:
+ is_dict_of_tensor = all([type(k) is str and type(
+ v) is torch.Tensor for k, v in object.items()])
+
+ return is_dict_of_tensor
+ return False
+
+
+def _communicate(
+ object,
+ send_dst: Optional[int],
+ recv_src: Optional[int],
+ send_group: Optional[ProcessGroup] = None,
+ recv_group: Optional[ProcessGroup] = None,
+) -> Any:
+ if c10d._rank_not_in_group(send_group) or c10d._rank_not_in_group(recv_group):
+ c10d._warn_not_in_group("_communicate")
+ return
+
+ current_send_device, is_send_nccl_backend = check_device(send_group)
+ current_recv_device, is_recv_nccl_backend = check_device(recv_group)
+
+ is_nccl_backend = is_send_nccl_backend and is_recv_nccl_backend
+
+ assert current_send_device == current_recv_device
+ current_device = current_send_device
+
+ assert (send_dst is not None) or (recv_src is not None)
+
+ can_fast_send = False
+ send_metadata = None
+ if send_dst is not None:
+ can_fast_send = _check_if_fast_send_available(object) and is_nccl_backend
+ if not can_fast_send:
+ send_metadata = P2PMetadata(P2PDataType.serialization, object)
+ else:
+ if type(object) is torch.Tensor:
+ data_type = P2PDataType.tensor
+ content = TensorMetadata(None, object.shape, object.dtype, object.requires_grad)
+ elif type(object) is list:
+ data_type = P2PDataType.list
+ content = []
+ for v in object:
+ content.append(TensorMetadata(None, v.shape, v.dtype, v.requires_grad))
+ elif type(object) is dict:
+ data_type = P2PDataType.dict
+ content = []
+ for k, v in object.items():
+ content.append(TensorMetadata(k, v.shape, v.dtype, v.requires_grad))
+ else:
+ raise ValueError('Cannot send object of type {}'.format(type(object)))
+ send_metadata = P2PMetadata(data_type, content)
+
+ recv_metadata = _send_recv_serialization_object(send_metadata, send_dst, recv_src, send_group, recv_group, current_device, is_nccl_backend)
+ if recv_metadata is not None:
+ assert type(recv_metadata) is P2PMetadata
+ if recv_metadata.data_type == P2PDataType.serialization:
+ return recv_metadata.content
+ if not can_fast_send and send_dst is not None:
+ return
+
+ send_tensor_list = None
+ if type(object) is torch.Tensor:
+ send_tensor_list = object
+ elif type(object) is list:
+ send_tensor_list = object
+ elif type(object) is dict:
+ send_tensor_list = list(object.values())
+
+ recv_buffer = _batch_send_recv_tensor(send_tensor_list, recv_metadata, send_dst, recv_src, send_group, recv_group, current_device)
+
+ if recv_metadata is not None:
+ assert recv_buffer is not None
+ if recv_metadata.data_type in [P2PDataType.tensor, P2PDataType.list]:
+ return recv_buffer
+ elif recv_metadata.data_type == P2PDataType.dict:
+ return {
+ k: v
+ for k, v in zip(
+ [m.key for m in recv_metadata.content],
+ recv_buffer,
+ )
+ }
+ else:
+ raise ValueError('Unknown data type {}'.format(recv_metadata.data_type))
+
+
def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
"""send anything to dst rank
@@ -141,8 +411,7 @@ def _send_object(object: Any, src: int, dst: int, group: ProcessGroup) -> None:
Returns:
None
"""
- # then broadcast safely
- _broadcast_object_list([object], src, group)
+ _communicate(object, send_dst=dst, recv_src=None, send_group=group)
def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
@@ -154,10 +423,7 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
Returns:
Any: Object received from src.
"""
- object_list = [None]
- _broadcast_object_list(object_list, src, group)
-
- return object_list[0]
+ return _communicate(None, send_dst=None, recv_src=src, recv_group=group)
def _p2p_comm(
@@ -302,6 +568,64 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
+ def send_forward_recv_backward(self, input_object: Any, next_rank: int = None) -> Any:
+ """Sends the gradient tensor to and copy the gradient tensor from the next stage in pipeline
+
+ Args:
+ input_object (Any): Object to be sent.
+ next_rank (int, optional): The rank of the sender and recipient of the tensor
+ """
+ if next_rank is None:
+ next_rank = self.stage_manager.get_next_rank()
+
+ cur_rank = self.stage_manager.get_rank()
+ group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
+ return _communicate(
+ input_object, next_rank, next_rank,
+ send_group=group, recv_group=group,
+ )
+
+ def send_backward_recv_forward(self, input_object: Any, prev_rank: int = None) -> Any:
+ """Sends the gradient tensor to and copy the gradient tensor from the previous stage in pipeline
+
+ Args:
+ input_object (Any): Object to be sent.
+ prev_rank (int, optional): The rank of the sender and recipient of the tensor
+ """
+ if prev_rank is None:
+ prev_rank = self.stage_manager.get_prev_rank()
+
+ cur_rank = self.stage_manager.get_rank()
+ group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
+ return _communicate(
+ input_object, prev_rank, prev_rank,
+ send_group=group, recv_group=group,
+ )
+
+ def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
+ """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
+
+ Args:
+ input_object (Any): Object to be sent.
+ prev_rank (int, optional): The rank of the sender of the tensor
+ next_rank (int, optional): The rank of the recipient of the tensor
+ """
+ if prev_rank is None:
+ prev_rank = self.stage_manager.get_prev_rank()
+ if next_rank is None:
+ next_rank = self.stage_manager.get_next_rank()
+
+ cur_rank = self.stage_manager.get_rank()
+ recv_group = self.stage_manager.get_p2p_process_group(prev_rank, cur_rank)
+ send_group = self.stage_manager.get_p2p_process_group(cur_rank, next_rank)
+ return _communicate(
+ input_object,
+ send_dst=next_rank,
+ recv_src=prev_rank,
+ send_group=send_group,
+ recv_group=recv_group,
+ )
+
def p2p_communicate(
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
) -> None:
diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py
index 1f4bbe9f8dad..72480526bd5c 100644
--- a/colossalai/pipeline/schedule/generate.py
+++ b/colossalai/pipeline/schedule/generate.py
@@ -7,10 +7,10 @@
from torch.nn import Module
from torch.utils._pytree import tree_map
-from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status
+from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule
@@ -69,6 +69,8 @@ def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None)
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
+ if self.stage_manager.num_stages == 1:
+ self.microbatch_size = self.batch_size
self.microbatch_offset = 0
assert (
self.batch_size % self.microbatch_size == 0
@@ -93,9 +95,7 @@ def _prepare_inputs_for_interval_stage(self):
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
"""
- model_inputs = (
- {"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
- )
+ model_inputs = {"infer_state": self.mb_manager.cur_descrption.infer_state}
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
@@ -108,9 +108,8 @@ def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
"""
new_mask = self.mb_manager.cur_descrption.attn_mask
- past_key_values = self.mb_manager.cur_descrption.kv_cache
- return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values)
+ return dict(input_ids=new_token, attention_mask=new_mask)
def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor:
last_hidden_state = hidden_state[:, -1]
@@ -128,27 +127,38 @@ def _recv_pre_stage(self) -> Any:
return self.comm.p2p_recv()
return self.comm.recv_forward()
+ def _init_infer_state_action(self) -> None:
+ """
+ This action is only for no first stage, to load batch and init infer_state.
+ 1.Load micro_batch 2.Use the current micro_batch to init the current infer_state
+ """
+ inputs_dict = self.load_micro_batch()
+ self.mb_manager.add_descrption(inputs_dict)
+
def _load_stage_action(self, model: Module) -> None:
"""
- In this action, 1.load micro_batch 2.do the forward 3.step to update
+ This action is only for first stage, load, init and do forward.
+ 1.load micro_batch 2.do the forward 3.step to update
"""
inputs_dict = self.load_micro_batch()
+ self.mb_manager.add_descrption(inputs_dict)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
- output_dict = model_forward(model, inputs_dict, None)
+ interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
+ output_dict = model_forward(model, inputs_dict, interval_inputs)
- self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _gen_token_action(self, model: Module):
"""
- In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update
+ This action is only for first stage
+ 1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
- hidden_states = {"hidden_states": hidden_states}
- logits = model_forward(model, None, hidden_states)
+ interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
+ logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
@@ -157,7 +167,7 @@ def _gen_token_action(self, model: Module):
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
- self.mb_manager.step(None, None, new_token)
+ self.mb_manager.step(new_token)
self.action_interval_buffer.new_token = new_token
self.action_interval_buffer.hidden_states = None
@@ -168,19 +178,17 @@ def _head_encoding_action(self, model: Module):
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
- output_dict = model_forward(model, inputs_dict, None)
+ interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
+ output_dict = model_forward(model, inputs_dict, interval_inputs)
- self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
- inputs_dict = self._prepare_inputs_for_interval_stage()
- hidden_states = {"hidden_states": hidden_states}
- output_dict = model_forward(model, inputs_dict, hidden_states)
+ interval_inputs = {"hidden_states": hidden_states, "infer_state": self.mb_manager.cur_infer_state}
+ output_dict = model_forward(model, None, interval_inputs)
- self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
@@ -218,17 +226,80 @@ def _gen_action(self, model: Module):
actions.append(partial(self._gen_token_action, model))
# other stage
else:
+ if self.mb_manager.cur_state is Status.PREFILL:
+ actions.append(partial(self._init_infer_state_action))
actions.append(partial(self._comm_action, True))
actions.append(partial(self._body_encoding_action, model))
return actions
+ def _gen_one_stage_action(self, model: Module):
+ """
+ In this function, it will generate a sequence action for current state, and do the action one by one.
+
+ Args:
+ model (Module): Model to be run.
+
+ Returns:
+ List[Callable]: A list of action, each action is a callable function, and it will be called in order.
+ """
+ actions = []
+
+ if self.mb_manager.cur_state is Status.PREFILL:
+ actions.append(partial(self._load_stage_action, model))
+ elif self.mb_manager.cur_state is Status.GENERATE:
+ actions.append(partial(self._gen_token_action, model))
+ actions.append(partial(self._head_encoding_action, model))
+ elif self.mb_manager.cur_state is Status.COOLDOWN:
+ actions.append(partial(self._gen_token_action, model))
+
+ return actions
+
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
- if self.stage_manager.num_stages == 2:
+ if self.stage_manager.num_stages == 1:
+ return self.generate_step_one_stage(model, data_iter)
+ elif self.stage_manager.num_stages == 2:
return self.generate_step_p2p(model, data_iter)
else:
return self.generate_step_broadcast(model, data_iter)
+ @torch.no_grad()
+ def generate_step_one_stage(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
+ """
+ Forward one step of the pipeline, when pipeline size is 1.
+
+ Args:
+ model (Module): Model to be run.
+ data_iter (Iterable): Data iterator.
+
+ Returns:
+ Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
+ """
+ output_sequence = []
+ self.load_batch(data_iter)
+ model.eval()
+ self.comm_dtype = model.dtype
+
+ whole_timestamp = []
+
+ # run by round
+ for _ in range(self.round):
+ self.timestamps = [[] for _ in range(self.stage_manager.num_stages)] if self.verbose else None
+ self.action_interval_buffer.clear()
+ while self.mb_manager.is_micro_batch_done() is False:
+ actions = self._gen_one_stage_action(model)
+ for action in actions:
+ action()
+ self.mb_manager.next()
+ # All microbatch in current round is DONE
+ output_sequence.extend(self.mb_manager.export_new_tokens())
+
+ self.mb_manager.clear()
+ if self.verbose:
+ whole_timestamp.extend(self.timestamps)
+
+ return output_sequence, whole_timestamp
+
@torch.no_grad()
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
@@ -308,8 +379,9 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
- output_dict = model_forward(model, inputs_dict, None)
- self.mb_manager.step(inputs_dict, output_dict, None)
+ self.mb_manager.add_descrption(inputs_dict)
+ interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
+ output_dict = model_forward(model, inputs_dict, interval_inputs)
# In GENERATE phase
else:
# Get hidden_states from previous stage
@@ -319,7 +391,11 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
- logits = model_forward(model, None, hidden_states)
+ interval_inputs = {
+ "hidden_states": hidden_states["hidden_states"],
+ "infer_state": self.mb_manager.cur_infer_state,
+ }
+ logits = model_forward(model, None, interval_inputs)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
@@ -327,17 +403,24 @@ def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[t
"logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
- self.mb_manager.step(None, None, new_token)
+ self.mb_manager.step(new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
- output_dict = model_forward(model, inputs_dict, None)
- self.mb_manager.step(inputs_dict, output_dict, None)
+ interval_inputs = {"infer_state": self.mb_manager.cur_infer_state}
+ output_dict = model_forward(model, inputs_dict, interval_inputs)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
- inputs_dict = self._prepare_inputs_for_interval_stage()
- output_dict = model_forward(model, inputs_dict, hidden_states)
- self.mb_manager.step(inputs_dict, output_dict, None)
+ # inputs_dict = self._prepare_inputs_for_interval_stage()
+ inputs_dict = None
+ if self.mb_manager.cur_state is Status.PREFILL:
+ inputs_dict = self.load_micro_batch()
+ self.mb_manager.add_descrption(inputs_dict)
+ interval_inputs = {
+ "hidden_states": hidden_states["hidden_states"],
+ "infer_state": self.mb_manager.cur_infer_state,
+ }
+ output_dict = model_forward(model, inputs_dict, interval_inputs)
# Current microbatch is not DONE, send hidden_state to next stage
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (
diff --git a/colossalai/pipeline/schedule/interleaved_pp.py b/colossalai/pipeline/schedule/interleaved_pp.py
index 780437155c61..cbf6dd80f3e0 100644
--- a/colossalai/pipeline/schedule/interleaved_pp.py
+++ b/colossalai/pipeline/schedule/interleaved_pp.py
@@ -9,7 +9,7 @@
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py
index 4eaf135fd5db..fd918cf1921c 100644
--- a/colossalai/pipeline/schedule/one_f_one_b.py
+++ b/colossalai/pipeline/schedule/one_f_one_b.py
@@ -9,7 +9,7 @@
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from ._utils import (
detach,
@@ -127,6 +127,17 @@ def send_forward(self, output_object: Any, next_rank: int = None) -> None:
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
+ def send_forward_recv_backward(self, output_object: Any, next_rank: int = None) -> Any:
+ """Sends the input tensor to the next stage and copy the gradient tensor from the next stage in pipeline.
+ For 1F1B.
+
+ Args:
+ output_object (Any): Object to be sent.
+ next_rank (int, optional): The rank of the recipient of the tensor.
+ """
+ if not self.stage_manager.is_last_stage():
+ return self.comm.send_forward_recv_backward(output_object, next_rank)
+
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
@@ -138,6 +149,33 @@ def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
+ def send_backward_recv_forward(self, output_object: Any, prev_rank: int = None) -> Any:
+ """Sends the gradient tensor to the previous stage and copy the input tensor from the previous stage in pipeline.
+ For 1F1B.
+
+ Args:
+ output_object (Any): Object to be sent.
+ prev_rank (int, optional): The rank of the recipient of the tensor.
+ """
+ if not self.stage_manager.is_first_stage():
+ return self.comm.send_backward_recv_forward(output_object, prev_rank)
+
+ def send_forward_recv_forward(self, input_object: Any, prev_rank: int = None, next_rank: int = None) -> Any:
+ """Sends the input tensor to the next stage and copy the input tensor from the previous stage in pipeline.
+ For 1F1B.
+
+ Args:
+ input_object (Any): Object to be sent.
+ prev_rank (int, optional): The previous rank of the recipient of the tensor.
+ next_rank (int, optional): The next rank of the recipient of the tensor.
+ """
+ if self.stage_manager.is_first_stage():
+ return self.comm.send_forward(input_object, next_rank)
+ elif self.stage_manager.is_last_stage():
+ return self.comm.recv_forward(prev_rank)
+ else:
+ return self.comm.send_forward_recv_forward(input_object, prev_rank, next_rank)
+
def forward_step(
self,
model: Module,
@@ -291,7 +329,6 @@ def forward_backward_step(
if not last_iteration:
input_obj = self.recv_forward()
-
else:
# TODO adjust here
self.send_forward(output_obj)
diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md
index 63b28701e879..cf06eecd3c0c 100644
--- a/colossalai/shardformer/README.md
+++ b/colossalai/shardformer/README.md
@@ -81,7 +81,7 @@ Following are the description `ShardConfig`'s arguments:
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
-- `inference_only`: Whether only doing forward passing. Defaults to False.
+- `extra_kwargs`: A dict to store extra kwargs for ShardFomer.
### Write your own policy
@@ -185,8 +185,8 @@ class ShardConfig:
# Some possible future config fields
tensor_parallel_mode: Choice['1d', '2d', '2.5d', '3d'] # support different tensor parallel mode
- inference_only: bool # only inject inference-suitable sharding policy
use_flash_attention: bool # whether to use flash attention to speed up attention
+ extra_kwargs: Dict[str, Any] # extra kwargs for the shardformer
```
### Policy
@@ -235,6 +235,14 @@ class SubModuleReplacementDescription:
class Policy(ABC):
+ r"""
+ The base class for all the policies. For each different model, it should have a different policy class,
+ like BertPolicy for Bert Model or OPTPolicy for OPT model.
+
+ Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
+ built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
+ If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
+ """
def __init__(self)
self.model = None
@@ -245,6 +253,16 @@ class Policy(ABC):
"""
self.model = model
+ def set_shard_config(self, shard_config: ShardConfig) -> None:
+ r"""
+ Set shard config as an attribute of the Policy object.
+ Args:
+ shard_config (:class:`ShardConfig`): The shard config to be perform
+ """
+ self.shard_config = shard_config
+
+ self.config_sanity_check()
+
@abstractmethod
def preprocess(self) -> nn.Module:
"""
diff --git a/colossalai/shardformer/layer/__init__.py b/colossalai/shardformer/layer/__init__.py
index a134a2cbd21c..56e8b08c4e4a 100644
--- a/colossalai/shardformer/layer/__init__.py
+++ b/colossalai/shardformer/layer/__init__.py
@@ -2,7 +2,7 @@
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d
-from .normalization import FusedLayerNorm, FusedRMSNorm
+from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
@@ -16,6 +16,9 @@
"DropoutForParallelInput",
"DropoutForReplicatedInput",
"cross_entropy_1d",
+ "BaseLayerNorm",
+ "LayerNorm",
+ "RMSNorm",
"FusedLayerNorm",
"FusedRMSNorm",
"FusedLinear1D_Col",
diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py
index 5ec48096183b..0d8c3d453ce1 100644
--- a/colossalai/shardformer/layer/_operation.py
+++ b/colossalai/shardformer/layer/_operation.py
@@ -53,7 +53,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
- ctx.save_for_backward(input_, weight)
+ ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
@@ -62,13 +62,18 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
if bias is not None:
output = output + bias
+
return output
@staticmethod
def backward(ctx, grad_output):
- input, weight = ctx.saved_tensors
+ input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
+ # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
+ weight = weight.view(weight.shape)
+ bias = bias.view(bias.shape)
+
total_input = input
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
@@ -100,7 +105,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
- ctx.save_for_backward(input_, weight)
+ ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce
@@ -109,13 +114,18 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
output = F.linear(input_, weight, bias)
else:
output = F.linear(input_, weight)
+
return output
@staticmethod
def backward(ctx, grad_output):
- input, weight = ctx.saved_tensors
+ input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
+ # In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
+ if use_bias:
+ bias.view(bias.shape)
+
total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
@@ -152,7 +162,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
- ctx.save_for_backward(input_, weight)
+ ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
@@ -170,12 +180,16 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,
@staticmethod
def backward(ctx, grad_output):
- input_, weight = ctx.saved_tensors
+ input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
+ # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
+ if use_bias:
+ bias = bias.view(bias.shape)
+
if not overlap:
input_parallel = _gather(input_, dim, process_group)
@@ -289,7 +303,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
- ctx.save_for_backward(input_, weight)
+ ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
@@ -306,12 +320,17 @@ def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter,
@staticmethod
def backward(ctx, grad_output):
- input_, weight = ctx.saved_tensors
+ input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
+ # In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
+ weight = weight.view(weight.shape)
+ if use_bias:
+ bias = bias.view(bias.shape)
+
if not overlap:
input_parallel = _gather(input_, dim, process_group)
@@ -454,6 +473,29 @@ def forward(ctx, input_, dim, process_group):
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
+
+
+class HookParameter(torch.autograd.Function):
+ """In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ ctx.save_for_backward(weight, bias)
+ output = input
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ weight, bias = ctx.saved_tensors
+ if weight is not None:
+ weight = weight.view(weight.shape)
+ if bias is not None:
+ bias = bias.view(bias.shape)
+ return grad_output, None, None
+
+
+def hook_paramter_in_backward(input, weight=None, bias=None):
+ return HookParameter.apply(input, weight, bias)
+
def _reduce(input_, process_group):
diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py
index 62163cb009aa..d081b204093b 100644
--- a/colossalai/shardformer/layer/embedding.py
+++ b/colossalai/shardformer/layer/embedding.py
@@ -309,7 +309,8 @@ def forward(self, input_: Tensor) -> Tensor:
)
# Mask the output embedding.
- output_parallel[input_mask, :] = 0.0
+ embedding_output = output_parallel.clone()
+ embedding_output[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
- output = reduce_forward(output_parallel, self.process_group)
+ output = reduce_forward(embedding_output, self.process_group)
return output
diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py
index cf2003877d3c..9e638622348e 100644
--- a/colossalai/shardformer/layer/linear.py
+++ b/colossalai/shardformer/layer/linear.py
@@ -149,7 +149,6 @@ def from_native_module(
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
-
# ensure only one process group is passed
if isinstance(process_group, (list, tuple)):
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
index 19b973be8679..8387bb5e365e 100644
--- a/colossalai/shardformer/layer/normalization.py
+++ b/colossalai/shardformer/layer/normalization.py
@@ -1,11 +1,48 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+import warnings
+from abc import ABC, abstractmethod
import torch.nn as nn
from colossalai.lazy import LazyInitContext
-__all__ = ["FusedLayerNorm", "FusedRMSNorm"]
+from ._operation import hook_paramter_in_backward
+from .utils import SeqParallelUtils
+
+__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
+
+try:
+ from apex.contrib.layer_norm.layer_norm import FastLayerNorm
+
+ EnableFastLayerNorm = True
+except ImportError:
+ EnableFastLayerNorm = False
+
+try:
+ from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
+ from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
+
+ class FusedLayerNormWithHook(ApexFusedLayerNorm):
+ def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
+ super().__init__(normalized_shape, eps, elementwise_affine)
+
+ def forward(self, input):
+ output = super().forward(input)
+ output = hook_paramter_in_backward(output, self.weight, self.bias)
+ return output
+
+ class FusedRMSNormWithHook(ApexFusedRMSNorm):
+ def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
+ super().__init__(normalized_shape, eps, elementwise_affine)
+
+ def forward(self, input):
+ output = super().forward(input)
+ output = hook_paramter_in_backward(output, self.weight)
+ return output
+
+except ImportError:
+ warnings.warn("Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel")
FAST_LAYERNORM_SUPPORTED_SIZE = [
1024,
@@ -34,8 +71,115 @@
65536,
]
+if EnableFastLayerNorm:
+
+ class FastLayerNormWithHook(FastLayerNorm):
+ def __init__(self, hidden_size, eps=0.00001):
+ super().__init__(hidden_size, eps)
+
+ def forward(self, input):
+ output = super().forward(input)
+ output = hook_paramter_in_backward(output, self.weight, self.bias)
+ return output
+
+
+class BaseLayerNorm(ABC):
+ @abstractmethod
+ def from_native_module(module: nn.Module, sp_partial_derived: bool = False):
+ """
+ Convert a native PyTorch layer normalization module to a specific layer normalization module,
+ and optionally mark parameters for gradient aggregation.
+
+ Args:
+ module (nn.Module): The native PyTorch layer normalization module to be converted.
+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
+
+ Returns:
+ nn.Module: The specific layer normalization module.
+
+ Raises:
+ AssertionError: If the provided module is not an instance of the supported layer normalization type.
+ """
+
+
+class RMSNorm(BaseLayerNorm):
+ r"""
+ This is a wrapper around the RMSNorm. It is meant to be used only with the from_native_module interface.
+ """
+
+ def __init__(self) -> None:
+ raise NotImplementedError(
+ "FusedLayerNorm is not implemented as a physical class. "
+ "It is meant to be used only with the from_native_module interface to convert a native RMSNorm module to colossalai layer norm module."
+ )
+
+ @staticmethod
+ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
+ """
+ Convert a native RMSNorm module to colossalai layer norm module,
+ and optionally mark parameters for gradient aggregation.
+
+ Args:
+ module (nn.Module): The native RMSNorm module to be converted.
+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
+
+ Returns:
+ nn.Module: The RMSNorm module.
+ """
+
+ LazyInitContext.materialize(module)
+
+ if sp_partial_derived:
+ # Since gradients are computed using only a subset of the data,
+ # aggregation of these gradients is necessary during backpropagation.
+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
+ SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
+
+ return module
+
-class FusedLayerNorm:
+class LayerNorm(BaseLayerNorm):
+ r"""
+ This is a wrapper around the torch.nn.LayerNorm. It is meant to be used only with the from_native_module interface.
+ """
+
+ def __init__(self) -> None:
+ raise NotImplementedError(
+ "LayerNorm is not implemented as a physical class. "
+ "It is meant to be used only with the from_native_module interface to convert a native pytorch layer norm module to colossalai layer norm module."
+ )
+
+ @staticmethod
+ def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
+ r"""
+ Convert a native pytorch layer norm module to colossalai layer norm module,
+ and optionally marking parameters for gradient aggregation.
+
+ Args:
+ module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
+
+ Returns:
+ nn.Module: The LayerNorm module.
+
+ Raises:
+ AssertionError: If the provided module is not an instance of nn.LayerNorm.
+ """
+ assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
+
+ LazyInitContext.materialize(module)
+
+ if sp_partial_derived:
+ # Since gradients are computed using only a subset of the data,
+ # aggregation of these gradients is necessary during backpropagation.
+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
+ SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
+ SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
+
+ return module
+
+
+class FusedLayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the apex fused layernorm implementation. It is meant to be used only with the from_native_module interface.
"""
@@ -43,21 +187,25 @@ class FusedLayerNorm:
def __init__(self) -> None:
raise NotImplementedError(
"FusedLayerNorm is not implemented as a physical class. "
- "It is meant to be used only with the from_native_module interface to wrap the fused layernorm implementation provided by apex."
+ "It is meant to be used only with the from_native_module interface convert a native pytorch layer norm module to FusedLayerNorm module provided by apex."
)
@staticmethod
- def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
+ def from_native_module(module: nn.LayerNorm, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r"""
- Convert a native pytorch layer norm module to colossalai layer norm module
+ Convert a native pytorch layer norm module to FusedLayerNorm module provided by apex,
+ and optionally marking parameters for gradient aggregation.
+
+ Args:
+ module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
+
+ Returns:
+ nn.Module: Union[FastLayerNorm, FusedLayerNorm].
+
+ Raises:
+ AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
- # check if apex is installed
- try:
- pass
- except ImportError:
- raise ImportError(
- "Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
- )
LazyInitContext.materialize(module)
# get the attributes of the module
@@ -71,24 +219,31 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
if use_fast_ln:
- try:
- from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm
- except ImportError:
+ if EnableFastLayerNorm:
+ ApexFusedLayerNorm = FastLayerNormWithHook
+ else:
# fall back to the normal fused layernorm is not built
- from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
+ ApexFusedLayerNorm = FusedLayerNormWithHook
else:
- from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
+ ApexFusedLayerNorm = FusedLayerNormWithHook
layernorm = (
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
)
-
layernorm.weight = module.weight
layernorm.bias = module.bias
+
+ if sp_partial_derived:
+ # Since gradients are computed using only a subset of the data,
+ # aggregation of these gradients is necessary during backpropagation.
+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
+ SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
+ SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
+
return layernorm
-class FusedRMSNorm:
+class FusedRMSNorm(BaseLayerNorm):
"""
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
"""
@@ -96,13 +251,24 @@ class FusedRMSNorm:
def __init__(self) -> None:
raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. "
- "It is meant to be used only with the from_native_module interface to wrap the fused rms norm implementation provided by apex."
+ "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
)
@staticmethod
- def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
+ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
+ r"""
+ Convert a native RMSNorm module module to FusedRMSNorm module provided by apex,
+ and optionally marking parameters for gradient aggregation.
+
+ Args:
+ module (nn.LayerNorm): The native PyTorch LayerNorm module to be converted.
+ sp_partial_derived (bool): Whether this module's gradients are partially derived in sequence parallelism.
+
+ Returns:
+ nn.Module: FusedRMSNorm module.
+ """
try:
- from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
+ pass
except ImportError:
raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused RMS normalization kernel"
@@ -120,8 +286,16 @@ def from_native_module(module: nn.Module, *args, **kwargs) -> nn.Module:
eps = module.eps
elementwise_affine = module.elementwise_affine
- rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
+ rmsnorm = FusedRMSNormWithHook(
+ normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
+ )
rmsnorm.weight = module.weight
+ if sp_partial_derived:
+ # Since gradients are computed using only a subset of the data,
+ # aggregation of these gradients is necessary during backpropagation.
+ # Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
+ SeqParallelUtils.marked_as_sp_partial_derived_param(rmsnorm.weight)
+
return rmsnorm
diff --git a/colossalai/shardformer/layer/utils.py b/colossalai/shardformer/layer/utils.py
index c3d8501cdeae..4b6343adcd3b 100644
--- a/colossalai/shardformer/layer/utils.py
+++ b/colossalai/shardformer/layer/utils.py
@@ -1,8 +1,83 @@
from contextlib import contextmanager
+from typing import List
import torch
import torch.distributed as dist
-from torch.distributed import ProcessGroup
+from torch import nn
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+from torch.distributed import ProcessGroup, get_world_size
+from colossalai.utils.device import get_current_device, get_rng_state, set_rng_state, manual_seed
+
+
+class SeqParallelUtils:
+ @staticmethod
+ def marked_as_sp_partial_derived_param(param):
+ """
+ Mark a parameter as partially derived in sequence parallelism.
+
+ Args:
+ param: The parameter to mark as partially derived.
+ """
+ setattr(param, "partial_derived", True)
+
+ @staticmethod
+ def is_sp_partial_derived_param(param):
+ """
+ Check if a parameter is marked as partially derived in sequence parallelism.
+
+ Args:
+ param: The parameter to check.
+
+ Returns:
+ bool: True if the parameter is marked as partially derived, False otherwise.
+ """
+ return getattr(param, "partial_derived", False)
+
+ @staticmethod
+ def allreduce_partial_data_grad(tp_group: ProcessGroup, model: nn.Module = None, grads: List[torch.Tensor] = None):
+ """
+ Allreduce partial derived gradients across the specified process group.
+
+ This function performs gradient synchronization for parameters that are marked as partially derived in sequence parallelism.
+
+ Args:
+ tp_group (ProcessGroup): The process group for gradient synchronization.
+ model (nn.Module): The model from which gradients will be synchronized.
+ grads (List[torch.Tensor]): The list of gradients to be synchronized.
+
+ Raises:
+ AssertionError: If both `model` and `grads` are provided or neither is provided.
+ """
+ # Ensure that exactly one of `model` and `grads` is provided for gradient synchronization.
+ assert (model is not None) ^ (grads is not None), "Exactly one of model and grads must be not None."
+
+ # Get the size of the process group, which determines whether synchronization is needed.
+ tp_size = get_world_size(tp_group) if tp_group is not None else 1
+
+ if tp_size == 1:
+ # If the process group size is 1, no synchronization is required.
+ return
+
+ if model is not None:
+ # If `model` is provided, extract partial derived gradients from the model's parameters.
+ grads = []
+ for p in model.parameters():
+ if p.grad is not None and SeqParallelUtils.is_sp_partial_derived_param(p):
+ grads.append(p.grad.data)
+
+ # Flatten and reduce the gradients using the specified process group.
+ coalesced = _flatten_dense_tensors(grads)
+ dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
+
+ # Unflatten the synchronized gradients and update the model's gradients.
+ for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
+ buf.copy_(synced)
+ else:
+ # If `grads` are provided explicitly, synchronize those gradients directly.
+ coalesced = _flatten_dense_tensors(grads)
+ dist.all_reduce(coalesced, op=dist.ReduceOp.SUM, group=tp_group)
+ for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
+ buf.copy_(synced)
class Randomizer:
@@ -30,14 +105,14 @@ class Randomizer:
def __init__(self, seed: int):
self.seed = seed
- # Handle CUDA rng state
+ # Handle device rng state
# 1. get the current rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
- cuda_original_rng_state = torch.cuda.get_rng_state()
- torch.cuda.manual_seed(seed)
- self.cuda_rng_state = torch.cuda.get_rng_state()
- torch.cuda.set_rng_state(cuda_original_rng_state)
+ device_original_rng_state = get_rng_state()
+ manual_seed(seed)
+ self.device_rng_state = get_rng_state()
+ set_rng_state(device_original_rng_state)
# to the same for cpu rng state
cpu_original_rng_state = torch.get_rng_state()
@@ -45,11 +120,11 @@ def __init__(self, seed: int):
self.cpu_rng_state = torch.get_rng_state()
torch.set_rng_state(cpu_original_rng_state)
- def _set_cuda_rng_state(self, rng_state):
- torch.cuda.set_rng_state(rng_state)
+ def _set_device_rng_state(self, rng_state):
+ set_rng_state(rng_state)
- def _get_cuda_rng_state(self):
- current_state = torch.cuda.get_rng_state()
+ def _get_device_rng_state(self):
+ current_state = get_rng_state()
return current_state
def _set_cpu_rng_state(self, rng_state):
@@ -70,16 +145,16 @@ def fork_rng(self, enable_cpu: bool = False):
>>> input = super().forward(input)
"""
try:
- current_cuda_rng_state = self._get_cuda_rng_state()
- self._set_cuda_rng_state(self.cuda_rng_state)
+ current_device_rng_state = self._get_device_rng_state()
+ self._set_device_rng_state(self.device_rng_state)
if enable_cpu:
current_cpu_rng_state = self._get_cpu_rng_state()
self._set_cpu_rng_state(self.cpu_rng_state)
yield
finally:
- self.cuda_rng_state = self._get_cuda_rng_state()
- self._set_cuda_rng_state(current_cuda_rng_state)
+ self.device_rng_state = self._get_device_rng_state()
+ self._set_device_rng_state(current_device_rng_state)
if enable_cpu:
self.cpu_rng_state = self._get_cpu_rng_state()
@@ -134,7 +209,7 @@ def is_randomizer_index_synchronized(process_group: ProcessGroup = None):
index = Randomizer.index()
if dist.is_initialized():
# convert the index to tensor
- index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
+ index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
@@ -156,7 +231,7 @@ def synchronize_index(process_group: ProcessGroup = None):
if dist.is_initialized():
# convert the index to tensor
- index_tensor = torch.tensor(index, dtype=torch.int32).cuda()
+ index_tensor = torch.tensor(index, dtype=torch.int32, device=get_current_device())
# all gather the index
gathered_index = [torch.zeros_like(index_tensor) for _ in range(dist.get_world_size(process_group))]
diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py
index 1bf87e80a461..cd8a023306dc 100644
--- a/colossalai/shardformer/modeling/bloom.py
+++ b/colossalai/shardformer/modeling/bloom.py
@@ -719,7 +719,7 @@ def forward(
):
fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
- batch_size, tgt_len, _ = query_layer.size()
+ batch_size, tgt_len, _, _ = query_layer.size()
_, kv_length, _, _ = key_layer.size()
@@ -755,6 +755,7 @@ def forward(
attention_numerical_mask = torch.masked_fill(
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
)
+ attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype)
context_layer = me_attention(
query_layer,
diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py
index 8934068d609c..c8a311df7c6d 100644
--- a/colossalai/shardformer/modeling/chatglm2.py
+++ b/colossalai/shardformer/modeling/chatglm2.py
@@ -51,7 +51,8 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_
attn_mask_type = AttnMaskType.causal
else:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
- attn_mask_type = AttnMaskType.paddedcausal
+ if not torch.all(flash_attention_mask):
+ attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(
embed_dim=self.hidden_size_per_partition,
diff --git a/colossalai/shardformer/modeling/chatglm2_6b/__init__.py b/colossalai/shardformer/modeling/chatglm2_6b/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
index fdd49ecfeae5..71aa2296eb4c 100644
--- a/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
+++ b/colossalai/shardformer/modeling/chatglm2_6b/modeling_chatglm.py
@@ -400,7 +400,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
)
self.core_attention = CoreAttention(config, self.layer_number)
-
# Output.
self.dense = nn.Linear(
self.projection_size,
diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py
index 21f06393071d..8f456353742c 100644
--- a/colossalai/shardformer/modeling/gpt2.py
+++ b/colossalai/shardformer/modeling/gpt2.py
@@ -771,11 +771,12 @@ def forward(
attn_mask_type = AttnMaskType.causal
flash_attention_mask = None
if attention_mask != None:
- if attn_mask_type == AttnMaskType.causal:
- attn_mask_type == AttnMaskType.paddedcausal
- else:
- attn_mask_type = AttnMaskType.padding
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
+ if not torch.all(flash_attention_mask):
+ if attn_mask_type == AttnMaskType.causal:
+ attn_mask_type == AttnMaskType.paddedcausal
+ else:
+ attn_mask_type = AttnMaskType.padding
scale = value.size(-1) ** -0.5
if self.scale_attn_by_inverse_layer_idx:
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index 4b6c8342534a..616c9220f4ab 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -1,5 +1,5 @@
import warnings
-from typing import List, Optional, Tuple
+from typing import List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -13,6 +13,11 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
+try:
+ from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
+ LATEST_VERSION = True
+except ImportError:
+ LATEST_VERSION = False
class LlamaPipelineForwards:
"""
@@ -97,9 +102,14 @@ def llama_model_forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
- attention_mask = self._prepare_decoder_attention_mask(
- attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
- )
+ if LATEST_VERSION:
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+ )
+ else:
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+ )
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -403,8 +413,6 @@ def get_llama_flash_attention_forward():
warnings.warn("using llamav1, llamav1 hasn't repeat_kv function")
llama_version = 1
- from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
-
def forward(
self: LlamaAttention,
hidden_states: torch.Tensor,
@@ -413,6 +421,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
+ **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
@@ -454,7 +463,8 @@ def forward(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
- attn_mask_type = AttnMaskType.paddedcausal
+ if not torch.all(flash_attention_mask):
+ attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(embed_dim=self.hidden_size, num_heads=self.num_heads)
attn_output = attention(
diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py
index e0978d38e110..71f2ca3353bc 100644
--- a/colossalai/shardformer/modeling/opt.py
+++ b/colossalai/shardformer/modeling/opt.py
@@ -581,7 +581,8 @@ def forward(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
- attn_mask_type = AttnMaskType.paddedcausal
+ if not torch.all(flash_attention_mask):
+ attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
diff --git a/colossalai/shardformer/modeling/whisper.py b/colossalai/shardformer/modeling/whisper.py
index ef59dbcee680..9827d4801f8d 100644
--- a/colossalai/shardformer/modeling/whisper.py
+++ b/colossalai/shardformer/modeling/whisper.py
@@ -106,7 +106,10 @@ def forward(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool).contiguous())
- attn_type = AttnMaskType.paddedcausal
+ if not torch.all(flash_attention_mask):
+ attn_type = AttnMaskType.paddedcausal
+ else:
+ attn_type = AttnMaskType.causal
attention = ColoAttention(
embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=self.dropout, scale=self.scaling
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index f3587de15f86..b01896e48021 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -1,6 +1,5 @@
import importlib
from dataclasses import dataclass
-from typing import Optional
import torch.nn as nn
@@ -149,39 +148,12 @@ class PolicyLocation:
),
}
-_INFER_POLICY_LIST = {
- # LlaMa
- "transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
- file_name="llama", class_name="LlamaModelInferPolicy"
- ),
- "transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
- file_name="llama", class_name="LlamaModelInferPolicy"
- ),
- # Bloom
- "transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
- file_name="bloom", class_name="BloomModelInferPolicy"
- ),
- "transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
- file_name="bloom", class_name="BloomModelInferPolicy"
- ),
- # ChatGLM2
- "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
- file_name="chatglm2", class_name="ChatGLM2InferPolicy"
- ),
- "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
- file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"
- ),
-}
-
-def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy:
+def import_policy(policy_location: PolicyLocation) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
- if inference_only:
- module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
- else:
- module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
+ module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)
@@ -197,7 +169,7 @@ def _fullname(obj):
return module + "." + klass.__qualname__
-def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) -> Policy:
+def get_autopolicy(model: nn.Module) -> Policy:
r"""
Return the auto policy for the model
@@ -208,15 +180,12 @@ def get_autopolicy(model: nn.Module, inference_only: Optional[bool] = False) ->
:class:`Policy`: The auto policy for the model
"""
full_name = _fullname(model)
- if inference_only:
- policy_location = _INFER_POLICY_LIST.get(full_name, None)
- else:
- policy_location = _POLICY_LIST.get(full_name, None)
+ policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None:
raise NotImplementedError(
- f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}"
+ f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
- policy = import_policy(policy_location, inference_only)
+ policy = import_policy(policy_location)
return policy()
diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py
index eb03500531bc..003c9322aea9 100644
--- a/colossalai/shardformer/policies/base_policy.py
+++ b/colossalai/shardformer/policies/base_policy.py
@@ -11,6 +11,7 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
+from ..layer.normalization import BaseLayerNorm
from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig
@@ -29,7 +30,7 @@ class SubModuleReplacementDescription:
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
- target_module: ParallelModule
+ target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False
@@ -77,7 +78,6 @@ def __init__(self) -> None:
def set_model(self, model: nn.Module) -> None:
r"""
Set model as an attribute of the Policy object so that we can access the model's attributes.
-
Args:
model (:class:`nn.Module`): The model to be perform
"""
@@ -86,11 +86,11 @@ def set_model(self, model: nn.Module) -> None:
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
-
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
+
self.config_sanity_check()
@property
diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py
index 14146de158ae..c31327a6ccaf 100644
--- a/colossalai/shardformer/policies/bert.py
+++ b/colossalai/shardformer/policies/bert.py
@@ -60,6 +60,12 @@ def module_policy(self):
)
policy = {}
+
+ if self.shard_config.enable_fused_normalization:
+ norm_cls = col_nn.FusedLayerNorm
+ else:
+ norm_cls = col_nn.LayerNorm
+
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@@ -141,33 +147,34 @@ def module_policy(self):
)
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- # Handle bert layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="attention.output.LayerNorm",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="output.LayerNorm",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=BertLayer,
- )
- # handle embedding layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="LayerNorm",
- target_module=col_nn.FusedLayerNorm,
- )
- ],
- policy=policy,
- target_key=BertEmbeddings,
- )
+ # Handle bert layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="attention.output.LayerNorm",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": use_sequence_parallel},
+ ),
+ SubModuleReplacementDescription(
+ suffix="output.LayerNorm",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": use_sequence_parallel},
+ ),
+ ],
+ policy=policy,
+ target_key=BertLayer,
+ )
+ # handle embedding layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="LayerNorm",
+ target_module=norm_cls,
+ )
+ ],
+ policy=policy,
+ target_key=BertEmbeddings,
+ )
# use flash attention
if self.shard_config.enable_flash_attention:
@@ -288,9 +295,6 @@ def get_held_layers(self) -> List[Module]:
# BertModel
class BertModelPolicy(BertPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertModel
@@ -313,9 +317,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# BertForPreTraining
class BertForPreTrainingPolicy(BertPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
@@ -355,9 +356,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
@@ -396,9 +394,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
@@ -437,9 +432,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification
@@ -484,9 +476,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification
@@ -531,9 +520,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertForNextSentencePrediction
@@ -564,9 +550,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice
@@ -610,9 +593,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
class BertForQuestionAnsweringPolicy(BertPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.bert.modeling_bert import BertForQuestionAnswering
diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py
index 997643d1a911..9be2a1e78073 100644
--- a/colossalai/shardformer/policies/blip2.py
+++ b/colossalai/shardformer/policies/blip2.py
@@ -43,6 +43,11 @@ def module_policy(self):
policy = {}
+ if self.shard_config.enable_fused_normalization:
+ norm_cls = col_nn.FusedLayerNorm
+ else:
+ norm_cls = col_nn.LayerNorm
+
if self.shard_config.enable_tensor_parallelism:
policy[Blip2EncoderLayer] = ModulePolicyDescription(
attribute_replacement={
@@ -214,94 +219,93 @@ def module_policy(self):
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- # Handle Blip2EncoderLayer layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="layer_norm1",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="layer_norm2",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=Blip2EncoderLayer,
- )
+ # Handle Blip2EncoderLayer layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm1",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm2",
+ target_module=norm_cls,
+ ),
+ ],
+ policy=policy,
+ target_key=Blip2EncoderLayer,
+ )
- # handle Blip2VisionModel layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="post_layernorm",
- target_module=col_nn.FusedLayerNorm,
- )
- ],
- policy=policy,
- target_key=Blip2VisionModel,
- )
+ # handle Blip2VisionModel layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="post_layernorm",
+ target_module=norm_cls,
+ )
+ ],
+ policy=policy,
+ target_key=Blip2VisionModel,
+ )
- # handle Blip2VisionModel layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="layernorm",
- target_module=col_nn.FusedLayerNorm,
- )
- ],
- policy=policy,
- target_key=Blip2QFormerModel,
- )
+ # handle Blip2VisionModel layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="layernorm",
+ target_module=norm_cls,
+ )
+ ],
+ policy=policy,
+ target_key=Blip2QFormerModel,
+ )
- # handle Blip2QFormerLayer layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="attention.output.LayerNorm",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="crossattention.output.LayerNorm",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="output_query.LayerNorm",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=Blip2QFormerLayer,
- )
+ # handle Blip2QFormerLayer layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="attention.output.LayerNorm",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="crossattention.output.LayerNorm",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="output_query.LayerNorm",
+ target_module=norm_cls,
+ ),
+ ],
+ policy=policy,
+ target_key=Blip2QFormerLayer,
+ )
- # handle OPTForCausalLM layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="model.decoder.final_layer_norm",
- target_module=col_nn.FusedLayerNorm,
- )
- ],
- policy=policy,
- target_key=OPTForCausalLM,
- )
+ # handle OPTForCausalLM layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="model.decoder.final_layer_norm",
+ target_module=norm_cls,
+ )
+ ],
+ policy=policy,
+ target_key=OPTForCausalLM,
+ )
- # handle OPTDecoderLayer layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="self_attn_layer_norm",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="final_layer_norm",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=OPTDecoderLayer,
- )
+ # handle OPTDecoderLayer layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="self_attn_layer_norm",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="final_layer_norm",
+ target_module=norm_cls,
+ ),
+ ],
+ policy=policy,
+ target_key=OPTDecoderLayer,
+ )
# use flash attention
if self.shard_config.enable_flash_attention:
diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py
index 13b9dd31345d..c8687a1ac60e 100644
--- a/colossalai/shardformer/policies/bloom.py
+++ b/colossalai/shardformer/policies/bloom.py
@@ -42,6 +42,10 @@ def module_policy(self):
policy = {}
+ if self.shard_config.enable_fused_normalization:
+ norm_cls = col_nn.FusedLayerNorm
+ else:
+ norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@@ -97,38 +101,39 @@ def module_policy(self):
)
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- # handle bloom model
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="ln_f",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="word_embeddings_layernorm",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=BloomModel,
- )
-
- # handle bloom block
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="input_layernorm",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="post_attention_layernorm",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=BloomBlock,
- )
+ # handle bloom model
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="ln_f",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="word_embeddings_layernorm",
+ target_module=norm_cls,
+ ),
+ ],
+ policy=policy,
+ target_key=BloomModel,
+ )
+
+ # handle bloom block
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="input_layernorm",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": use_sequence_parallel},
+ ),
+ SubModuleReplacementDescription(
+ suffix="post_attention_layernorm",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": use_sequence_parallel},
+ ),
+ ],
+ policy=policy,
+ target_key=BloomBlock,
+ )
if use_sequence_parallel:
self.append_or_create_method_replacement(
@@ -225,9 +230,6 @@ def get_held_layers(self) -> List[Module]:
class BloomModelPolicy(BloomPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
policy = super().module_policy()
from transformers.models.bloom.modeling_bloom import BloomModel
diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py
index 3c27c848e738..d1ad9f91478b 100644
--- a/colossalai/shardformer/policies/chatglm2.py
+++ b/colossalai/shardformer/policies/chatglm2.py
@@ -45,6 +45,16 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
+ if self.shard_config.enable_fused_normalization:
+ if self.model.config.rmsnorm:
+ norm_cls = col_nn.FusedRMSNorm
+ else:
+ norm_cls = col_nn.FusedLayerNorm
+ else:
+ if self.model.config.rmsnorm:
+ norm_cls = col_nn.RMSNorm
+ else:
+ norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@@ -94,54 +104,35 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
),
],
)
-
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- if not self.model.config.rmsnorm:
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
- SubModuleReplacementDescription(
- suffix="post_attention_layernorm", target_module=col_nn.FusedLayerNorm
- ),
- ],
- policy=policy,
- target_key=GLMBlock,
- )
-
- if self.model.config.post_layer_norm:
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="encoder.final_layernorm", target_module=col_nn.FusedLayerNorm
- )
- ],
- policy=policy,
- target_key=ChatGLMModel,
- )
-
- else:
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
- SubModuleReplacementDescription(
- suffix="post_attention_layernorm", target_module=col_nn.FusedRMSNorm
- ),
- ],
- policy=policy,
- target_key=GLMBlock,
- )
-
- if self.model.config.post_layer_norm:
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="encoder.final_layernorm", target_module=col_nn.FusedRMSNorm
- )
- ],
- policy=policy,
- target_key=ChatGLMModel,
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="input_layernorm",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": use_sequence_parallel},
+ ),
+ SubModuleReplacementDescription(
+ suffix="post_attention_layernorm",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": use_sequence_parallel},
+ ),
+ ],
+ policy=policy,
+ target_key=GLMBlock,
+ )
+
+ if self.model.config.post_layer_norm:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="encoder.final_layernorm",
+ target_module=norm_cls,
)
+ ],
+ policy=policy,
+ target_key=ChatGLMModel,
+ )
# use flash attention
if self.shard_config.enable_flash_attention:
@@ -224,9 +215,6 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
class ChatGLMModelPolicy(ChatGLMPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
pass
diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py
index 6f46bfc7ef9f..022e6ff5b32c 100644
--- a/colossalai/shardformer/policies/gpt2.py
+++ b/colossalai/shardformer/policies/gpt2.py
@@ -39,6 +39,11 @@ def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {}
+
+ if self.shard_config.enable_fused_normalization:
+ norm_cls = col_nn.FusedLayerNorm
+ else:
+ norm_cls = col_nn.LayerNorm
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
@@ -102,33 +107,37 @@ def module_policy(self):
)
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- self.append_or_create_submodule_replacement(
- description=SubModuleReplacementDescription(
- suffix="ln_f",
- target_module=col_nn.FusedLayerNorm,
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="ln_f",
+ target_module=norm_cls,
+ ),
+ policy=policy,
+ target_key=GPT2Model,
+ )
+
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="ln_1",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": use_sequence_parallel},
),
- policy=policy,
- target_key=GPT2Model,
- )
-
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="ln_1",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="ln_2",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="ln_cross_attn", target_module=col_nn.FusedLayerNorm, ignore_if_not_exist=True
- ),
- ],
- policy=policy,
- target_key=GPT2Block,
- )
+ SubModuleReplacementDescription(
+ suffix="ln_2",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": use_sequence_parallel},
+ ),
+ SubModuleReplacementDescription(
+ suffix="ln_cross_attn",
+ target_module=norm_cls,
+ ignore_if_not_exist=True,
+ kwargs={"sp_partial_derived": use_sequence_parallel},
+ ),
+ ],
+ policy=policy,
+ target_key=GPT2Block,
+ )
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
@@ -192,9 +201,6 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
# GPT2Model
class GPT2ModelPolicy(GPT2Policy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
@@ -216,9 +222,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
@@ -263,9 +266,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel
@@ -317,9 +317,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForQuestionAnswering
@@ -347,9 +344,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForTokenClassification
@@ -387,9 +381,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2ForSequenceClassification
diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py
index 099995acb440..915f07d31da1 100644
--- a/colossalai/shardformer/policies/llama.py
+++ b/colossalai/shardformer/policies/llama.py
@@ -6,7 +6,7 @@
from torch import Tensor
from torch.nn import Module
-from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
+from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from ..modeling.llama import LlamaPipelineForwards, get_llama_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -35,6 +35,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
policy = {}
+ if self.shard_config.enable_fused_normalization:
+ norm_cls = FusedRMSNorm
+ else:
+ norm_cls = RMSNorm
+
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@@ -93,31 +98,31 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="input_layernorm",
- target_module=FusedRMSNorm,
- ),
- SubModuleReplacementDescription(
- suffix="post_attention_layernorm",
- target_module=FusedRMSNorm,
- ),
- ],
- policy=policy,
- target_key=LlamaDecoderLayer,
- )
-
- self.append_or_create_submodule_replacement(
- description=SubModuleReplacementDescription(
- suffix="norm",
- target_module=FusedRMSNorm,
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="input_layernorm",
+ target_module=norm_cls,
),
- policy=policy,
- target_key=LlamaModel,
- )
-
+ SubModuleReplacementDescription(
+ suffix="post_attention_layernorm",
+ target_module=norm_cls,
+ ),
+ ],
+ policy=policy,
+ target_key=LlamaDecoderLayer,
+ )
+
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="norm",
+ target_module=norm_cls,
+ ),
+ policy=policy,
+ target_key=LlamaModel,
+ )
+
+ # use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
@@ -174,9 +179,6 @@ def get_held_layers(self) -> List[Module]:
class LlamaModelPolicy(LlamaPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel
diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py
index 5739d21a3903..0b5c767e1d91 100644
--- a/colossalai/shardformer/policies/opt.py
+++ b/colossalai/shardformer/policies/opt.py
@@ -5,7 +5,7 @@
import torch.nn as nn
from torch import Tensor, nn
-from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
+from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
@@ -42,6 +42,12 @@ def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {}
+
+ if self.shard_config.enable_fused_normalization:
+ norm_cls = FusedLayerNorm
+ else:
+ norm_cls = LayerNorm
+
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@@ -94,26 +100,25 @@ def module_policy(self):
)
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- self.append_or_create_submodule_replacement(
- description=SubModuleReplacementDescription(
- suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
+ ),
+ policy=policy,
+ target_key=OPTDecoder,
+ )
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
),
- policy=policy,
- target_key=OPTDecoder,
- )
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="self_attn_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
- ),
- SubModuleReplacementDescription(
- suffix="final_layer_norm", target_module=FusedLayerNorm, ignore_if_not_exist=True
- ),
- ],
- policy=policy,
- target_key=OPTDecoderLayer,
- )
+ SubModuleReplacementDescription(
+ suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True
+ ),
+ ],
+ policy=policy,
+ target_key=OPTDecoderLayer,
+ )
# use flash attention
if self.shard_config.enable_flash_attention:
@@ -183,9 +188,6 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
class OPTModelPolicy(OPTPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTModel
@@ -253,9 +255,6 @@ def postprocess(self):
class OPTForSequenceClassificationPolicy(OPTPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForSequenceClassification
@@ -281,9 +280,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
class OPTForQuestionAnsweringPolicy(OPTPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForQuestionAnswering
diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py
index 58a8500e3863..498e62164b09 100644
--- a/colossalai/shardformer/policies/sam.py
+++ b/colossalai/shardformer/policies/sam.py
@@ -24,6 +24,11 @@ def module_policy(self):
policy = {}
+ if self.shard_config.enable_fused_normalization:
+ norm_cls = col_nn.FusedLayerNorm
+ else:
+ norm_cls = col_nn.LayerNorm
+
if self.shard_config.enable_tensor_parallelism:
policy[SamVisionLayer] = ModulePolicyDescription(
attribute_replacement={
@@ -151,58 +156,57 @@ def module_policy(self):
)
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- # Handle SamVisionLayer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="layer_norm1",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="layer_norm2",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=SamVisionLayer,
- )
+ # Handle SamVisionLayer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm1",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm2",
+ target_module=norm_cls,
+ ),
+ ],
+ policy=policy,
+ target_key=SamVisionLayer,
+ )
- # Handle SamTwoWayAttentionBlock
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="layer_norm1",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="layer_norm2",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="layer_norm3",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="layer_norm4",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=SamTwoWayAttentionBlock,
- )
+ # Handle SamTwoWayAttentionBlock
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm1",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm2",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm3",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="layer_norm4",
+ target_module=norm_cls,
+ ),
+ ],
+ policy=policy,
+ target_key=SamTwoWayAttentionBlock,
+ )
- # Handle SamTwoWayTransformer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="layer_norm_final_attn",
- target_module=col_nn.FusedLayerNorm,
- )
- ],
- policy=policy,
- target_key=SamTwoWayTransformer,
- )
+ # Handle SamTwoWayTransformer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm_final_attn",
+ target_module=norm_cls,
+ )
+ ],
+ policy=policy,
+ target_key=SamTwoWayTransformer,
+ )
# use flash attention
if self.shard_config.enable_flash_attention:
diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py
index 74cc7337e9f1..4d906e3f4c04 100644
--- a/colossalai/shardformer/policies/t5.py
+++ b/colossalai/shardformer/policies/t5.py
@@ -11,6 +11,7 @@
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
+ RMSNorm,
VocabParallelEmbedding1D,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
@@ -58,6 +59,11 @@ def module_policy(self):
policy = {}
+ if self.shard_config.enable_fused_normalization:
+ norm_cls = FusedRMSNorm
+ else:
+ norm_cls = RMSNorm
+
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
@@ -169,38 +175,29 @@ def module_policy(self):
)
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- self.append_or_create_submodule_replacement(
- description=SubModuleReplacementDescription(
- suffix="layer_norm",
- target_module=FusedRMSNorm,
- ),
- policy=policy,
- target_key=T5LayerFF,
- )
- self.append_or_create_submodule_replacement(
- description=SubModuleReplacementDescription(
- suffix="layer_norm",
- target_module=FusedRMSNorm,
- ),
- policy=policy,
- target_key=T5LayerFF,
- )
- self.append_or_create_submodule_replacement(
- description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
- policy=policy,
- target_key=T5LayerSelfAttention,
- )
- self.append_or_create_submodule_replacement(
- description=SubModuleReplacementDescription(suffix="layer_norm", target_module=FusedRMSNorm),
- policy=policy,
- target_key=T5LayerCrossAttention,
- )
- self.append_or_create_submodule_replacement(
- description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=FusedRMSNorm),
- policy=policy,
- target_key=T5Stack,
- )
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="layer_norm",
+ target_module=norm_cls,
+ ),
+ policy=policy,
+ target_key=T5LayerFF,
+ )
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
+ policy=policy,
+ target_key=T5LayerSelfAttention,
+ )
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
+ policy=policy,
+ target_key=T5LayerCrossAttention,
+ )
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=norm_cls),
+ policy=policy,
+ target_key=T5Stack,
+ )
# use flash attention
if self.shard_config.enable_flash_attention:
@@ -363,9 +360,6 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
class T5ModelPolicy(T5BasePolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers import T5Model
@@ -402,9 +396,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
class T5ForConditionalGenerationPolicy(T5BasePolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers import T5ForConditionalGeneration
@@ -466,9 +457,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
class T5EncoderPolicy(T5BasePolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers import T5EncoderModel
diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py
index 270cdce9b091..6ef0e3b34b2b 100644
--- a/colossalai/shardformer/policies/vit.py
+++ b/colossalai/shardformer/policies/vit.py
@@ -159,9 +159,6 @@ def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable,
# ViTModel
class ViTModelPolicy(ViTPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTModel
@@ -227,9 +224,6 @@ def get_held_layers(self) -> List[nn.Module]:
# ViTForMaskedImageModeling
class ViTForMaskedImageModelingPolicy(ViTPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel
diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py
index d9af2461cdb8..3ce198e9eedc 100644
--- a/colossalai/shardformer/policies/whisper.py
+++ b/colossalai/shardformer/policies/whisper.py
@@ -52,6 +52,11 @@ def module_policy(self):
policy = {}
+ if self.shard_config.enable_fused_normalization:
+ norm_cls = col_nn.FusedLayerNorm
+ else:
+ norm_cls = col_nn.LayerNorm
+
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
@@ -161,62 +166,61 @@ def module_policy(self):
)
# optimization configuration
- if self.shard_config.enable_fused_normalization:
- # Handle encoder layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="self_attn_layer_norm",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="final_layer_norm",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=WhisperEncoderLayer,
- )
+ # Handle encoder layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="self_attn_layer_norm",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="final_layer_norm",
+ target_module=norm_cls,
+ ),
+ ],
+ policy=policy,
+ target_key=WhisperEncoderLayer,
+ )
- # Handle decoder layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="self_attn_layer_norm",
- target_module=col_nn.FusedLayerNorm,
- ),
- SubModuleReplacementDescription(
- suffix="final_layer_norm",
- target_module=col_nn.FusedLayerNorm,
- ),
- ],
- policy=policy,
- target_key=WhisperDecoderLayer,
- )
+ # Handle decoder layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="self_attn_layer_norm",
+ target_module=norm_cls,
+ ),
+ SubModuleReplacementDescription(
+ suffix="final_layer_norm",
+ target_module=norm_cls,
+ ),
+ ],
+ policy=policy,
+ target_key=WhisperDecoderLayer,
+ )
- # handle encoder layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="layer_norm",
- target_module=col_nn.FusedLayerNorm,
- )
- ],
- policy=policy,
- target_key=WhisperEncoder,
- )
+ # handle encoder layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm",
+ target_module=norm_cls,
+ )
+ ],
+ policy=policy,
+ target_key=WhisperEncoder,
+ )
- # handle decoder layer
- self.append_or_create_submodule_replacement(
- description=[
- SubModuleReplacementDescription(
- suffix="layer_norm",
- target_module=col_nn.FusedLayerNorm,
- )
- ],
- policy=policy,
- target_key=WhisperDecoder,
- )
+ # handle decoder layer
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="layer_norm",
+ target_module=norm_cls,
+ )
+ ],
+ policy=policy,
+ target_key=WhisperDecoder,
+ )
# enable flash attention
if self.shard_config.enable_flash_attention:
@@ -416,9 +420,6 @@ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, poli
# WhisperModel
class WhisperModelPolicy(WhisperPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers import WhisperModel
@@ -441,9 +442,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def module_policy(self):
from transformers import WhisperForConditionalGeneration
@@ -502,9 +500,6 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
- def __init__(self) -> None:
- super().__init__()
-
def preprocess(self):
return self.model
diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py
index a285874d218b..a00506a1a068 100644
--- a/colossalai/shardformer/shard/shard_config.py
+++ b/colossalai/shardformer/shard/shard_config.py
@@ -1,5 +1,5 @@
-from dataclasses import dataclass
-from typing import Optional
+from dataclasses import dataclass, field
+from typing import Any, Dict, Optional
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -24,7 +24,6 @@ class ShardConfig:
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
- inference_only (bool): Whether only doing forward passing. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
@@ -33,10 +32,9 @@ class ShardConfig:
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_all_optimization: bool = False
- inference_only: bool = False
- inference_gptq: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
+ extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
@@ -76,4 +74,4 @@ def _infer(self):
"""
Set default params for inference.
"""
- assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
+ # assert self.pipeline_stage_manager is None, "pipeline parallelism is not supported in inference for now"
diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py
index 1bed850c6581..fc2f92778583 100644
--- a/colossalai/shardformer/shard/sharder.py
+++ b/colossalai/shardformer/shard/sharder.py
@@ -27,8 +27,8 @@ class ModelSharder(object):
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
- self.policy = get_autopolicy(self.model, shard_config.inference_only) if policy is None else policy
self.shard_config = shard_config
+ self.policy = get_autopolicy(self.model) if policy is None else policy
def shard(self) -> List[Dict[int, Tensor]]:
r"""
@@ -180,7 +180,6 @@ def _replace_sub_module(
assert target_module is not None, "target_module should not be None"
native_sub_module = getattr_(org_layer, suffix, ignore=True)
-
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
continue
@@ -196,7 +195,7 @@ def _replace_sub_module(
try:
replace_layer = target_module.from_native_module(
- native_sub_module, self.shard_config.tensor_parallel_process_group, **kwargs
+ native_sub_module, process_group=self.shard_config.tensor_parallel_process_group, **kwargs
)
except Exception as e:
raise RuntimeError(
diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py
index fad5101d380c..6f8097735d57 100644
--- a/colossalai/tensor/d_tensor/__init__.py
+++ b/colossalai/tensor/d_tensor/__init__.py
@@ -2,7 +2,9 @@
compute_global_numel,
customized_distributed_tensor_to_param,
distribute_tensor,
+ init_as_dtensor,
distribute_tensor_with_customization,
+ init_tensor_as_customization_distributed,
get_device_mesh,
get_global_shape,
get_layout,
@@ -23,6 +25,7 @@
__all__ = [
"is_distributed_tensor",
"distribute_tensor",
+ "init_as_dtensor",
"to_global",
"is_sharded",
"shard_rowwise",
@@ -36,6 +39,7 @@
"get_layout",
"is_customized_distributed_tensor",
"distribute_tensor_with_customization",
+ "init_tensor_as_customization_distributed",
"to_global_for_customized_distributed_tensor",
"customized_distributed_tensor_to_param",
"Layout",
diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py
index 178bac428ea9..74a785f2dcd4 100644
--- a/colossalai/tensor/d_tensor/api.py
+++ b/colossalai/tensor/d_tensor/api.py
@@ -128,6 +128,17 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
return sharded_tensor
+def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor:
+ assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
+ dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
+
+ # shard tensor
+ tensor.dist_layout = dist_layout
+
+ # hack some tensor methods
+ _hijack_detach_and_clone(tensor)
+
+ return tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
"""
@@ -420,6 +431,54 @@ def gather_fn(tensor):
return sharded_tensor
+def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gather_fn: callable):
+ """
+ Distribute the given tensor with the given shard_fn and gather_fn.
+
+ Example:
+
+ ```python
+ # define shard and gather functions
+ def shard_fn(tensor):
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+ return tensor.chunk(world_size, dim=0)[rank]
+
+ def gather_fn(tensor):
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+ shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
+ torch.distributed.all_gather(shard_list, tensor)
+ return torch.cat(shard_list, dim=0)
+
+ # create a distributed tensor
+ tensor = torch.rand(4, 4)
+ dtensor = init_tensor_as_customization_distributed(tensor, shard_fn, gather_fn)
+ ```
+
+ Args:
+ tensor (torch.Tensor): The tensor to be distributed.
+ shard_fn (callable): The function to shard the tensor.
+ gather_fn (callable): The function to gather the tensor.
+
+ Returns:
+ torch.Tensor: The distributed tensor.
+ """
+ assert callable(shard_fn), "The shard_fn must be callable."
+ assert callable(gather_fn), "The gather_fn must be callable."
+ assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
+
+
+ # set the shard_fn and gather_fn as attributes of the distributed tensor
+ tensor.shard_fn = shard_fn
+ tensor.gather_fn = gather_fn
+
+ # set the shard_fn and gather_fn as attributes of the distributed tensor
+ _hijack_detach_and_clone_for_customized_distributed_tensor(tensor)
+
+ return tensor
+
+
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
"""
Gather the given tensor to the global tensor.
diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py
index e031e0472b0b..abe4a86d8198 100644
--- a/colossalai/tensor/d_tensor/layout_converter.py
+++ b/colossalai/tensor/d_tensor/layout_converter.py
@@ -4,6 +4,7 @@
from typing import Dict, List, Tuple
import torch
+import torch.distributed as dist
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
@@ -438,11 +439,58 @@ def layout_converting(
MAX_TRANSFORM_STEPS = 20
total_steps = 0
transform_path = []
- comm_action_sequence = []
+ comm_action_sequence: List[CommSpec] = []
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
if spec_pairs in self.cached_solution:
- return self.cached_solution[spec_pairs]
+ # Solution Cache hit
+
+ def _group_alive_check(cached_comm_action_sequence):
+ r"""
+ Check if the process groups required for sharding have been deleted by torch.distributed.destroy_process_group method.
+ If not deleted, return True; otherwise, return False.
+
+ Args:
+ cached_comm_action_sequence (List[CommSpec]): A list of communication specifications representing actions.
+
+ Returns:
+ bool: True if all process groups are still registered, False if at least one has been deleted.
+
+ Raises:
+ RuntimeError: If there is an error while checking the status of a process group.
+ """
+
+ # Collect all process groups used in communication actions from the cached sequence
+ used_process_groups = [
+ pg for comm_spec in cached_comm_action_sequence for pg in comm_spec.process_group_dict.values()
+ ]
+
+ # Check if each process group is still alive
+ for process_group in used_process_groups:
+ try:
+ dist.get_rank(process_group)
+ except RuntimeError as e:
+ # If the group is not registered, it means it has been deleted
+ if str(e) == (
+ f"Group {process_group} is not registered, please create group with torch.distributed.new_group API"
+ ):
+ return False
+ elif str(e) == "The given group does not exist":
+ return False
+ else:
+ # Re-raise the exception if it's not related to group deletion
+ raise e
+ # All process groups are alive
+ return True
+
+ cached_transform_path, cached_comm_action_sequence = self.cached_solution[spec_pairs]
+
+ if _group_alive_check(cached_comm_action_sequence):
+ # If all process groups have not been deleted, the cache is valid
+ return cached_transform_path, cached_comm_action_sequence
+ else:
+ # If at least one process group has been deleted, the cache is invalid, so delete it
+ del self.cached_solution[spec_pairs]
# We do nothing if the sharding spec is all the same.
if source_spec.spec_diff(target_spec) == 0:
diff --git a/colossalai/tensor/moe_tensor/__init__.py b/colossalai/tensor/moe_tensor/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/colossalai/tensor/moe_tensor/api.py b/colossalai/tensor/moe_tensor/api.py
new file mode 100644
index 000000000000..1e4486101dd3
--- /dev/null
+++ b/colossalai/tensor/moe_tensor/api.py
@@ -0,0 +1,152 @@
+from typing import List
+
+import torch
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+from .moe_info import MoeParallelInfo
+
+
+def is_moe_tensor(tensor: torch.Tensor) -> bool:
+ """
+ Check whether the given tensor is a moe tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ bool: Whether the given tensor is a moe tensor.
+ """
+ return hasattr(tensor, "moe_info")
+
+
+def set_moe_tensor_info(tensor: torch.Tensor, moe_info: MoeParallelInfo) -> None:
+ """
+ Set moe info for the given tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be set.
+ moe_info (dict): The moe info to be set.
+
+ """
+ tensor.__setattr__("moe_info", moe_info)
+
+
+def get_moe_info(ep_size: int, dp_size: int, pp_size: int, ep_inside: bool) -> MoeParallelInfo:
+ """
+ Get moe info for the given tensor.
+
+ Args:
+ ep_size (int): The expert parallel size.
+ dp_size (int): The data parallel size.
+ pp_size (int): The pipeline parallel size.
+ ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle.
+
+ Returns:
+ dict: The moe info of the given tensor.
+ """
+ return MoeParallelInfo(ep_inside, ep_size, dp_size, pp_size)
+
+
+def get_ep_group(tensor: torch.Tensor) -> ProcessGroup:
+ """
+ Get the expert parallel group of the given tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ torch.distributed.ProcessGroup: The expert parallel group of the given tensor.
+ """
+ return tensor.moe_info.ep_group
+
+
+def get_ep_size(tensor: torch.Tensor) -> int:
+ """
+ Get the expert parallel size of the given tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ int: The expert parallel size of the given tensor.
+ """
+ return tensor.moe_info.ep_size
+
+
+def get_dp_size(tensor: torch.Tensor) -> int:
+ """
+ Get the data parallel size of the given tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ int: The data parallel size of the given tensor.
+ """
+ return tensor.moe_info.dp_size
+
+
+def get_dp_group(tensor: torch.Tensor) -> ProcessGroup:
+ """
+ Get the data parallel group of the given tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ torch.distributed.ProcessGroup: The data parallel group of the given tensor.
+ """
+ return tensor.moe_info.dp_group
+
+
+def get_ep_rank(tensor: torch.Tensor) -> int:
+ """
+ Get the expert parallel rank of the given tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ int: The expert parallel rank of the given tensor.
+ """
+ return dist.get_rank(get_ep_group(tensor))
+
+
+def get_dp_rank(tensor: torch.Tensor) -> int:
+ """
+ Get the data parallel rank of the given tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ int: The data parallel rank of the given tensor.
+ """
+ return dist.get_rank(get_dp_group(tensor))
+
+
+def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]:
+ """
+ Get the expert parallel group ranks of the given tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ int: The expert parallel group ranks of the given tensor.
+ """
+ return tensor.moe_info.ep_group_ranks
+
+
+def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]:
+ """
+ Get the data parallel group ranks of the given tensor.
+
+ Args:
+ tensor (torch.Tensor): The tensor to be checked.
+
+ Returns:
+ int: The data parallel group ranks of the given tensor.
+ """
+ return tensor.moe_info.dp_group_ranks
diff --git a/colossalai/tensor/moe_tensor/moe_info.py b/colossalai/tensor/moe_tensor/moe_info.py
new file mode 100644
index 000000000000..5097ac1044e7
--- /dev/null
+++ b/colossalai/tensor/moe_tensor/moe_info.py
@@ -0,0 +1,28 @@
+from colossalai.cluster import ProcessGroupMesh
+
+
+class MoeParallelInfo:
+ """Moe parallelism information, storing parallel sizes and groups."""
+
+ def __init__(self, ep_inside: bool, ep_size: int, dp_size: int, pp_size: int = 1):
+ """
+ init MoeParallelInfo with ep_size, dp_size and pp_size
+
+ Args:
+ ep_size (int): expert parallel size
+ dp_size (int): data parallel (zero) size
+ pp_size (int, optional): pipeline parallel size. Defaults to 1.
+ ep_inside (bool, optional): Use ep inside dp if True, dp inside ep if Fasle. Defaults to True.
+ """
+ self.pp_size, self.dp_size, self.ep_size = pp_size, dp_size, ep_size
+ if ep_inside:
+ self.pp_axis, self.dp_axis, self.ep_axis = 0, 1, 2
+ self.pg = ProcessGroupMesh(self.pp_size, self.dp_size, self.ep_size)
+ else:
+ self.pp_axis, self.ep_axis, self.dp_axis = 0, 1, 2
+ self.pg = ProcessGroupMesh(self.pp_size, self.ep_size, self.dp_size)
+
+ self.ep_group = self.pg.get_group_along_axis(self.ep_axis)
+ self.ep_group_ranks = self.pg.get_ranks_in_group(self.ep_group)
+ self.dp_group = self.pg.get_group_along_axis(self.dp_axis)
+ self.dp_group_ranks = self.pg.get_ranks_in_group(self.dp_group)
diff --git a/colossalai/testing/utils.py b/colossalai/testing/utils.py
index 839e7aab3567..7cd24b0adc60 100644
--- a/colossalai/testing/utils.py
+++ b/colossalai/testing/utils.py
@@ -9,6 +9,7 @@
import torch
import torch.multiprocessing as mp
from packaging import version
+from colossalai.utils.device import empty_cache, reset_max_memory_allocated, reset_peak_memory_stats, synchronize, reset_max_memory_cached, device_count
def parameterize(argument: str, values: List[Any]) -> Callable:
@@ -198,7 +199,7 @@ def test_something():
def _wrap_func(f):
def _execute_by_gpu_num(*args, **kwargs):
- num_avail_gpu = torch.cuda.device_count()
+ num_avail_gpu = device_count()
if num_avail_gpu >= min_gpus:
f(*args, **kwargs)
@@ -262,11 +263,11 @@ def test_something():
def _wrap_func(f):
def _clear_cache(*args, **kwargs):
- torch.cuda.empty_cache()
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_max_memory_cached()
- torch.cuda.synchronize()
+ empty_cache()
+ reset_peak_memory_stats()
+ reset_max_memory_allocated()
+ reset_max_memory_cached()
+ synchronize()
gc.collect()
f(*args, **kwargs)
diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py
index 3ec39b949a23..0246a35e2a1b 100644
--- a/colossalai/utils/__init__.py
+++ b/colossalai/utils/__init__.py
@@ -7,7 +7,7 @@
is_ddp_ignored,
set_seed,
)
-from .cuda import empty_cache, get_current_device, set_device, set_to_cuda, synchronize
+from .device import IS_NPU_AVAILABLE, empty_cache, get_current_device, set_device, set_to_cuda, synchronize
from .multi_tensor_apply import multi_tensor_applier
from .tensor_detector import TensorDetector
from .timer import MultiTimer, Timer
@@ -29,4 +29,5 @@
"set_seed",
"is_ddp_ignored",
"set_device",
+ "IS_NPU_AVAILABLE",
]
diff --git a/colossalai/utils/cuda.py b/colossalai/utils/cuda.py
deleted file mode 100644
index 6bfb08d1f04a..000000000000
--- a/colossalai/utils/cuda.py
+++ /dev/null
@@ -1,56 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-
-from typing import Optional
-
-import torch
-import torch.distributed as dist
-
-
-def set_to_cuda(models):
- """Send model to gpu.
-
- :param models: nn.module or a list of module
- """
- if isinstance(models, list) and len(models) > 1:
- ret = []
- for model in models:
- ret.append(model.to(get_current_device()))
- return ret
- elif isinstance(models, list):
- return models[0].to(get_current_device())
- else:
- return models.to(get_current_device())
-
-
-def get_current_device() -> torch.device:
- """
- Returns currently selected device (gpu/cpu).
- If cuda available, return gpu, otherwise return cpu.
- """
- if torch.cuda.is_available():
- return torch.device(f"cuda:{torch.cuda.current_device()}")
- else:
- return torch.device("cpu")
-
-
-def synchronize():
- """Similar to cuda.synchronize().
- Waits for all kernels in all streams on a CUDA device to complete.
- """
- if torch.cuda.is_available():
- torch.cuda.synchronize()
-
-
-def empty_cache():
- """Similar to cuda.empty_cache()
- Releases all unoccupied cached memory currently held by the caching allocator.
- """
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
-
-
-def set_device(index: Optional[int] = None) -> None:
- if index is None:
- index = dist.get_rank() % torch.cuda.device_count()
- torch.cuda.set_device(index)
diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py
new file mode 100644
index 000000000000..c70dbdaa5ee1
--- /dev/null
+++ b/colossalai/utils/device.py
@@ -0,0 +1,223 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+from typing import Any, Dict, List, Optional, Tuple, Callable
+
+import torch
+import torch.distributed as dist
+
+IS_NPU_AVAILABLE: bool = False
+try:
+ import torch_npu # noqa
+
+ IS_NPU_AVAILABLE = torch.npu.is_available()
+except ImportError:
+ pass
+
+
+def set_to_cuda(models):
+ """Send model to gpu.
+
+ :param models: nn.module or a list of module
+ """
+ if isinstance(models, list) and len(models) > 1:
+ ret = []
+ for model in models:
+ ret.append(model.to(get_current_device()))
+ return ret
+ elif isinstance(models, list):
+ return models[0].to(get_current_device())
+ else:
+ return models.to(get_current_device())
+
+
+def get_current_device() -> torch.device:
+ """
+ Returns currently selected device (gpu/cpu).
+ If cuda available, return gpu, otherwise return cpu.
+ """
+ if torch.cuda.is_available():
+ return torch.device(f"cuda:{torch.cuda.current_device()}")
+ elif IS_NPU_AVAILABLE:
+ return torch.device(f"npu:{torch.npu.current_device()}")
+ else:
+ return torch.device("cpu")
+
+
+def _dispatch_device_func(fn_name: str, *args, **kwargs):
+ if torch.cuda.is_available():
+ return getattr(torch.cuda, fn_name)(*args, **kwargs)
+ elif IS_NPU_AVAILABLE:
+ return getattr(torch.npu, fn_name)(*args, **kwargs)
+ else:
+ raise RuntimeError("No device available")
+
+
+# device semantics
+
+
+def can_device_access_peer(device, peer_device) -> bool:
+ return _dispatch_device_func("can_device_access_peer", device, peer_device)
+
+
+def current_device() -> int:
+ return _dispatch_device_func("current_device")
+
+
+def current_stream(device=None):
+ return _dispatch_device_func("current_stream", device)
+
+
+def default_stream(device=None):
+ return _dispatch_device_func("default_stream", device)
+
+
+def device_count() -> int:
+ return _dispatch_device_func("device_count")
+
+
+def get_device_capability(device=None) -> Tuple[int, int]:
+ return _dispatch_device_func("get_device_capability", device)
+
+
+def get_device_name(device=None) -> str:
+ return _dispatch_device_func("get_device_name", device)
+
+
+def get_device_properties(device):
+ return _dispatch_device_func("get_device_properties", device)
+
+
+def set_device(index: Optional[int] = None) -> None:
+ if index is None:
+ index = dist.get_rank() % device_count()
+ _dispatch_device_func("set_device", index)
+
+
+def set_stream(stream_):
+ return _dispatch_device_func("set_stream", stream_)
+
+
+def stream(stream_):
+ return _dispatch_device_func("stream", stream_)
+
+
+def synchronize():
+ return _dispatch_device_func("synchronize")
+
+
+def utilization(device=None) -> int:
+ return _dispatch_device_func("utilization", device)
+
+
+# random number generator
+
+
+def get_rng_state(device="cuda") -> torch.Tensor:
+ return _dispatch_device_func("get_rng_state", device)
+
+
+def get_rng_state_all() -> List[torch.Tensor]:
+ return _dispatch_device_func("get_rng_state_all")
+
+
+def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
+ return _dispatch_device_func("set_rng_state", new_state, device)
+
+
+def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
+ return _dispatch_device_func("set_rng_state_all", new_states)
+
+
+def manual_seed(seed: int) -> None:
+ return _dispatch_device_func("manual_seed", seed)
+
+
+def manual_seed_all(seed: int) -> None:
+ return _dispatch_device_func("manual_seed_all", seed)
+
+
+def seed() -> None:
+ return _dispatch_device_func("seed")
+
+
+def seed_all() -> None:
+ return _dispatch_device_func("seed_all")
+
+
+def initial_seed() -> int:
+ return _dispatch_device_func("initial_seed")
+
+
+# streams and events
+
+
+def Stream(device=None, priority=0, **kwargs):
+ return _dispatch_device_func("Stream", device, priority, **kwargs)
+
+
+def Event(enable_timing: bool = False, blocking: bool = False, interprocess: bool = False):
+ return _dispatch_device_func("Event", enable_timing, blocking, interprocess)
+
+
+# memory management
+
+
+def empty_cache() -> None:
+ return _dispatch_device_func("empty_cache")
+
+
+def memory_stats(device=None) -> Dict[str, Any]:
+ return _dispatch_device_func("memory_stats", device)
+
+
+def memory_summary(device=None, abbreviated=False) -> str:
+ return _dispatch_device_func("memory_summary", device, abbreviated)
+
+
+def memory_snapshot():
+ return _dispatch_device_func("memory_snapshot")
+
+
+def memory_allocated(device=None) -> int:
+ return _dispatch_device_func("memory_allocated", device)
+
+
+def max_memory_allocated(device=None) -> int:
+ return _dispatch_device_func("max_memory_allocated", device)
+
+
+def reset_max_memory_allocated(device=None) -> None:
+ return _dispatch_device_func("reset_max_memory_allocated", device)
+
+
+def reset_max_memory_cached(device=None) -> None:
+ return _dispatch_device_func("reset_max_memory_cached", device)
+
+
+def memory_reserved(device=None) -> int:
+ return _dispatch_device_func("memory_reserved", device)
+
+
+def max_memory_reserved(device=None) -> int:
+ return _dispatch_device_func("max_memory_reserved", device)
+
+
+def set_per_process_memory_fraction(fraction: float, device=None) -> None:
+ return _dispatch_device_func("set_per_process_memory_fraction", fraction, device)
+
+
+def reset_peak_memory_stats(device=None) -> None:
+ return _dispatch_device_func("reset_peak_memory_stats", device)
+
+
+# amp
+
+
+def autocast() -> Callable:
+ if torch.cuda.is_available():
+ return torch.cuda.amp.autocast()
+ elif IS_NPU_AVAILABLE:
+ return torch.npu.amp.autocast()
+ else:
+ raise RuntimeError("No device available")
diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py
deleted file mode 100644
index 1b75448bdd3c..000000000000
--- a/colossalai/utils/moe.py
+++ /dev/null
@@ -1,53 +0,0 @@
-from typing import Dict, List
-
-import torch.distributed as dist
-import torch.nn as nn
-
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.legacy.context import ParallelMode
-from colossalai.legacy.core import global_context as gpc
-from colossalai.legacy.utils import is_using_ddp
-
-
-def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]:
- """Returns a parameter dictionary, the key of which is the expert parallel
- size of every parameter. Since the parameters in data parallelism is replicated
- in each GPU, we set their ep_size to 1.
-
- Args:
- model (:class:`torch.nn.Module`): A pyTorch `nn.Module` from which we get dict.
- """
- epsize_param_dict = dict()
- for param in model.parameters():
- if not hasattr(param, "moe_info"):
- ep_size = 1 # set ep_size to 1 for dp parameters
- else:
- ep_size = param.moe_info.ep_size
- if ep_size not in epsize_param_dict:
- epsize_param_dict[ep_size] = []
- epsize_param_dict[ep_size].append(param)
-
- return epsize_param_dict
-
-
-def sync_moe_model_param(model: nn.Module):
- """Make sure model parameters are consistent in MoE parallel context.
-
- Args:
- model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
- """
- if is_using_ddp():
- param_dict = get_moe_epsize_param_dict(model)
-
- # synchronize the parameters whose dp_group is the whole world
- if 1 in param_dict:
- src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
- for param in param_dict[1]:
- dist.broadcast(param, src=src_rank, group=gpc.get_group(ParallelMode.DATA))
-
- for ep_size in param_dict:
- # When ep_size = world_size, communication is not needed
- if ep_size != 1 and ep_size != MOE_CONTEXT.world_size:
- src_rank = dist.get_rank(MOE_CONTEXT.parallel_info_dict[ep_size].ep_group)
- for param in param_dict[ep_size]:
- dist.broadcast(param, src=src_rank, group=param.moe_info.dp_group)
diff --git a/colossalai/utils/timer.py b/colossalai/utils/timer.py
index 2f61817f0461..8ab6b46f28b6 100644
--- a/colossalai/utils/timer.py
+++ b/colossalai/utils/timer.py
@@ -3,7 +3,7 @@
import time
from typing import Tuple
-from .cuda import synchronize
+from .device import synchronize
class Timer:
diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py
index d3309fc5364f..defc6c4cb150 100644
--- a/colossalai/zero/gemini/chunk/chunk.py
+++ b/colossalai/zero/gemini/chunk/chunk.py
@@ -7,6 +7,7 @@
from torch.distributed import ProcessGroup
from colossalai.utils import get_current_device
+from colossalai.utils.device import IS_NPU_AVAILABLE
class TensorState(Enum):
@@ -61,12 +62,13 @@ class Chunk:
def __init__(
self,
chunk_size: int,
- process_group: ProcessGroup,
+ zero_group: ProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
cpu_shard_init: bool = False,
keep_gathered: bool = False,
pin_memory: bool = False,
+ extra_dp_group: ProcessGroup = None,
) -> None:
"""
Chunk: A container owning a piece of contiguous memory space for tensors
@@ -76,7 +78,7 @@ def __init__(
Args:
chunk_size (int): the number of elements in the chunk
- process_group (ProcessGroup): the process group of this chunk
+ zero_group (ProcessGroup): the process group of this chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, During the chunk construction process, where the tensor is stored.
The default value is None, which is the current GPU
@@ -90,9 +92,11 @@ def __init__(
self.chunk_size = chunk_size
self.utilized_size = 0
- self.torch_pg = process_group
+ self.torch_pg = zero_group
self.pg_size = dist.get_world_size(self.torch_pg)
self.pg_rank = dist.get_rank(self.torch_pg)
+ self.extra_dp_group = extra_dp_group
+ self.extra_dp_size = dist.get_world_size(self.extra_dp_group) if self.extra_dp_group is not None else 1
# the chunk size should be divisible by the dp degree
if not keep_gathered:
@@ -169,7 +173,7 @@ def memory_usage(self) -> Dict[str, int]:
if self.chunk_temp is not None:
# this chunk is not closed
- if self.chunk_temp.device.type == "cuda":
+ if self.chunk_temp.device.type == "cuda" or self.chunk_temp.device.type == "npu":
cuda_memory += self.chunk_mem
else:
cpu_memory += self.chunk_mem
@@ -188,10 +192,8 @@ def device_type(self) -> str:
if self.chunk_temp is not None:
return self.chunk_temp.device.type
else:
- if self.is_gathered:
- return "cuda"
- elif self.cuda_shard is not None:
- return "cuda"
+ if self.is_gathered or self.cuda_shard is not None:
+ return "npu" if IS_NPU_AVAILABLE else "cuda"
else:
return "cpu"
@@ -326,12 +328,12 @@ def shard_move(self, device: torch.device, force_copy: bool = False):
# when the current chunk is not synchronized with the optimizer
# just use another way for the movement
if not self.optim_sync_flag:
- assert device.type == "cuda", "each chunk should first be moved to CUDA"
+ assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
self.__paired_shard_move()
self.optim_sync_flag = True
return
- if device.type == "cuda":
+ if device.type == "cuda" or device.type == "npu":
assert device == get_current_device(), "can't move chunk to another device"
if self.cuda_shard:
@@ -384,14 +386,20 @@ def reduce(self):
# just move cuda_global_chunk to cuda_shard
# the communication is not necessary
self.__scatter()
+ if self.extra_dp_group is not None:
+ dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
elif self.keep_gathered:
# we use all-reduce here
dist.all_reduce(self.cuda_global_chunk, group=self.torch_pg)
+ if self.extra_dp_group is not None:
+ dist.all_reduce(self.cuda_global_chunk, group=self.extra_dp_group)
else:
self.cuda_shard = torch.empty(self.shard_size, dtype=self.dtype, device=get_current_device())
input_list = list(torch.chunk(self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.reduce_scatter(self.cuda_shard, input_list, group=self.torch_pg)
+ if self.extra_dp_group is not None:
+ dist.all_reduce(self.cuda_shard, group=self.extra_dp_group)
free_storage(self.cuda_global_chunk)
self.is_gathered = False
@@ -475,7 +483,7 @@ def optim_update(self) -> None:
assert friend_chunk.is_gathered is True
self.cuda_global_chunk.copy_(friend_chunk.cuda_global_chunk)
self.optim_sync_flag = True
- elif friend_chunk.device_type == "cuda" and self.device_type == "cuda":
+ elif friend_chunk.device_type in ("cuda", "npu") and self.device_type in ("cuda", "npu"):
self.cuda_shard.copy_(friend_chunk.cuda_shard)
self.optim_sync_flag = True
self.cpu_vis_flag = False
@@ -608,10 +616,11 @@ def init_grad_chunk(self) -> "Chunk":
# grad chunk is not initialized
grad_chunk = Chunk(
chunk_size=self.chunk_size,
- process_group=self.torch_pg,
+ zero_group=self.torch_pg,
dtype=self.dtype,
keep_gathered=self.keep_gathered,
pin_memory=self.pin_memory,
+ extra_dp_group=self.extra_dp_group,
)
grad_chunk.num_tensors = self.num_tensors
grad_chunk.utilized_size = self.utilized_size
@@ -637,6 +646,7 @@ def init_grad_chunk(self) -> "Chunk":
# grad chunk is initialized, just reallocate cuda global chunk
self.grad_chunk.cuda_shard = None
self.grad_chunk.is_gathered = True
+ self.grad_chunk.l2_norm = None
alloc_storage(self.grad_chunk.cuda_global_chunk)
return self.grad_chunk
diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py
index d3c512fe978d..5f4f37c267aa 100644
--- a/colossalai/zero/gemini/chunk/manager.py
+++ b/colossalai/zero/gemini/chunk/manager.py
@@ -38,7 +38,8 @@ def register_tensor(
tensor: torch.Tensor,
group_type: str,
config_key: int,
- process_group: ProcessGroup,
+ zero_group: ProcessGroup,
+ extra_dp_group: ProcessGroup = None,
cpu_offload: bool = False,
pin_memory: bool = False,
) -> None:
@@ -76,15 +77,16 @@ def register_tensor(
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
- dp_size = dist.get_world_size(process_group)
+ dp_size = dist.get_world_size(zero_group)
chunk_size = chunk_size + (-chunk_size % dp_size)
chunk = Chunk(
chunk_size=chunk_size,
- process_group=process_group,
+ zero_group=zero_group,
dtype=tensor.dtype,
cpu_shard_init=cpu_offload,
pin_memory=pin_memory,
+ extra_dp_group=extra_dp_group,
**chunk_kwargs,
)
@@ -204,7 +206,10 @@ def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
"""
assert tensor not in self.tensor_chunk_map
- self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
+ device_type = tensor.device.type
+ if device_type == "npu":
+ device_type = "cuda"
+ self.total_mem[device_type] += tensor.numel() * tensor.element_size()
def __repr__(self) -> str:
msg = [
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index df7e1163c3d9..5217b8036bcd 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -10,11 +10,22 @@
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
-from colossalai.checkpoint_io.utils import StateDictSharder
+from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.tensor.colo_parameter import ColoParameter
+from colossalai.tensor.d_tensor import (
+ distribute_tensor,
+ distribute_tensor_with_customization,
+ get_device_mesh,
+ get_global_shape,
+ get_sharding_spec,
+ init_as_dtensor,
+ init_tensor_as_customization_distributed,
+ is_customized_distributed_tensor,
+ is_distributed_tensor,
+)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
@@ -73,9 +84,10 @@ def __init__(
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16,
- process_group: Optional[ProcessGroup] = None,
+ zero_group: Optional[ProcessGroup] = None,
memstats: Optional[MemStats] = None, # genimi memory stats
master_weights: bool = True,
+ extra_dp_group: Optional[ProcessGroup] = None,
verbose: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
@@ -92,7 +104,7 @@ def __init__(
search_range_m=search_range_m,
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
- process_group=process_group,
+ process_group=zero_group,
verbose=verbose,
)
self.gemini_manager = GeminiManager(
@@ -115,7 +127,8 @@ def __init__(
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
- self.dp_process_group = process_group or _get_default_group()
+ self.zero_group = zero_group or _get_default_group()
+ self.extra_dp_group = extra_dp_group
self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights
@@ -147,7 +160,7 @@ def __init__(
self._init_chunks(
param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
- cpu_offload=self.gemini_manager.policy_name != "cuda",
+ cpu_offload=not (self.gemini_manager.policy_name == "static" and offload_param_frac == 0),
pin_memory=pin_memory,
)
super().__init__(module)
@@ -318,9 +331,7 @@ def backward(self, loss: torch.Tensor):
self._post_backward()
def backward_by_grad(self, tensor, grad):
- with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
- torch.autograd.backward(tensor, grad)
- self._post_backward()
+ raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
def grad_handle(self, p, grad):
setattr(p, "_gemini_reduced", True)
@@ -343,6 +354,7 @@ def grad_handle(self, p, grad):
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk
+ chunk.grad_chunk.l2_norm = None
# hold -> compute -> hold after bwd
grad_chunk.tensor_trans_state(p, TensorState.COMPUTE)
@@ -365,8 +377,12 @@ def grad_handle(self, p, grad):
self.chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
+ if self.extra_dp_group is not None:
+ grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
+ if self.extra_dp_group is not None:
+ grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
self.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
@@ -430,7 +446,19 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
- record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu()
+ record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).to(tensor.device)
+ if is_distributed_tensor(tensor):
+ global_shape = get_global_shape(tensor)
+ device_mesh = get_device_mesh(tensor)
+ shard_spec = get_sharding_spec(tensor)
+ record_tensor = init_as_dtensor(
+ record_tensor, device_mesh=device_mesh, sharding_spec=shard_spec, global_shape=global_shape
+ )
+ elif is_customized_distributed_tensor(tensor):
+ init_tensor_as_customization_distributed(
+ record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
+ )
+ record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
assert tensor not in chunk_to_save_data
chunk_to_save_data[tensor] = record_tensor
@@ -605,10 +633,26 @@ def _load_from_state_dict(
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
- def load(param_name, dest_tensor, copy_func):
+ def load(
+ param_name,
+ dest_tensor,
+ copy_func,
+ source_device_mesh=None,
+ source_sharding_spec=None,
+ shard_fn=None,
+ gather_fn=None,
+ ):
state_key = prefix + param_name
if state_key in state_dict:
input_param = state_dict[state_key]
+
+ if source_device_mesh is not None and source_sharding_spec is not None:
+ input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
+ elif shard_fn is not None and gather_fn is not None:
+ input_param = distribute_tensor_with_customization(
+ input_param, shard_fn=shard_fn, gather_fn=gather_fn
+ )
+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
@@ -652,9 +696,26 @@ def load_parameter(chunk_slice, data):
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
for tensor, tensor_info in chunk.tensors_info.items():
+ source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None
+ if is_distributed_tensor(tensor):
+ # shard the input param
+ source_device_mesh = get_device_mesh(tensor)
+ source_sharding_spec = get_sharding_spec(tensor)
+ elif is_customized_distributed_tensor(tensor):
+ shard_fn = tensor.shard_fn
+ gather_fn = tensor.gather_fn
+
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
- load(parameter_name, tensor, partial(load_parameter, parameter_slice))
+ load(
+ parameter_name,
+ tensor,
+ partial(load_parameter, parameter_slice),
+ source_device_mesh,
+ source_sharding_spec,
+ shard_fn,
+ gather_fn,
+ )
if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk)
@@ -694,7 +755,7 @@ def load_parameter(chunk_slice, data):
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
- dp_world_size = dist.get_world_size(self.dp_process_group)
+ zero_world_size = dist.get_world_size(self.zero_group)
for p in param_order.generate():
self._preprocess_param(p)
assert type(p) is ColoParameter
@@ -714,8 +775,9 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi
self.chunk_manager.register_tensor(
tensor=p,
group_type="fp16_param",
- config_key=dp_world_size,
- process_group=self.dp_process_group,
+ config_key=zero_world_size,
+ zero_group=self.zero_group,
+ extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
@@ -723,12 +785,14 @@ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pi
if self.master_weights:
# create a fp32 parameter
- fp32_p = p.data.float()
+ fp32_p = p.clone()
+ fp32_p.data = fp32_p.data.float()
self.chunk_manager.register_tensor(
tensor=fp32_p,
group_type="fp32_param",
- config_key=dp_world_size,
- process_group=self.dp_process_group,
+ config_key=zero_world_size,
+ zero_group=self.zero_group,
+ extra_dp_group=self.extra_dp_group,
cpu_offload=cpu_offload,
pin_memory=pin_memory,
)
@@ -751,7 +815,7 @@ def _cast_buffers(self):
for buffer in self.module.buffers():
if isinstance(buffer, LazyTensor):
buffer.materialize()
- buffer.data = buffer.cuda()
+ buffer.data = buffer.to(get_current_device())
if torch.is_floating_point(buffer):
buffer.data = buffer.to(self.mixed_precision)
diff --git a/colossalai/zero/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py
index f7ff3f6cdd86..150932e3d8d9 100644
--- a/colossalai/zero/gemini/gemini_mgr.py
+++ b/colossalai/zero/gemini/gemini_mgr.py
@@ -17,9 +17,7 @@ class GeminiManager:
https://arxiv.org/abs/2108.05818
Args:
- placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
- If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
- If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
+ placement_policy (str): Which device to place *held* tensors. It can be 'static' and 'auto'.
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
chunk_manager (ChunkManager): A ``ChunkManager`` instance.
@@ -121,7 +119,7 @@ def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk,
start = time()
cuda_demand = 0
for chunk in chunks:
- if chunk.device_type == "cuda":
+ if chunk.device_type == "cuda" or chunk.device_type == "npu":
if chunk.is_gathered:
pass
else:
diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py
index 0d0298e067f3..8f828bd6cf20 100644
--- a/colossalai/zero/gemini/gemini_optimizer.py
+++ b/colossalai/zero/gemini/gemini_optimizer.py
@@ -7,14 +7,25 @@
import torch
import torch.distributed as dist
from packaging.version import Version
+from torch.distributed import ProcessGroup
from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
-from colossalai.checkpoint_io.utils import StateDictSharder
+from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
+from colossalai.tensor.d_tensor import (
+ distribute_tensor,
+ distribute_tensor_with_customization,
+ get_device_mesh,
+ get_sharding_spec,
+ init_as_dtensor,
+ init_tensor_as_customization_distributed,
+ is_customized_distributed_tensor,
+ is_distributed_tensor,
+)
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
@@ -93,6 +104,8 @@ def __init__(
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
+ tp_group: ProcessGroup = None,
+ optimizer_params_info=None,
verbose: bool = False,
**defaults: Any,
):
@@ -109,6 +122,10 @@ def __init__(
self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm
+ self.tp_group = tp_group
+ self.optimizer_params_info = optimizer_params_info
+ self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
+ self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose
self.param_groups_backup = list()
@@ -293,7 +310,7 @@ def _maybe_move_fp32_params(self):
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
- if chunk32.device_type == "cuda":
+ if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
continue
if fp32_params_used_cuda_margin_mem + chunk32.payload_mem < fp32_params_available_cuda_margin_mem:
@@ -307,7 +324,7 @@ def _maybe_move_fp32_params(self):
for fake_param in group["params"]:
chunk16 = self.param_to_chunk16[fake_param]
chunk32 = chunk16.paired_chunk
- if chunk32.device_type == "cuda":
+ if chunk32.device_type == "cuda" or chunk32.device_type == "npu":
state = self.optim.state[fake_param]
for k, v in state.items():
if isinstance(v, torch.Tensor):
@@ -406,8 +423,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
param = self.id_to_real_params[param_id]
fake_param = self.id_to_fake_params.get(param_id, None)
chunk = self.chunk_manager.get_chunk(param)
- process_group = chunk.torch_pg
- rank = dist.get_rank(process_group)
+ zero_group = chunk.torch_pg
+ rank = dist.get_rank(zero_group)
master_rank = 0
collected_states = {}
@@ -415,9 +432,9 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
local_state_names = None
if fake_param is not None:
local_state_names = list(self.optim.state[fake_param].keys())
- gathered_state_names = [None for _ in range(dist.get_world_size(process_group))]
+ gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier()
- dist.all_gather_object(gathered_state_names, local_state_names)
+ dist.all_gather_object(gathered_state_names, local_state_names, zero_group)
state_names = None
for names in gathered_state_names:
if names is not None:
@@ -436,6 +453,13 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
# Every rank is collector when only_rank_0 is False.
is_collector = (rank == master_rank) or (not only_rank_0)
+ # get tensor parallelism information
+ is_dtensor = is_distributed_tensor(param)
+ is_customized_distributed = is_customized_distributed_tensor(param)
+ shard_spec = get_sharding_spec(param) if is_dtensor else None
+ device_mesh = get_device_mesh(param) if is_dtensor else None
+ global_shape = self.optimizer_params_info["id2shape"][param_id]
+
# If the chunk is kept gathered,
# the parameteres are treated the same as that of those in strict DDP during training.
# So states can be directly fetched from current device.
@@ -451,7 +475,22 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
).cpu()
else:
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
- collected_states[state_name] = torch.reshape(state_tensor, param.shape)
+ if is_dtensor:
+ state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
+ state_tensor = init_as_dtensor(
+ state_tensor,
+ device_mesh=device_mesh,
+ sharding_spec=shard_spec,
+ global_shape=global_shape,
+ )
+ elif is_customized_distributed:
+ state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
+ init_tensor_as_customization_distributed(
+ state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
+ )
+ state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
+
+ collected_states[state_name] = state_tensor.reshape(global_shape)
return collected_states
# Check whether the param with given id is managed by current process.
@@ -473,10 +512,10 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
_, shard_offset, shard_size = self.get_offsets(param_id)
# Collectors gather state shards through all_gathering.
- gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))]
+ gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier()
- dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
+ dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group)
if is_collector:
for state_shard in gathered_state_shards:
@@ -494,6 +533,17 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
for state_name, state_tensor in collected_states.items():
if state_tensor.numel() == param.numel():
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
+ if is_dtensor:
+ state_tensor = state_tensor.to(param.device)
+ state_tensor = init_as_dtensor(
+ state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
+ )
+ elif is_customized_distributed:
+ state_tensor = state_tensor.to(param.device)
+ init_tensor_as_customization_distributed(
+ state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
+ )
+ state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
return collected_states
@@ -501,7 +551,7 @@ def pack_optimizer_states_to_tensor(
self,
param_id: int,
state_names: list,
- device: torch.device = torch.device("cuda"),
+ device: torch.device = get_current_device(),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
@@ -658,6 +708,14 @@ def cast(param, state_range, value, key=None):
ret_val = torch.zeros(
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
)
+
+ if is_dtensor:
+ value = torch.reshape(value, global_shape)
+ value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
+ elif is_customized_distributed:
+ value = torch.reshape(value, global_shape)
+ value = distribute_tensor_with_customization(value, real_param.shard_fn, real_param.gather_fn)
+
ret_val.copy_(value.flatten()[state_start:state_end])
return ret_val
@@ -668,6 +726,15 @@ def cast(param, state_range, value, key=None):
# Copy states assigned to param (and cast tensors to appropriate types).
updated_states = dict()
+
+ # get tensor parallelism information
+ real_param = self.id_to_real_params[param_id]
+ is_dtensor = is_distributed_tensor(real_param)
+ is_customized_distributed = is_customized_distributed_tensor(real_param)
+ shard_spec = get_sharding_spec(real_param) if is_dtensor else None
+ device_mesh = get_device_mesh(real_param) if is_dtensor else None
+ global_shape = self.optimizer_params_info["id2shape"][param_id]
+
for k, v in saved_states.items():
updated_states[k] = cast(fake_param, state_range, v, k)
del v # clean loaded states
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index e6974a6760ce..c1b35ee17f91 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -8,9 +8,11 @@
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor, inf
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
+import colossalai.utils.device as device_utils
from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
@@ -18,9 +20,10 @@
)
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
# from colossalai.tensor import ColoParameter, ProcessGroup
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import IS_NPU_AVAILABLE, get_current_device
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from .bookkeeping import BucketStore, GradientStore, ParameterStore
@@ -75,6 +78,7 @@ def __init__(
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
forced_dtype: Optional[torch.dtype] = None,
+ moe_extra_dp_process_group: Optional[ProcessGroup] = None,
master_weights: bool = True, # master weights
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@@ -95,6 +99,16 @@ def __init__(
self._local_rank = dist.get_rank(group=self.dp_pg)
self._world_size = dist.get_world_size(group=self.dp_pg)
+ # extra dp
+ # This group is used to sync moe param, dp_world_size = moe_duplicates * extra_dp_size.
+ # Non moe param will be sync by global dp pg, moe param will be sync by extra dp pg.
+ # Moe param grad is be split as non moe param by global dp pg, and grad will be merged in step.
+ # And moe working and master param are split by extra dp pg.
+ self.moe_extra_dp_pg = moe_extra_dp_process_group
+ if self.moe_extra_dp_pg is not None:
+ self.moe_extra_dp_pg_size = dist.get_world_size(group=self.moe_extra_dp_pg)
+ self.moe_extra_dp_pg_rank = dist.get_rank(group=self.moe_extra_dp_pg)
+
# working and master params for mixed precision training
self._working_param_groups = dict()
self._master_param_groups_of_current_rank = dict()
@@ -126,6 +140,12 @@ def __init__(
self._grad_store = GradientStore(self.dp_pg, partition_grad=partition_grad)
self._bucket_store = BucketStore(self.dp_pg)
+ # moe param should not be stored in working_groups
+ # because they have different parallel strategy
+ # so we need to store them separately in param_groups
+ # instead of working_groups
+ moe_params = list()
+
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
@@ -133,6 +153,11 @@ def __init__(
group_params = list()
for param in param_group["params"]:
if param.requires_grad:
+ if self.moe_extra_dp_pg is None:
+ # skip moe param
+ if is_moe_tensor(param):
+ moe_params.append(param)
+ continue
group_params.append(param)
# add the working params to working_param_groups for bookkeeping
@@ -146,10 +171,19 @@ def __init__(
# managed by this data parallel rank
param_group["params"] = master_param_current_rank
+ # if there are moe params, store in addtional group in optim
+ if len(moe_params) > 0:
+ param_group = dict()
+ for key, value in self.optim.param_groups[0].items():
+ if key != "params":
+ param_group[key] = value
+ param_group["params"] = moe_params
+ self.optim.param_groups.append(param_group)
+
# intialize communication stream for
# communication-compuation overlapping
if self._overlap_communication:
- self._comm_stream = torch.cuda.Stream()
+ self._comm_stream = device_utils.Stream()
# reduction hook is only used if overlapping communication
# or stage 2 is used
@@ -183,7 +217,7 @@ def num_param_groups(self):
return len(self._working_param_groups)
def _sanity_checks(self):
- assert torch.cuda.is_available(), "CUDA is required"
+ assert torch.cuda.is_available() or IS_NPU_AVAILABLE, "device is required"
for param_group in self.optim.param_groups:
group_params = param_group["params"]
for param in group_params:
@@ -208,13 +242,20 @@ def _create_master_param_current_rank(self, param_list):
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
- splited_params = padding_param.split(padding_param.numel() // self._world_size)
+
+ if self.moe_extra_dp_pg is not None and is_moe_tensor(param):
+ splited_params = padding_param.split(padding_param.numel() // self.moe_extra_dp_pg_size)
+ splited_params = splited_params[self.moe_extra_dp_pg_rank]
+ else:
+ splited_params = padding_param.split(padding_param.numel() // self._world_size)
+ splited_params = splited_params[self._local_rank]
# use fp32 when master_weights is True
if self._master_weights is True:
- splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
+ splited_param_current_rank = splited_params.detach().float().to(device)
else:
- splited_param_current_rank = splited_params[self._local_rank]
+ splited_param_current_rank = splited_params
+
params_current_rank.append(splited_param_current_rank)
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
@@ -247,8 +288,43 @@ def _run_reduction(self):
if self._bucket_store.num_elements_in_bucket() > 0:
self._bucket_store.build_grad_in_bucket()
- flat_grads = self._bucket_store.get_flatten_grad()
- flat_grads /= self._world_size
+ if self.moe_extra_dp_pg is None:
+ flat_grads = self._bucket_store.get_flatten_grad()
+ flat_grads /= self._world_size
+ else:
+ # record moe and non moe param
+ moe_list = []
+ for param in self._bucket_store._param_list:
+ moe_list.append(is_moe_tensor(param))
+
+ # divide them into different groups
+ moe_grad_list = []
+ non_moe_grad_list = []
+ for grad_list in self._bucket_store._grad_in_bucket.values():
+ non_moe_cur_grad = []
+ moe_cur_grad = []
+ for i in range(len(grad_list)):
+ if moe_list[i] == True:
+ moe_cur_grad.append(grad_list[i])
+ else:
+ non_moe_cur_grad.append(grad_list[i])
+ if len(moe_cur_grad) > 0:
+ moe_grad_list.append(moe_cur_grad)
+ if len(non_moe_cur_grad) > 0:
+ non_moe_grad_list.append(non_moe_cur_grad)
+
+ if len(non_moe_grad_list) > 0:
+ non_moe_flat_grads = []
+ for grad_list in non_moe_grad_list:
+ non_moe_flat_grads.append(_flatten_dense_tensors(grad_list))
+ non_moe_flat_grads = _flatten_dense_tensors(non_moe_flat_grads)
+ non_moe_flat_grads /= self._world_size
+
+ if len(moe_grad_list) > 0:
+ moe_flat_grads = []
+ for grad_list in moe_grad_list:
+ moe_flat_grads.append(_flatten_dense_tensors(grad_list))
+ moe_flat_grads = _flatten_dense_tensors(moe_flat_grads)
# ready to add other tensors to bucket
self._bucket_store.reset_num_elements_in_bucket()
@@ -256,58 +332,123 @@ def _run_reduction(self):
if self._overlap_communication:
stream = self._comm_stream
# in case of the memory being reused in the default stream
- flat_grads.record_stream(stream)
+ if self.moe_extra_dp_pg is None:
+ flat_grads.record_stream(stream)
+ else:
+ if len(non_moe_grad_list) > 0:
+ non_moe_flat_grads.record_stream(stream)
+ if len(moe_grad_list) > 0:
+ moe_flat_grads.record_stream(stream)
# waiting for ops in the default stream finishing
- stream.wait_stream(torch.cuda.current_stream())
+ stream.wait_stream(device_utils.current_stream())
else:
- stream = torch.cuda.current_stream()
+ stream = device_utils.current_stream()
- with torch.cuda.stream(stream):
+ with device_utils.stream(stream):
group_id = self._bucket_store.current_group_id
- grad_dtype = flat_grads.dtype
- if self._communication_dtype is not None:
- flat_grads = flat_grads.to(self._communication_dtype)
+ if self.moe_extra_dp_pg is None:
+ grad_dtype = flat_grads.dtype
+ if self._communication_dtype is not None:
+ flat_grads = flat_grads.to(self._communication_dtype)
if not self._partition_grads:
- dist.all_reduce(flat_grads, group=self.dp_pg)
- if flat_grads.dtype != grad_dtype:
- flat_grads = flat_grads.to(grad_dtype)
-
- flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
- grad_in_bucket = self._bucket_store.get_grad()
-
- for rank, grad_list in grad_in_bucket.items():
- sync_tensor(flat_grads_per_rank[rank], grad_list)
- for grad in grad_list:
- param_id = self._bucket_store.get_param_id_of_grad(grad)
- if (
- len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id))
- < self._world_size
- ):
- self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
- else:
- self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
+ if self.moe_extra_dp_pg is None:
+ dist.all_reduce(flat_grads, group=self.dp_pg)
+ if flat_grads.dtype != grad_dtype:
+ flat_grads = flat_grads.to(grad_dtype)
+
+ flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size)
+ grad_in_bucket = self._bucket_store.get_grad()
+ self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id)
+
+ # sync extra zero group
+ else:
+ # sync non moe param in global dp group
+ if len(non_moe_grad_list) > 0:
+ dist.all_reduce(non_moe_flat_grads, group=self.dp_pg)
+ flat_grads_per_rank = non_moe_flat_grads.split(
+ non_moe_flat_grads.numel() // self._world_size
+ )
+ self._update_unpartitoned_grad(non_moe_grad_list, flat_grads_per_rank, group_id)
+
+ # sync moe param only in zero group
+ if len(moe_grad_list) > 0:
+ dist.all_reduce(moe_flat_grads, group=self.moe_extra_dp_pg)
+ flat_grads_per_rank = moe_flat_grads.split(moe_flat_grads.numel() // self._world_size)
+ self._update_unpartitoned_grad(moe_grad_list, flat_grads_per_rank, group_id)
else:
- flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
- recieved_grad = torch.zeros_like(flat_grads_list[0])
- dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
-
- if recieved_grad.dtype != grad_dtype:
- recieved_grad = recieved_grad.to(grad_dtype)
-
- grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
- sync_tensor(recieved_grad, grad_in_bucket_current_rank)
- for grad in grad_in_bucket_current_rank:
- param_id = self._bucket_store.get_param_id_of_grad(grad)
- if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
- self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
- else:
- self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
+ if self.moe_extra_dp_pg is None:
+ flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
+ recieved_grad = torch.zeros_like(flat_grads_list[0])
+ dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
+
+ if recieved_grad.dtype != grad_dtype:
+ recieved_grad = recieved_grad.to(grad_dtype)
+
+ grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
+ self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1)
+ else:
+ # categorize moe and non moe param
+ grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank]
+ moe_grad_in_bucket_current_rank = []
+ non_moe_grad_in_bucket_current_rank = []
+ for idx, grad in enumerate(grad_in_bucket_current_rank):
+ if moe_list[idx] == True:
+ moe_grad_in_bucket_current_rank.append(grad)
+ else:
+ non_moe_grad_in_bucket_current_rank.append(grad)
+
+ if len(non_moe_grad_list) > 0:
+ flat_grads_list = list(
+ non_moe_flat_grads.split(len(non_moe_flat_grads) // self._world_size)
+ )
+ recieved_grad = torch.zeros_like(flat_grads_list[0])
+ dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.dp_pg)
+ self._update_partitoned_grad(
+ non_moe_grad_in_bucket_current_rank, recieved_grad, group_id, 1
+ )
+
+ if len(moe_grad_list) > 0:
+ flat_grads_list = list(
+ moe_flat_grads.split(len(moe_flat_grads) // self.moe_extra_dp_pg_size)
+ )
+ recieved_grad = torch.zeros_like(flat_grads_list[0])
+ dist.reduce_scatter(recieved_grad, flat_grads_list, group=self.moe_extra_dp_pg)
+ param_slice = self._world_size // self.moe_extra_dp_pg_size
+ recieved_grad = list(recieved_grad.split(len(recieved_grad) // param_slice))
+ for split_recieved_grad in recieved_grad:
+ split_recieved_grad = _unflatten_dense_tensors(
+ split_recieved_grad, moe_grad_in_bucket_current_rank
+ )
+ for real_grad, grad in zip(split_recieved_grad, moe_grad_in_bucket_current_rank):
+ param_id = self._bucket_store.get_param_id_of_grad(grad)
+ self._add_grad(real_grad, param_slice, group_id, param_id)
self._bucket_store.reset()
+ def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None:
+ for rank, grad_list in enumerate(origin_grad_list):
+ sync_tensor(flat_grad_list[rank], grad_list)
+ for grad in grad_list:
+ param_id = self._bucket_store.get_param_id_of_grad(grad)
+ self._add_grad(grad, self._world_size, group_id, param_id, rank)
+
+ def _update_partitoned_grad(
+ self, origin_grad_list: List, flat_grad: torch.Tensor, group_id: int, partition_num: int
+ ) -> None:
+ sync_tensor(flat_grad, origin_grad_list)
+ for grad in origin_grad_list:
+ param_id = self._bucket_store.get_param_id_of_grad(grad)
+ self._add_grad(grad, partition_num, group_id, param_id)
+
+ def _add_grad(self, grad: torch.Tensor, partition_num: int, group_id: int, param_id: int, rank: int = 0) -> None:
+ if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num:
+ self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
+ else:
+ self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
+
def _add_to_bucket(self, param, group_id):
param_size = param.numel()
@@ -345,7 +486,7 @@ def backward(self, loss, retain_graph=False):
# clear reduced grads
if self._overlap_communication:
- torch.cuda.synchronize()
+ device_utils.synchronize()
self.zero_grad()
@@ -364,7 +505,7 @@ def backward_by_grad(self, tensor, grad):
# clear reduced grads
if self._overlap_communication:
- torch.cuda.synchronize()
+ device_utils.synchronize()
self.zero_grad()
@@ -424,13 +565,23 @@ def step(self, closure=None):
# else the splited grad should be attached to the splited param
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
if len(grads) > 0:
- real_working_params[group_id].append(working_param)
+ # moe hybrid zero
+ if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
+ real_working_params[group_id].append(working_param)
+ if self._partition_grads:
+ grad = grads
+ else:
+ param_slice = self._world_size // self.moe_extra_dp_pg_size
+ grad = grads[
+ self.moe_extra_dp_pg_rank * param_slice : (self.moe_extra_dp_pg_rank + 1) * param_slice
+ ]
+ grad = flatten(grad)
+ else:
+ real_working_params[group_id].append(working_param)
+ grad = grads[grad_index]
# no need to copy fp32 grad if master_weights is False
- grad = (
- grads[grad_index].to(splited_param.dtype).to(splited_param.device)
- if self._master_weights
- else grads[grad_index]
- )
+ if self._master_weights:
+ grad = grad.to(splited_param.dtype).to(splited_param.device)
splited_param.grad = grad
grad_partition_groups.append(grad)
real_master_params[group_id].append(splited_param)
@@ -449,24 +600,46 @@ def step(self, closure=None):
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
+ # TODO: we should store master param for ep
+ if len(self.param_groups) > len(self._working_param_groups):
+ for param in self.param_groups[-1]["params"]:
+ param.data = param.data.to(torch.float32)
+ param.grad = param.grad.to(torch.float32)
+
# update the parameters
self.optim.step()
+ # release the moe gradm
+ if len(self.param_groups) > len(self._working_param_groups):
+ for param in self.param_groups[-1]["params"]:
+ param.grad = None
+ param.data = param.data.to(self._dtype)
+
# release the grad
grad_partition_groups = []
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank
- # dtype = real_working_params[0][0].dtype
+ device = get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
- all_splited_param = [
- torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size)
- ]
- dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
+ if self.moe_extra_dp_pg is not None and is_moe_tensor(working_param):
+ all_splited_param = [
+ torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
+ for _ in range(self.moe_extra_dp_pg_size)
+ ]
+ dist.all_gather(
+ all_splited_param, splited_param.to(device).to(self._dtype), group=self.moe_extra_dp_pg
+ )
+ else:
+ all_splited_param = [
+ torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
+ for _ in range(self._world_size)
+ ]
+ dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
@@ -488,8 +661,7 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
norm_type = float(norm_type)
if norm_type == inf:
total_norm = max(grad.data.abs().max() for grad in gradients)
-
- total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
+ total_norm_cuda = torch.tensor([float(total_norm)], device=get_current_device(), dtype=torch.float)
dist.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=self.dp_pg)
total_norm = total_norm_cuda.item()
@@ -500,7 +672,9 @@ def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> flo
total_norm_exponentiated += grad_norm_exponentiated
# Sum across all model parallel GPUs.
- total_norm_exponentiated_cuda = torch.cuda.FloatTensor([float(total_norm_exponentiated)])
+ total_norm_exponentiated_cuda = torch.tensor(
+ [float(total_norm_exponentiated)], device=get_current_device(), dtype=torch.float
+ )
torch.distributed.all_reduce(
total_norm_exponentiated_cuda, op=torch.distributed.ReduceOp.SUM, group=self.dp_pg
)
@@ -532,7 +706,7 @@ def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
############################
# this method is used to sync gradient manually
- def sync_grad(self):
+ def _sync_grad(self):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
@@ -545,7 +719,7 @@ def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication:
- self.sync_grad()
+ self._sync_grad()
else:
self._run_reduction()
@@ -591,15 +765,22 @@ def state_dict(self) -> Dict:
Dict: the pytorch form state_dict
"""
zero_state = dict()
+ device = get_current_device()
for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != "step":
working_param = self._param_store.master_to_working_param[id(param)]
- gather_tensor = [
- torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)
- ]
- dist.all_gather(gather_tensor, v.cuda(), group=self.dp_pg)
+ if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
+ gather_tensor = [
+ torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
+ ]
+ dist.all_gather(gather_tensor, v.to(device), group=self.moe_extra_dp_pg)
+ else:
+ gather_tensor = [
+ torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
+ ]
+ dist.all_gather(gather_tensor, v.to(device), group=self.dp_pg)
param_state = (
torch.stack(gather_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@@ -624,8 +805,12 @@ def load_state_dict(self, state_dict: Dict):
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
- v_list = v.split(v.numel() // self._world_size)
- zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
+ if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
+ v_list = v.split(v.numel() // self.moe_extra_dp_pg_size)
+ zero_state_dict["state"][param_idx][k] = v_list[self.moe_extra_dp_pg_rank].detach().clone()
+ else:
+ v_list = v.split(v.numel() // self._world_size)
+ zero_state_dict["state"][param_idx][k] = v_list[self._local_rank].detach().clone()
self.optim.load_state_dict(zero_state_dict)
@@ -642,6 +827,7 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i
ret_block = dict()
ret_block_size = 0
+ device = get_current_device()
local_states = self.optim.state_dict()["state"]
for param_idx, states in local_states.items():
current_block_size = 0
@@ -656,8 +842,16 @@ def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, i
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != "step":
- state_tensor = [torch.zeros(v.shape, device="cuda", dtype=v.dtype) for _ in range(self._world_size)]
- dist.all_gather(state_tensor, v.cuda(), group=self.dp_pg)
+ if self.moe_extra_dp_pg is not None and is_moe_tensor(v):
+ state_tensor = [
+ torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self.moe_extra_dp_pg_size)
+ ]
+ dist.all_gather(state_tensor, v.to(device), group=self.moe_extra_dp_pg)
+ else:
+ state_tensor = [
+ torch.zeros(v.shape, device=device, dtype=v.dtype) for _ in range(self._world_size)
+ ]
+ dist.all_gather(state_tensor, v.to(device), group=self.dp_pg)
state_tensor = (
torch.stack(state_tensor).view(-1)[: working_param.numel()].reshape_as(working_param).cpu()
)
@@ -688,7 +882,10 @@ def update_master_params(self, model: nn.Module) -> None:
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
- master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
+ if self.moe_extra_dp_pg is not None and is_moe_tensor(p):
+ master_param.copy_(working_param.chunk(self.extra_dp_pg_size)[self.extra_dp_pg_rank])
+ else:
+ master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
diff --git a/docs/source/en/basics/booster_plugins.md b/docs/source/en/basics/booster_plugins.md
index fa360a4b9213..2426d8c32672 100644
--- a/docs/source/en/basics/booster_plugins.md
+++ b/docs/source/en/basics/booster_plugins.md
@@ -26,7 +26,7 @@ Generally only one plugin is used to train a model. Our recommended use case for
- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters (e.g. Bert-3m, GPT2-1.5b).
- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters (e.g. GPTJ-6b, MegatronLM-8b).
- [Gemini Plugin](#gemini-plugin): It is suitable for models with more than 10 billion parameters (e.g. TuringNLG-17b) and is ideal for scenarios with **high cross-node bandwidth and medium to small-scale clusters (below a thousand cards)** (e.g. Llama2-70b).
-- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, or special models such as those with exceptionally long sequences, very large vocabularies, and is best suited for scenarios with **low cross-node bandwidth and large-scale clusters (a thousand cards or more)** (e.g. GPT3-175b, Bloom-176b).
+- [Hybrid Parallel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, or special models such as those with exceptionally long sequences, very large vocabularies, and is best suited for scenarios with **low cross-node bandwidth and large-scale clusters (a thousand cards or more)** (e.g. GPT3-175b, Bloom-176b).
## Plugins
@@ -58,7 +58,11 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme
This plugin implements the combination of various parallel training strategies and optimization tools. The features of HybridParallelPlugin can be generally divided into four parts:
-1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer. More details can be found in chapter [Shardformer Doc](../features/shardformer.md).
+1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer. More details can be found in chapter [Shardformer Doc](../features/shardformer.md). The diagram below shows the features supported by shardformer together with hybrid parallel plugin.
+
+
+
+
2. Mixed Precision Training: Support for fp16/bf16 mixed precision training. More details about its arguments configuration can be found in [Mixed Precision Training Doc](../features/mixed_precision_training_with_booster.md).
diff --git a/docs/source/en/features/lazy_init.md b/docs/source/en/features/lazy_init.md
index 133fd799280a..160f68767156 100644
--- a/docs/source/en/features/lazy_init.md
+++ b/docs/source/en/features/lazy_init.md
@@ -1,6 +1,6 @@
# Lazy initialization
-Author: [Hongxiu Liu](https://github.com/ver217)
+Author: [Hongxin Liu](https://github.com/ver217)
**Prerequisite:**
- [Train with booster](../basics/booster_api.md)
@@ -73,4 +73,4 @@ And some models are not supported at all which will raise an error. We tested mo
| Blip2Model | transformers |
| Blip2ForConditionalGeneration | transformers |
-
+
diff --git a/docs/source/en/features/pipeline_parallel.md b/docs/source/en/features/pipeline_parallel.md
index cb19f9815bf2..31b20335e529 100644
--- a/docs/source/en/features/pipeline_parallel.md
+++ b/docs/source/en/features/pipeline_parallel.md
@@ -167,7 +167,7 @@ plugin = HybridParallelPlugin(tp_size=1,
booster = Booster(plugin=plugin)
```
-Boost these train componts with the booster created.
+Boost these train components with the booster created.
```python
model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
optimizer,
diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md
index a6e32d2c05fa..bf7b2b3e4305 100644
--- a/docs/source/en/features/shardformer.md
+++ b/docs/source/en/features/shardformer.md
@@ -20,7 +20,7 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
## Introduction
-When training large transformer models such as LLaMa-2 70B or OPT 175B, model parallelism methods that divide a huge model into smaller shards, including tensor parallelism or pipeline parallism, are essential so as to meet the limitation of GPU memory.
+When training large transformer models such as LLaMa-2 70B or OPT 175B, model parallelism methods that divide a huge model into smaller shards, including tensor parallelism or pipeline parallelism, are essential so as to meet the limitation of GPU memory.
However, manually cutting model and rewriting its forward/backword logic could be difficult for users who are not familiar with distributed training.
Meanwhile, the Huggingface transformers library has gradually become users' first choice of model source, and most mainstream large models have been open-sourced in Huggingface transformers model library.
@@ -321,7 +321,7 @@ For example, when training LlaMa-2 with tensor parallel size as 2, the attribute
3. Replacing the `forward` methods implemented by original Huggingface
Transformers libraries with our customized `forward` methods.
-This replacement is essential for pipeline paralellism, where a customiozed function is needed to pass intermediate hidden states between different pipeline stages.
+This replacement is essential for pipeline parallelism, where a customized function is needed to pass intermediate hidden states between different pipeline stages.
Also, optimization methods such as flash attention or sequence parallel can be injected into the `forward` process through our customized `forward` method.
4. Replacing the whole copy of model parameters and optimizer states with incomplete ones controlled by current device (this is why it's called Shardformer).
diff --git a/docs/source/en/features/zero_with_chunk.md b/docs/source/en/features/zero_with_chunk.md
index 42305182b8b8..62be864884b7 100644
--- a/docs/source/en/features/zero_with_chunk.md
+++ b/docs/source/en/features/zero_with_chunk.md
@@ -1,6 +1,6 @@
# Zero Redundancy Optimizer with chunk-based memory management
-Author: [Hongxiu Liu](https://github.com/ver217), [Jiarui Fang](https://github.com/feifeibear), [Zijian Ye](https://github.com/ZijianYY)
+Author: [Hongxin Liu](https://github.com/ver217), [Jiarui Fang](https://github.com/feifeibear), [Zijian Ye](https://github.com/ZijianYY)
**Prerequisite:**
- [Train with booster](../basics/booster_api.md)
diff --git a/docs/source/zh-Hans/basics/booster_plugins.md b/docs/source/zh-Hans/basics/booster_plugins.md
index 70352a7b9af3..21e135a470c7 100644
--- a/docs/source/zh-Hans/basics/booster_plugins.md
+++ b/docs/source/zh-Hans/basics/booster_plugins.md
@@ -16,7 +16,7 @@
- [Torch FSDP 插件](#torch-fsdp-插件): 它包装了 `torch.distributed.fsdp.FullyShardedDataParallel` 并且可用于使用 Zero-dp 训练模型。
- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。
- [Gemini 插件](#gemini-插件): 它包装了 [Gemini](../features/zero_with_chunk.md),Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。
-- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 它为Shardformer,流水线管理器,混合精度运算,TorchDDP以及Zero-1/Zero-2功能提供了一个统一且简洁的接口。使用该插件可以简单高效地实现transformer模型在张量并行,流水线并行以及数据并行(DDP, Zero)间任意组合并行训练策略,同时支持多种训练速度和内存的优化工具。有关这些训练策略和优化工具的具体信息将在下一章中阐述。
+- [Hybrid Parallel 插件](#hybrid-parallel-插件): 它为Shardformer,流水线管理器,混合精度运算,TorchDDP以及Zero-1/Zero-2功能提供了一个统一且简洁的接口。使用该插件可以简单高效地实现transformer模型在张量并行,流水线并行以及数据并行(DDP, Zero)间任意组合并行训练策略,同时支持多种训练速度和内存的优化工具。有关这些训练策略和优化工具的具体信息将在下一章中阐述。
更多插件即将推出。
@@ -24,7 +24,7 @@
- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型(例如 Bert-3m、GPT2-1.5b)。
- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型(例如 GPTJ-6b、MegatronLM-8b)。
- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型(例如 TuringNLG-17b),且**跨节点带宽高、中小规模集群(千卡以下)**的场景(例如 Llama2-70b)。
-- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且**跨节点带宽低、大规模集群(千卡以上)**的场景(例如 GPT3-175b、Bloom-176b)。
+- [Hybrid Parallel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且**跨节点带宽低、大规模集群(千卡以上)**的场景(例如 GPT3-175b、Bloom-176b)。
## 插件
@@ -55,7 +55,11 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
这个插件实现了多种并行训练策略和优化工具的组合。Hybrid Parallel插件支持的功能大致可以被分为以下四个部分:
-1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑,以及前向/后向方法的重载,这个插件为Shardformer功能提供了一个简单易用的接口。与此同时,Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。更多关于Shardformer的信息请参考 [Shardformer文档](../features/shardformer.md)。
+1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑,以及前向/后向方法的重载,这个插件为Shardformer功能提供了一个简单易用的接口。与此同时,Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。更多关于Shardformer的信息请参考 [Shardformer文档](../features/shardformer.md)。下图展示了Shardformer与Hybrid Parallel插件所支持的功能。
+
+
+
+
2. 混合精度训练:插件支持fp16/bf16的混合精度训练。更多关于混合精度训练的参数配置的详细信息请参考 [混合精度训练文档](../features/mixed_precision_training_with_booster.md)。
diff --git a/docs/source/zh-Hans/features/lazy_init.md b/docs/source/zh-Hans/features/lazy_init.md
index 80742a56df29..137719c69de2 100644
--- a/docs/source/zh-Hans/features/lazy_init.md
+++ b/docs/source/zh-Hans/features/lazy_init.md
@@ -1,6 +1,6 @@
# 懒惰初始化
-作者: [Hongxiu Liu](https://github.com/ver217)
+作者: [Hongxin Liu](https://github.com/ver217)
**前置教程:**
- [Train with booster](../basics/booster_api.md)
@@ -73,4 +73,4 @@ model, *_ = booster.boost(model)
| Blip2Model | transformers |
| Blip2ForConditionalGeneration | transformers |
-
+
diff --git a/docs/source/zh-Hans/features/zero_with_chunk.md b/docs/source/zh-Hans/features/zero_with_chunk.md
index 61290628588b..c4f21c73c586 100644
--- a/docs/source/zh-Hans/features/zero_with_chunk.md
+++ b/docs/source/zh-Hans/features/zero_with_chunk.md
@@ -1,6 +1,6 @@
# 基于Chunk内存管理的零冗余优化器 (ZeRO)
-作者: [Hongxiu Liu](https://github.com/ver217), [Jiarui Fang](https://github.com/feifeibear), [Zijian Ye](https://github.com/ZijianYY)
+作者: [Hongxin Liu](https://github.com/ver217), [Jiarui Fang](https://github.com/feifeibear), [Zijian Ye](https://github.com/ZijianYY)
**前置教程:**
diff --git a/examples/community/roberta/preprocessing/Makefile b/examples/community/roberta/preprocessing/Makefile
index 82ee4e1c5b31..81478dd49213 100644
--- a/examples/community/roberta/preprocessing/Makefile
+++ b/examples/community/roberta/preprocessing/Makefile
@@ -1,4 +1,4 @@
-CXXFLAGS += -O3 -Wall -shared -std=c++14 -fPIC -fdiagnostics-color
+CXXFLAGS += -O3 -Wall -shared -std=c++14 -std=c++17 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = mask
LIBEXT = $(shell python3-config --extension-suffix)
diff --git a/examples/inference/_utils.py b/examples/inference/_utils.py
deleted file mode 100644
index 67d897836113..000000000000
--- a/examples/inference/_utils.py
+++ /dev/null
@@ -1,19 +0,0 @@
-def print_perf_stats(latency_set, config, bs, warmup=3):
- # trim warmup queries
- latency_set = list(latency_set)
- latency_set = latency_set[warmup:]
- count = len(latency_set)
-
- if count > 0:
- latency_set.sort()
- avg = sum(latency_set) / count
- num_layers = (
- getattr(config, "num_layers") if hasattr(config, "num_layers") else getattr(config, "num_hidden_layers")
- )
- num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
- num_bytes = 2 # float16
-
- print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
- print("Avg BW: {0:8.2f} GB/s".format(1 / avg * num_parameters * num_bytes / 1e9))
- print("Avg flops: {0:8.2f} TFlops/s".format(1 / avg * num_parameters * num_bytes * bs / 1e12))
- print("Avg Throughput: tokens/s: {}".format((1000 / (avg * 1000)) * bs))
diff --git a/examples/inference/bench_bloom.py b/examples/inference/bench_bloom.py
deleted file mode 100644
index 054641f6eebf..000000000000
--- a/examples/inference/bench_bloom.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import argparse
-import os
-import time
-
-import torch
-from _utils import print_perf_stats
-from transformers import BloomForCausalLM, BloomTokenizerFast
-
-import colossalai
-from colossalai.inference.tensor_parallel.engine import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
-
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-
-
-def bench_bloom(args):
- model_path = args.path
- max_batch_size = args.batch_size
- max_input_len = args.input_len
- max_output_len = args.output_len
-
- tokenizer = BloomTokenizerFast.from_pretrained(model_path)
- tokenizer.pad_token = tokenizer.eos_token
- model = BloomForCausalLM.from_pretrained(model_path, pad_token_id=tokenizer.eos_token_id)
- model = model.half()
-
- # init TPInferEngine and shard the original model
- # To benchmark torch original, comment out the line of optimizing model
- shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
- infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
-
- # prepare data for generation
- generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
- input_tokens = {
- "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
- "attention_mask": torch.ones((max_batch_size, max_input_len)),
- }
- for t in input_tokens:
- if torch.is_tensor(input_tokens[t]):
- input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
- print(f" input_tokens[{t}].shape: {input_tokens[t].shape}")
-
- iters = 10
- times = []
- for i in range(iters):
- torch.cuda.synchronize()
- start = time.time()
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
- torch.cuda.synchronize()
- end = time.time()
- out_len = outputs.shape[1]
- print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
- times.append((end - start) / (out_len - max_input_len))
-
- print_perf_stats(times, model.config, max_batch_size)
-
-
-def check_bloom(rank, world_size, port, args):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- bench_bloom(args)
-
-
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_bloom(args):
- spawn(check_bloom, args.tp_size, args=args)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
- parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
- parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
- parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
- parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
-
- args = parser.parse_args()
-
- test_bloom(args)
diff --git a/examples/inference/bench_chatglm2.py b/examples/inference/bench_chatglm2.py
deleted file mode 100644
index f3678d29ff93..000000000000
--- a/examples/inference/bench_chatglm2.py
+++ /dev/null
@@ -1,116 +0,0 @@
-import argparse
-import os
-import time
-
-import torch
-from _utils import print_perf_stats
-from transformers import AutoTokenizer
-
-import colossalai
-from colossalai.inference.tensor_parallel.engine import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-
-
-def run_chatglm2_test(args):
- chatglm2_model_path = args.path
- max_batch_size = args.batch_size
- max_input_len = args.input_len
- max_output_len = args.output_len
- args.test_mode
-
- print("max_batch_size : " + str(max_batch_size))
-
- tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
- model = ChatGLMForConditionalGeneration.from_pretrained(chatglm2_model_path, pad_token_id=tokenizer.eos_token_id)
- model = model.half()
- model.config
-
- shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
- infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
-
- generate_kwargs = dict(max_new_tokens=1, do_sample=False)
- input_tokens = {
- "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
- "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
- }
-
- iters = 10
- prefill_times = []
-
- warmup = 3
-
- for i in range(iters):
- torch.cuda.synchronize()
- start = time.time()
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
- torch.cuda.synchronize()
- end = time.time()
- out_len = outputs.shape[1]
- print("generation time {} s".format(str(end - start)))
- print(out_len - max_input_len)
- prefill_times.append((end - start) / (out_len - max_input_len))
-
- prefill_times = prefill_times[warmup:]
- prefill_time_avg = sum(prefill_times) / len(prefill_times)
- generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
-
- times = []
- decoder_times = []
- for i in range(iters):
- torch.cuda.synchronize()
- start = time.time()
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
- torch.cuda.synchronize()
- end = time.time()
- out_len = outputs.shape[1]
- print("generation time {} s".format(str(end - start)))
- print(out_len - max_input_len)
- times.append((end - start) / (out_len - max_input_len))
- if args.test_mode == "decoder_test":
- decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1))
-
- times = times[warmup:]
- latency = sum(times) / len(times)
- print("total process latency is : " + str(latency) + " s")
- print("total throughput is : " + str(1 / latency * max_batch_size))
-
- if args.test_mode == "decoder_test":
- decoder_times = decoder_times[warmup:]
- latency = sum(decoder_times) / len(decoder_times)
-
- print("decoder process latency is : " + str(latency) + " s")
- print("decoder throughput is : " + str(1 / latency * max_batch_size))
-
- print_perf_stats(times, model.config, max_batch_size)
-
-
-def check_chatglm2(rank, world_size, port, args):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_chatglm2_test(args)
-
-
-@rerun_if_address_is_in_use()
-def test_chatglm2(args):
- spawn(check_chatglm2, args.tp_size, args=args)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
- parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
- parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
- parser.add_argument("--input_len", type=int, default=256, help="Maximum input length")
- parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
- parser.add_argument(
- "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
- )
-
- args = parser.parse_args()
-
- test_chatglm2(args)
diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py
deleted file mode 100644
index f3e742dfbb59..000000000000
--- a/examples/inference/bench_llama.py
+++ /dev/null
@@ -1,117 +0,0 @@
-import argparse
-import os
-import time
-
-import torch
-from _utils import print_perf_stats
-from transformers import LlamaForCausalLM, LlamaTokenizer
-
-import colossalai
-from colossalai.inference.tensor_parallel.engine import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
-
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-
-
-def run_llama_test(args):
- llama_model_path = args.path
- max_batch_size = args.batch_size
- max_input_len = args.input_len
- max_output_len = args.output_len
- args.test_mode
-
- print("max_batch_size : " + str(max_batch_size))
-
- tokenizer = LlamaTokenizer.from_pretrained(llama_model_path)
- tokenizer.pad_token_id = tokenizer.unk_token_id
- model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.eos_token_id)
- model = model.half()
- model.config
-
- shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
- infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
-
- generate_kwargs = dict(max_new_tokens=1, do_sample=False)
- input_tokens = {
- "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
- "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
- }
-
- iters = 10
- prefill_times = []
-
- warmup = 3
-
- for i in range(iters):
- torch.cuda.synchronize()
- start = time.time()
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
- torch.cuda.synchronize()
- end = time.time()
- out_len = outputs.shape[1]
- print("generation time {} s".format(str(end - start)))
- print(out_len - max_input_len)
- prefill_times.append((end - start) / (out_len - max_input_len))
-
- prefill_times = prefill_times[warmup:]
- prefill_time_avg = sum(prefill_times) / len(prefill_times)
- generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
-
- times = []
- decoder_times = []
- for i in range(iters):
- torch.cuda.synchronize()
- start = time.time()
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
- torch.cuda.synchronize()
- end = time.time()
- out_len = outputs.shape[1]
- print("generation time {} s".format(str(end - start)))
- print(out_len - max_input_len)
- times.append((end - start) / (out_len - max_input_len))
- if args.test_mode == "decoder_test":
- decoder_times.append((end - start - prefill_time_avg) / (out_len - max_input_len - 1))
-
- times = times[warmup:]
- latency = sum(times) / len(times)
- print("total process latency is : " + str(latency) + " s")
- print("total throughput is : " + str(1 / latency * max_batch_size))
-
- if args.test_mode == "decoder_test":
- decoder_times = decoder_times[warmup:]
- latency = sum(decoder_times) / len(decoder_times)
-
- print("decoder process latency is : " + str(latency) + " s")
- print("decoder throughput is : " + str(1 / latency * max_batch_size))
-
- print_perf_stats(times, model.config, max_batch_size)
-
-
-def check_llama(rank, world_size, port, args):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_llama_test(args)
-
-
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_llama(args):
- spawn(check_llama, args.tp_size, args=args)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
- parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
- parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
- parser.add_argument("--input_len", type=int, default=256, help="Maximum input length")
- parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
- parser.add_argument(
- "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"]
- )
-
- args = parser.parse_args()
-
- test_llama(args)
diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py
new file mode 100644
index 000000000000..9a26098b3847
--- /dev/null
+++ b/examples/inference/benchmark_llama.py
@@ -0,0 +1,168 @@
+import argparse
+import time
+
+import torch
+import torch.distributed as dist
+import transformers
+
+import colossalai
+import colossalai.utils.device as device_utils
+from colossalai.inference import InferenceEngine
+from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
+from colossalai.utils.device import get_current_device
+
+GIGABYTE = 1024**3
+MEGABYTE = 1024 * 1024
+
+CONFIG_MAP = {
+ "toy": transformers.LlamaConfig(num_hidden_layers=4),
+ "llama-7b": transformers.LlamaConfig(
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_attention_heads=32,
+ num_hidden_layers=32,
+ num_key_value_heads=32,
+ max_position_embeddings=2048,
+ ),
+ "llama-13b": transformers.LlamaConfig(
+ hidden_size=5120,
+ intermediate_size=13824,
+ num_attention_heads=40,
+ num_hidden_layers=40,
+ num_key_value_heads=40,
+ max_position_embeddings=2048,
+ ),
+ "llama2-7b": transformers.LlamaConfig(
+ hidden_size=4096,
+ intermediate_size=11008,
+ num_attention_heads=32,
+ num_hidden_layers=32,
+ num_key_value_heads=32,
+ max_position_embeddings=4096,
+ ),
+ "llama2-13b": transformers.LlamaConfig(
+ hidden_size=5120,
+ intermediate_size=13824,
+ num_attention_heads=40,
+ num_hidden_layers=40,
+ num_key_value_heads=40,
+ max_position_embeddings=4096,
+ ),
+}
+
+
+def data_gen(batch_size: int = 4, seq_len: int = 512):
+ input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
+ attention_mask = torch.ones_like(input_ids)
+ data = dict(input_ids=input_ids, attention_mask=attention_mask)
+ return data
+
+
+def print_details_info(outputs, model_config, args, whole_end2end):
+ msg: str = ""
+
+ if dist.get_rank() == 0:
+ msg += "-------Perf Summary-------\n"
+ if args.verbose:
+ timestamps = outputs[1]
+ prefill = []
+ encoder = []
+ end2end = []
+ for timestamp in timestamps:
+ prefill.append(timestamp[1] - timestamp[0])
+ encoder.append(
+ sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
+ )
+ end2end.append(timestamp[-1] - timestamp[0])
+
+ mb_avg_end2end = sum(end2end) / len(end2end)
+ mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size)
+
+ msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n"
+ msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n"
+ msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n"
+ msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n"
+
+ whole_avg_latency = whole_end2end / (args.output_len * args.batch_size)
+ num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
+ num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
+ if args.dtype in ["fp16", "bf16"]:
+ num_bytes = 2
+ else:
+ num_bytes = 4
+
+ msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n"
+ msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n"
+ msg += f"Throughput: {args.output_len * args.batch_size / whole_end2end:.2f} tokens/s\n"
+ msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
+
+ if torch.cuda.is_available():
+ msg += f"-------Memory Summary Device:{device_utils.current_device()}-------\n"
+ msg += f"Max memory allocated: {device_utils.max_memory_allocated() / GIGABYTE:.2f} GB\n"
+ msg += f"Max memory reserved: {device_utils.max_memory_reserved() / GIGABYTE:.2f} GB\n"
+
+ print(msg)
+
+
+def benchmark_inference(args):
+ config = CONFIG_MAP[args.model]
+ model = transformers.LlamaForCausalLM(config)
+ if dist.get_rank() == 0:
+ print("Model loaded")
+ engine = InferenceEngine(
+ pp_size=args.pp_size,
+ tp_size=args.tp_size,
+ dtype=args.dtype,
+ micro_batch_size=args.mb_size,
+ model=model,
+ verbose=args.verbose,
+ max_batch_size=args.batch_size,
+ max_input_len=args.seq_len,
+ max_output_len=args.output_len,
+ )
+ data = data_gen(args.batch_size, args.seq_len)
+
+ N_WARMUP_STEPS = 2
+
+ for _ in range(N_WARMUP_STEPS):
+ engine.generate(data)
+
+ torch.cuda.synchronize()
+ whole_end2end = time.time()
+ outputs = engine.generate(data)
+ torch.cuda.synchronize()
+ whole_end2end = time.time() - whole_end2end
+
+ print_details_info(outputs, model.config, args, whole_end2end)
+
+
+def hybrid_inference(rank, world_size, port, args):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ benchmark_inference(args)
+
+
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def benchmark(args):
+ spawn(hybrid_inference, nprocs=args.tp_size * args.pp_size, args=args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-m",
+ "--model",
+ default="toy",
+ help="the size of model",
+ choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
+ )
+ parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
+ parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
+ parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
+ parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
+ parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
+ parser.add_argument("--output_len", type=int, default=128, help="Output length")
+ parser.add_argument("--dtype", type=str, default="fp16", help="data type")
+ parser.add_argument("-v", "--verbose", default=False, action="store_true")
+ args = parser.parse_args()
+ benchmark(args)
diff --git a/examples/inference/smoothquant_llama.py b/examples/inference/build_smoothquant_weight.py
similarity index 72%
rename from examples/inference/smoothquant_llama.py
rename to examples/inference/build_smoothquant_weight.py
index ce7a00aa2739..d60ce1c1d618 100644
--- a/examples/inference/smoothquant_llama.py
+++ b/examples/inference/build_smoothquant_weight.py
@@ -29,7 +29,7 @@ def parse_args():
type=str,
help="location of the calibration dataset",
)
- parser.add_argument("--num-samples", type=int, default=512)
+ parser.add_argument("--num-samples", type=int, default=10)
parser.add_argument("--seq-len", type=int, default=512)
args = parser.parse_args()
return args
@@ -41,13 +41,12 @@ def main():
model_path = args.model_name
dataset_path = args.dataset_path
output_path = args.output_path
- num_samples = 10
- seq_len = 512
+ num_samples = args.num_samples
+ seq_len = args.seq_len
model, tokenizer = build_model_and_tokenizer(model_path)
if not os.path.exists(dataset_path):
- print(f"Cannot find the dataset at {args.dataset_path}")
- raise FileNotFoundError
+ raise FileNotFoundError(f"Cannot find the dataset at {args.dataset_path}")
dataset = load_dataset("json", data_files=dataset_path, split="train")
model.quantized(tokenizer, dataset, num_samples=num_samples, seq_len=seq_len)
@@ -55,15 +54,6 @@ def main():
model.save_quantized(output_path, model_basename="llama-7b")
- model = SmoothLlamaForCausalLM.from_quantized(output_path, model_basename="llama-7b")
- model = model.cuda()
-
- generate_kwargs = dict(max_new_tokens=16, do_sample=False, use_cache=True)
- input_tokens = tokenizer(["today is "], return_tensors="pt").to("cuda")
- out = model.generate(**input_tokens, **generate_kwargs)
- text = tokenizer.batch_decode(out)
- print("out is:", text)
-
if __name__ == "__main__":
main()
diff --git a/examples/inference/gptq_bloom.py b/examples/inference/gptq_bloom.py
deleted file mode 100644
index cfa3171374dd..000000000000
--- a/examples/inference/gptq_bloom.py
+++ /dev/null
@@ -1,102 +0,0 @@
-import argparse
-import os
-import time
-
-import torch
-from _utils import print_perf_stats
-from auto_gptq import AutoGPTQForCausalLM
-from transformers import BloomTokenizerFast
-
-import colossalai
-from colossalai.inference.tensor_parallel.engine import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
-
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-
-
-def bench_bloom(args):
- pretrained_model_dir = args.path
- quantized_model_dir = args.quantized_path
- max_batch_size = args.batch_size
- max_input_len = args.input_len
- max_output_len = args.output_len
-
- tokenizer = BloomTokenizerFast.from_pretrained(pretrained_model_dir)
- tokenizer.pad_token = tokenizer.eos_token
-
- # load quantized model to the first GPU
- model = AutoGPTQForCausalLM.from_quantized(
- quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
- )
-
- model = model.half()
-
- model_config = model.config
- shard_config = ShardConfig(enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True)
- infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
- generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
-
- input_tokens = {
- "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
- "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
- }
-
- # init TPInferEngine and shard the original model
- # To benchmark torch original, comment out the line of optimizing model
- shard_config = ShardConfig(
- enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
- )
- infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
-
- # prepare data for generation
- generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
- input_tokens = {
- "input_ids": torch.randint(10, 1000, (max_batch_size, max_input_len)),
- "attention_mask": torch.ones((max_batch_size, max_input_len)),
- }
- for t in input_tokens:
- if torch.is_tensor(input_tokens[t]):
- input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
- # print(f" input_tokens[{t}].shape: {input_tokens[t].shape}")
-
- iters = 10
- times = []
- for i in range(iters):
- torch.cuda.synchronize()
- start = time.time()
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
- torch.cuda.synchronize()
- end = time.time()
- out_len = outputs.shape[1]
- print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
- times.append((end - start) / (out_len - max_input_len))
-
- print_perf_stats(times, model_config, max_batch_size)
-
-
-def check_bloom(rank, world_size, port, args):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- bench_bloom(args)
-
-
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_bloom(args):
- spawn(check_bloom, args.tp_size, args=args)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
- parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
- parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
- parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
- parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
- parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
-
- args = parser.parse_args()
-
- test_bloom(args)
diff --git a/examples/inference/gptq_llama.py b/examples/inference/gptq_llama.py
deleted file mode 100644
index 35a6049ad409..000000000000
--- a/examples/inference/gptq_llama.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import argparse
-import os
-import time
-
-import torch
-from _utils import print_perf_stats
-from auto_gptq import AutoGPTQForCausalLM
-from transformers import LlamaTokenizer
-
-import colossalai
-from colossalai.inference.tensor_parallel.engine import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
-
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-
-
-def run_llama_test(args):
- pretrained_model_dir = args.path
- quantized_model_dir = args.quantized_path
- max_batch_size = args.batch_size
- max_input_len = args.input_len
- max_output_len = args.output_len
-
- tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
- tokenizer.pad_token_id = tokenizer.eos_token_id
-
- # load quantized model to the first GPU
- model = AutoGPTQForCausalLM.from_quantized(
- quantized_model_dir, device=torch.cuda.current_device(), inject_fused_attention=False
- )
-
- model_config = model.config
- shard_config = ShardConfig(
- enable_tensor_parallelism=True if args.tp_size > 1 else False, inference_only=True, inference_gptq=True
- )
- infer_engine = TPInferEngine(model, shard_config, max_batch_size, max_input_len, max_output_len)
-
- generate_kwargs = dict(max_new_tokens=max_output_len, do_sample=False)
-
- input_tokens = {
- "input_ids": torch.randint(1, 1000, (max_batch_size, max_input_len), device="cuda"),
- "attention_mask": torch.ones((max_batch_size, max_input_len), device="cuda"),
- }
-
- iters = 10
- times = []
-
- for i in range(iters):
- torch.cuda.synchronize()
- start = time.time()
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
- torch.cuda.synchronize()
- end = time.time()
- out_len = outputs.shape[1]
- print(f" iter {i}: out len {str(out_len)}, generation time {str(end - start)} s")
- times.append((end - start) / (out_len - max_input_len))
-
- print_perf_stats(times, model_config, max_batch_size)
-
-
-def check_llama(rank, world_size, port, args):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_llama_test(args)
-
-
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_llama(args):
- spawn(check_llama, args.tp_size, args=args)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("-p", "--path", type=str, help="Model path", required=True)
- parser.add_argument("-q", "--quantized_path", type=str, help="Model path", required=True)
- parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size")
- parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size")
- parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length")
- parser.add_argument("--output_len", type=int, default=128, help="Maximum output length")
-
- args = parser.parse_args()
-
- test_llama(args)
diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh
new file mode 100755
index 000000000000..394222ea62b8
--- /dev/null
+++ b/examples/inference/run_benchmark.sh
@@ -0,0 +1,15 @@
+ROOT=$(realpath $(dirname $0))
+PY_SCRIPT=${ROOT}/benchmark_llama.py
+GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
+
+mkdir -p logs
+
+# benchmark llama2-7b one single GPU
+for bsz in 16 32 64; do
+ python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 | tee logs/${GPU}_${bsz}_256.txt
+done
+
+
+for bsz in 4 8 16 32 64; do
+ python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 | tee logs/${GPU}_${bsz}_1024.txt
+done
diff --git a/examples/inference/run_llama_inference.py b/examples/inference/run_llama_inference.py
new file mode 100644
index 000000000000..8f85a936352b
--- /dev/null
+++ b/examples/inference/run_llama_inference.py
@@ -0,0 +1,98 @@
+import argparse
+
+import torch
+import torch.distributed as dist
+from transformers import LlamaForCausalLM, LlamaTokenizer
+
+import colossalai
+from colossalai.inference import InferenceEngine
+from colossalai.testing import spawn
+from colossalai.utils.device import get_current_device
+
+INPUT_TEXTS = [
+ "What is the longest river in the world?",
+ "Explain the difference between process and thread in compouter science.",
+]
+
+
+def run_inference(args):
+ llama_model_path = args.model_path
+ llama_tokenize_path = args.tokenizer_path or args.model_path
+
+ max_input_len = args.max_input_len
+ max_output_len = args.max_output_len
+ max_batch_size = args.batch_size
+ micro_batch_size = args.micro_batch_size
+ tp_size = args.tp_size
+ pp_size = args.pp_size
+ rank = dist.get_rank()
+
+ tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left")
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+
+ if args.quant is None:
+ model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.pad_token_id)
+ elif args.quant == "gptq":
+ from auto_gptq import AutoGPTQForCausalLM
+
+ model = AutoGPTQForCausalLM.from_quantized(
+ llama_model_path, inject_fused_attention=False, device=torch.cuda.current_device()
+ )
+ elif args.quant == "smoothquant":
+ from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM
+
+ model = SmoothLlamaForCausalLM.from_quantized(llama_model_path, model_basename=args.smoothquant_base_name)
+ model = model.cuda()
+
+ engine = InferenceEngine(
+ tp_size=tp_size,
+ pp_size=pp_size,
+ model=model,
+ max_input_len=max_input_len,
+ max_output_len=max_output_len,
+ max_batch_size=max_batch_size,
+ micro_batch_size=micro_batch_size,
+ quant=args.quant,
+ dtype=args.dtype,
+ )
+
+ inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True)
+ inputs = {k: v.to(get_current_device()) for k, v in inputs.items()}
+ outputs = engine.generate(inputs)
+
+ if rank == 0:
+ output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ for input_text, output_text in zip(INPUT_TEXTS, output_texts):
+ print(f"Input: {input_text}")
+ print(f"Output: {output_text}")
+
+
+def run_tp_pipeline_inference(rank, world_size, port, args):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_inference(args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True)
+ parser.add_argument("-i", "--input", default="What is the longest river in the world?")
+ parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None)
+ parser.add_argument(
+ "-q",
+ "--quant",
+ type=str,
+ choices=["gptq", "smoothquant"],
+ default=None,
+ help="quantization type: 'gptq' or 'smoothquant'",
+ )
+ parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name")
+ parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
+ parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
+ parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size")
+ parser.add_argument("--max_input_len", type=int, default=2048, help="Maximum input length")
+ parser.add_argument("--max_output_len", type=int, default=64, help="Maximum output length")
+ parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size")
+ parser.add_argument("--dtype", default="fp16", type=str)
+
+ args = parser.parse_args()
+ spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args)
diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py
index ce13ebbf617d..d7a79a0221ca 100644
--- a/examples/language/llama2/benchmark.py
+++ b/examples/language/llama2/benchmark.py
@@ -13,6 +13,7 @@
from transformers.models.llama.modeling_llama import LlamaForCausalLM
import colossalai
+import colossalai.utils.device as device_utils
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchFSDPPlugin
from colossalai.cluster import DistCoordinator
@@ -130,7 +131,7 @@ def empty_init():
tp_size=args.tp,
pp_size=args.pp,
zero_stage=args.zero,
- enable_fused_normalization=True,
+ enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs,
precision="bf16",
)
@@ -140,7 +141,7 @@ def empty_init():
pp_size=args.pp,
zero_stage=args.zero,
cpu_offload=True,
- enable_fused_normalization=True,
+ enable_fused_normalization=torch.cuda.is_available(),
num_microbatches=args.mbs,
initial_scale=2**8,
precision="bf16",
@@ -183,14 +184,20 @@ def empty_init():
model_numel = get_model_numel(model)
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
performance_evaluator = PerformanceEvaluator(
- model_numel, args.grad_checkpoint, args.ignore_steps, dp_world_size=dp_size
+ model_numel,
+ model.config.num_hidden_layers,
+ model.config.hidden_size,
+ model.config.vocab_size,
+ args.grad_checkpoint,
+ args.ignore_steps,
+ dp_world_size=dp_size,
)
optimizer = HybridAdam(model.parameters())
torch.set_default_dtype(torch.bfloat16)
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, dataloader=dataloader)
torch.set_default_dtype(torch.float)
- coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+ coordinator.print_on_master(f"Booster init max CUDA memory: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
)
@@ -216,7 +223,7 @@ def empty_init():
performance_evaluator.on_step_end(**batch)
performance_evaluator.on_fit_end()
- coordinator.print_on_master(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+ coordinator.print_on_master(f"Max CUDA memory usage: {device_utils.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py
index 33aa1d33e6ba..f7708b1a38ab 100644
--- a/examples/language/llama2/finetune.py
+++ b/examples/language/llama2/finetune.py
@@ -58,6 +58,7 @@ def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = Non
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
+ tensor = tensor.data
tensor.div_(dist.get_world_size())
return tensor
diff --git a/examples/language/llama2/performance_evaluator.py b/examples/language/llama2/performance_evaluator.py
index a57c1e0e9ae3..6b1c92711d48 100644
--- a/examples/language/llama2/performance_evaluator.py
+++ b/examples/language/llama2/performance_evaluator.py
@@ -5,7 +5,9 @@
import torch.distributed as dist
from torch import Tensor
+import colossalai.utils.device as device_utils
from colossalai.cluster import DistCoordinator
+from colossalai.utils.device import get_current_device
def divide(x: float, y: float) -> float:
@@ -20,7 +22,7 @@ def divide(x: float, y: float) -> float:
def all_reduce_mean(x: float, world_size: int) -> float:
if world_size == 1:
return x
- tensor = torch.tensor([x], device=torch.cuda.current_device())
+ tensor = torch.tensor([x], device=get_current_device())
dist.all_reduce(tensor)
tensor = tensor / world_size
return tensor.item()
@@ -58,6 +60,9 @@ class PerformanceEvaluator:
def __init__(
self,
model_numel: int,
+ num_layers: int,
+ hidden_size: int,
+ vocab_size: int,
enable_grad_checkpoint: bool = False,
ignore_steps: int = 0,
dp_world_size: Optional[int] = None,
@@ -65,41 +70,52 @@ def __init__(
self.model_numel = model_numel
self.enable_grad_checkpoint = enable_grad_checkpoint
self.ignore_steps = ignore_steps
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size
+ self.vocab_size = vocab_size
self.coordinator = DistCoordinator()
self.dp_world_size = dp_world_size or self.coordinator.world_size
self.disable: bool = False
self.timer = Timer()
self.num_samples: int = 0
+ self.flop_megatron = 0
self.flop: int = 0
def on_step_start(self, step: int) -> None:
self.disable = self.ignore_steps > 0 and step < self.ignore_steps
if self.disable:
return
- torch.cuda.synchronize()
+ device_utils.synchronize()
self.timer.start()
def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
if self.disable:
return
- torch.cuda.synchronize()
+ device_utils.synchronize()
self.timer.end()
batch_size, seq_len = input_ids.shape
self.num_samples += batch_size
+ checkpoint_activations_factor = 3 + int(self.enable_grad_checkpoint)
+ self.flop_megatron += (
+ 24 * checkpoint_activations_factor * batch_size * seq_len * self.num_layers * (self.hidden_size**2)
+ ) * (
+ 1.0 + (seq_len / (6.0 * self.hidden_size)) + (self.vocab_size / (16.0 * self.num_layers * self.hidden_size))
+ )
self.flop += batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint))
def on_fit_end(self) -> None:
avg_duration = all_reduce_mean(self.timer.duration, self.coordinator.world_size)
avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
mp_world_size = self.coordinator.world_size // self.dp_world_size
+ avg_tflops_per_gpu_megatron = self.flop_megatron / 1e12 / (avg_duration + 1e-12) / mp_world_size
avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
self.coordinator.print_on_master(
- f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
+ f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop_megatron: {self.flop_megatron}, flop: {self.flop}, avg_duration: {avg_duration}, "
f"avg_throughput: {avg_throughput}"
)
self.coordinator.print_on_master(
- f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
+ f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU by Megatron: {avg_tflops_per_gpu_megatron:.2f}, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}"
)
diff --git a/examples/language/llama2/pretrain.py b/examples/language/llama2/pretrain.py
index 6cc73b6265a4..bb10f7a00e8a 100644
--- a/examples/language/llama2/pretrain.py
+++ b/examples/language/llama2/pretrain.py
@@ -76,6 +76,7 @@ def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = Non
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
+ tensor = tensor.data
tensor.div_(dist.get_world_size())
return tensor
diff --git a/examples/language/openmoe/README.md b/examples/language/openmoe/README.md
new file mode 100644
index 000000000000..a0821a5330a4
--- /dev/null
+++ b/examples/language/openmoe/README.md
@@ -0,0 +1,129 @@
+## OpenMoE
+[OpenMoE](https://github.com/XueFuzhao/OpenMoE) is the open-source community's first decoder-only MoE transformer. OpenMoE is implemented in Jax, and [Colossal-AI](https://github.com/hpcaitech/ColossalAI) has pioneered an efficient open-source support for this model in PyTorch, enabling a broader range of users to participate in and use this model. The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates finetune and inference methods.
+
+## Usage
+
+### 1. Installation
+
+Please install the latest ColossalAI from source.
+
+```bash
+CUDA_EXT=1 pip install -U git+https://github.com/hpcaitech/ColossalAI
+```
+
+Then install dependencies.
+
+```bash
+cd ColossalAI/examples/language/openmoe
+pip install -r requirements.txt
+```
+
+Additionally, we recommend you to use torch 1.13.1. We've tested our code on torch 1.13.1 and found it's compatible with our code and flash attention.
+
+### 2. Install kernels (Optional)
+
+We have utilized `Triton`, `FlashAttention` and `Apex` kernel for better performance. They are not necessary but we recommend you to install them to fully utilize your hardware.
+```
+# install triton via pip
+pip install triton
+
+# install flash attention via pip
+pip install flash-attn==2.0.5
+
+# install apex from source
+git clone https://github.com/NVIDIA/apex.git
+cd apex
+git checkout 741bdf50825a97664db08574981962d66436d16a
+pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext"
+```
+
+### 3. Train
+Yon can use colossalai run to launch single-node training:
+```bash
+colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS
+```
+Yon can also use colossalai run to launch multi-nodes training:
+```bash
+colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE train.py --OTHER_CONFIGURATIONS
+```
+
+Here is a sample hostfile:
+
+```text
+hostname1
+hostname2
+hostname3
+hostname4
+```
+
+The hostname refers to the ip address of your nodes. Make sure master node can access all nodes (including itself) by ssh without password.
+
+Here is details about CLI arguments:
+
+- Model configuration: `--model_name`. `base` and `8b` are supported for OpenMoE.
+- Booster plugin: `--plugin`. `ep`, `ep_zero` and `hybrid` are supported. `ep_zero` is recommended for general cases. `ep` can provides least memory consumption and `hybrid` suits large scale training.
+- Output path: `--output_path`. The path to save your model. The default value is `./outputs`.
+- Number of epochs: `--num_epochs`. The default value is 1.
+- Local batch size: `--batch_size`. Batch size per GPU. The default value is 1.
+- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
+- Mixed precision: `--precision`. The default value is "bf16". "fp16", "bf16" and "fp32" are supported.
+- Max length: `--max_length`. Max sequence length. Default to 2048.
+- Dataset: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as it.
+- Task Name: `--task_name`. Task of corresponding dataset. Default to `super_natural_instructions`.
+- Learning rate: `--lr`. The default value is 1e-5.
+- Weight decay: `--weight_decay`. The default value is 0.
+- Zero stage: `--zero_stage`. Zero stage. Recommend 2 for ep and 1 for ep zero.
+- Extra dp size: `--extra_dp_size`. Extra moe param dp size for ep_zero plugin. Recommended to be 2 or 4.
+- Use kernel: `--use_kernel`. Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.
+- Use layernorm kernel: `--use_layernorm_kernel`. Use layernorm kernel. Need to install apex. Raise error if not installed.
+- Router aux loss factor: `--router_aux_loss_factor`. Moe router z loss factor. You can refer to STMoE for details.
+- Router z loss factor: `--router_z_loss_factor`. Moe router aux loss factor. You can refer to STMoE for details.
+- Label smoothing: `--label_smoothing`. Label smoothing.
+- Z loss factor: `--z_loss_factor`. The final outputs' classification z loss factor.
+Load balance: `--load_balance`. Expert load balance. Defaults to False. Recommend enabling.
+- Load balance interval: `--load_balance_interval`. Expert load balance interval.
+- Communication overlap: `--comm_overlap`. Use communication overlap for MoE. Recommended to enable for multi-node training.
+
+### 4. Shell Script Examples
+
+For your convenience, we provide some shell scripts to train with various configurations. Here we will show an example of how to run training
+OpenMoE.
+
+#### a. Running environment
+This experiment was performed on a single computing nodes with 8 A800 80GB GPUs in total for OpenMoE-8B. The GPUs are fully connected with NVLink.
+
+#### b. Running command
+We demonstrate how to run three plugins in `train.sh`. You can choose anyone and use your own args.
+
+```bash
+bash train.sh
+```
+
+#### c. Multi-Nodes Training
+
+To run on multi-nodes, you can modify the script as:
+```bash
+colossalai run --nproc_per_node YOUR_GPU_PER_NODE --hostfile YOUR_HOST_FILE \
+train.py --OTHER_CONFIGURATIONS
+```
+
+## Reference
+```
+@article{bian2021colossal,
+ title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
+ author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
+ journal={arXiv preprint arXiv:2110.14883},
+ year={2021}
+}
+```
+
+```bibtex
+@misc{openmoe2023,
+ author = {Fuzhao Xue, Zian Zheng, Yao Fu, Jinjie Ni, Zangwei Zheng, Wangchunshu Zhou and Yang You},
+ title = {OpenMoE: Open Mixture-of-Experts Language Models},
+ year = {2023},
+ publisher = {GitHub},
+ journal = {GitHub repository},
+ howpublished = {\url{https://github.com/XueFuzhao/OpenMoE}},
+}
+```
diff --git a/examples/language/openmoe/benchmark/benchmark_cai.py b/examples/language/openmoe/benchmark/benchmark_cai.py
new file mode 100644
index 000000000000..65562b386cf9
--- /dev/null
+++ b/examples/language/openmoe/benchmark/benchmark_cai.py
@@ -0,0 +1,297 @@
+import argparse
+import json
+import os
+
+import torch
+import torch.distributed as dist
+from huggingface_hub import snapshot_download
+from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
+from model.openmoe_policy import OpenMoeForCausalLMPolicy
+from torch.utils.data import Dataset
+from tqdm import tqdm
+from transformers import T5Tokenizer
+from transformers.models.llama import LlamaConfig
+from utils import PerformanceEvaluator, get_model_numel
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.moe.layers import apply_load_balance
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.moe.utils import skip_init
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+
+def move_to_cuda(batch, device):
+ return {k: v.to(device) for k, v in batch.items()}
+
+
+def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
+ ckpt_path = snapshot_download(repo_name)
+ # single ckpt
+ if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
+ ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
+ # shard ckpt
+ elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
+ ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
+ else:
+ raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
+ booster.load_model(model, ckpt_path)
+
+
+class RandomDataset(Dataset):
+ def __init__(
+ self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 256384, tokenizer: T5Tokenizer = None
+ ):
+ self.num_samples = num_samples
+ self.max_length = max_length
+ if os.path.exists("./mock_data.json"):
+ self.input_ids = []
+ self.attention_mask = []
+ with open("./mock_data.json", "r") as f:
+ data = json.load(f)
+ for v in data.values():
+ d = v["text"]
+ encode = tokenizer(
+ "" + d,
+ return_tensors="pt",
+ add_special_tokens=False,
+ max_length=max_length,
+ truncation=True,
+ padding="max_length",
+ )
+ self.input_ids.append(encode["input_ids"])
+ self.attention_mask.append(encode["attention_mask"])
+ self.input_ids = torch.cat(self.input_ids, dim=0).to(get_current_device())
+ self.attention_mask = torch.cat(self.attention_mask, dim=0).to(get_current_device())
+ repeat_times = num_samples // self.input_ids.shape[0] + 1
+ self.input_ids = self.input_ids.repeat(repeat_times, 1)[:num_samples]
+ self.attention_mask = self.attention_mask.repeat(repeat_times, 1)[:num_samples]
+ else:
+ self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
+ self.attention_mask = torch.ones_like(self.input_ids)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ return {
+ "input_ids": self.input_ids[idx],
+ "attention_mask": self.attention_mask[idx],
+ "labels": self.input_ids[idx],
+ }
+
+
+def parse_args():
+ # basic settings
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="base",
+ choices=["base", "8b"],
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=4,
+ help="Batch size (per dp group) for the training dataloader.",
+ )
+ parser.add_argument(
+ "--seq_length",
+ type=int,
+ default=2048,
+ help="sequence length for the training dataloader.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="hybrid",
+ help="parallel plugin",
+ )
+ # hybrid plugin
+ parser.add_argument("--pp_size", type=int, default=2, help="pp size")
+ parser.add_argument("--dp_size", type=int, default=1, help="dp size")
+ parser.add_argument("--ep_size", type=int, default=2, help="ep size")
+ parser.add_argument("--zero_stage", type=int, default=2, help="zero stage in hybrid plugin")
+ parser.add_argument("--microbatch_size", type=int, default=1, help="microbatch size")
+ parser.add_argument("--extra_dp_size", type=int, default=1)
+ # kernel
+ parser.add_argument(
+ "--use_kernel",
+ action="store_true",
+ help="Use kernel optim. Need to install flash attention, apex, triton to enable all kernel optimizations.",
+ )
+ # bench
+ parser.add_argument("--warmup", type=int, default=20)
+ parser.add_argument("--active", type=int, default=20)
+ # load balance
+ parser.add_argument("--load_balance", action="store_true")
+
+ # overlap communication
+ parser.add_argument("--overlap_comm", action="store_true")
+ # hierarchical all-to-all
+ parser.add_argument("--hierarchical_alltoall", action="store_true")
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+
+ # Set plugin
+ booster_kwargs = {}
+ hybrid_dict = {
+ "tp_size": 1,
+ "custom_policy": OpenMoeForCausalLMPolicy(),
+ "enable_fused_normalization": args.use_kernel,
+ "enable_jit_fused": args.use_kernel,
+ "precision": "bf16",
+ "zero_stage": args.zero_stage,
+ }
+ mgr_dict = {}
+ if args.plugin == "ep":
+ dp_size = dist.get_world_size()
+ plugin = MoeHybridParallelPlugin(
+ pp_size=1,
+ **hybrid_dict,
+ )
+ MOE_MANAGER.setup(
+ parallel="EP",
+ max_ep_size=dp_size,
+ **mgr_dict,
+ )
+ elif args.plugin == "ep_zero":
+ dp_size = dist.get_world_size()
+ use_ep_inside = False
+ plugin = MoeHybridParallelPlugin(
+ pp_size=1,
+ extra_dp_size=args.extra_dp_size,
+ use_ep_inside=use_ep_inside,
+ **hybrid_dict,
+ )
+ MOE_MANAGER.setup(
+ parallel="EP",
+ max_ep_size=dp_size // args.extra_dp_size,
+ use_ep_inside=use_ep_inside,
+ **mgr_dict,
+ )
+ elif args.plugin == "hybrid":
+ dp_size = dist.get_world_size() // args.pp_size
+ plugin = MoeHybridParallelPlugin(
+ pp_size=args.pp_size,
+ zero_stage=args.zero_stage,
+ microbatch_size=args.microbatch_size,
+ **hybrid_dict,
+ )
+ MOE_MANAGER.setup(
+ parallel="EP",
+ mode="fixed",
+ fixed_dp_size=args.dp_size,
+ fixed_ep_size=args.ep_size,
+ fixed_pp_size=args.pp_size,
+ **mgr_dict,
+ )
+ else:
+ raise ValueError(f"Invalid plugin {args.plugin}")
+ coordinator.print_on_master(f"Set plugin as {plugin}")
+
+ # Build OpenMoe model
+ repo_name = "hpcaitech/openmoe-" + args.model_name
+ config = LlamaConfig.from_pretrained(repo_name)
+ set_openmoe_args(
+ config,
+ num_experts=config.num_experts,
+ moe_layer_interval=config.moe_layer_interval,
+ enable_load_balance=args.load_balance,
+ enable_kernel=args.use_kernel,
+ enable_comm_overlap=args.overlap_comm,
+ enable_hierarchical_alltoall=args.hierarchical_alltoall,
+ )
+ with skip_init():
+ model = OpenMoeForCausalLM(config)
+ coordinator.print_on_master(f"Finish init model with config:\n{config}")
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
+ # Prepare tokenizer and dataloader
+ tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
+ dataset = RandomDataset(
+ num_samples=args.batch_size * (args.warmup + args.active + 1) * dp_size,
+ max_length=args.seq_length,
+ tokenizer=tokenizer,
+ )
+ dataloader = plugin.prepare_dataloader(dataset, batch_size=args.batch_size)
+
+ # Set optimizer
+ optimizer = HybridAdam(model.parameters(), weight_decay=0.01, lr=1e-5)
+
+ model_numel = get_model_numel(model)
+ performance_evaluator = PerformanceEvaluator(
+ model_numel,
+ enable_grad_checkpoint=True,
+ ignore_steps=args.warmup,
+ dp_world_size=dp_size,
+ )
+
+ # Set booster
+ booster = Booster(plugin=plugin, **booster_kwargs)
+ load_ckpt(repo_name, model, booster)
+ model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
+ use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ coordinator.print_on_master(f"Finish init booster")
+
+ # Start finetuning
+ coordinator.print_on_master(f"Start training")
+ model.train()
+ train_dataloader_iter = iter(dataloader)
+ total_len = len(train_dataloader_iter) - 1
+ exmaple_data = next(train_dataloader_iter)
+ with tqdm(range(total_len), disable=not coordinator.is_master()) as pbar:
+ for step in pbar:
+ performance_evaluator.on_step_start(step)
+ if use_pipeline:
+ # Forward pass
+ outputs = booster.execute_pipeline(
+ train_dataloader_iter,
+ model,
+ lambda x, y: x.loss,
+ optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs["loss"]
+ pbar.set_postfix({"loss": loss.item()})
+ else:
+ # Forward pass
+ data = next(train_dataloader_iter)
+ data = move_to_cuda(data, torch.cuda.current_device())
+ outputs = model(**data)
+ loss = outputs["loss"]
+ # Backward
+ booster.backward(loss, optimizer)
+ pbar.set_postfix({"loss": loss.item()})
+
+ optimizer.step()
+ optimizer.zero_grad()
+ performance_evaluator.on_step_end(exmaple_data["input_ids"])
+ if (step == args.warmup // 2) and args.load_balance:
+ coordinator.print_on_master(f"Apply load balance")
+ apply_load_balance(model, optimizer)
+ performance_evaluator.on_fit_end()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/language/openmoe/benchmark/benchmark_cai.sh b/examples/language/openmoe/benchmark/benchmark_cai.sh
new file mode 100755
index 000000000000..f269e260d8db
--- /dev/null
+++ b/examples/language/openmoe/benchmark/benchmark_cai.sh
@@ -0,0 +1,78 @@
+#!/bin/bash
+
+set -xue
+
+NUM_GPU=8
+MODEL="8b"
+SEQ_LENGTH=2048
+WARMUP=20
+ACTIVE=4
+
+# HACK: make model importable
+example_dir=$(dirname $(realpath $(dirname $0)))
+if [ -z ${PYTHONPATH+x} ]; then
+ export PYTHONPATH=$example_dir
+else
+ export PYTHONPATH=$example_dir:$PYTHONPATH
+fi
+
+
+# ep
+echo -e "\n\n Naive EP \n\n"
+torchrun --standalone --nproc_per_node $NUM_GPU \
+ $example_dir/benchmark/benchmark_cai.py \
+ --model_name $MODEL \
+ --batch_size 8 \
+ --seq_length $SEQ_LENGTH \
+ --warmup $WARMUP \
+ --active $ACTIVE \
+ --plugin ep \
+ --zero_stage 2
+
+
+# ep_zero
+echo -e "\n\n EP-ZERO \n\n"
+torchrun --standalone --nproc_per_node $NUM_GPU \
+ $example_dir/benchmark/benchmark_cai.py \
+ --model_name $MODEL \
+ --batch_size 16 \
+ --seq_length $SEQ_LENGTH \
+ --warmup $WARMUP \
+ --active $ACTIVE \
+ --plugin ep_zero \
+ --use_kernel \
+ --extra_dp_size 2 \
+ --zero_stage 1 \
+ --load_balance
+
+echo -e "\n\n EP-ZERO + Overlap \n\n"
+torchrun --standalone --nproc_per_node $NUM_GPU \
+ $example_dir/benchmark/benchmark_cai.py \
+ --model_name $MODEL \
+ --batch_size 16 \
+ --seq_length $SEQ_LENGTH \
+ --warmup $WARMUP \
+ --active $ACTIVE \
+ --plugin ep_zero \
+ --use_kernel \
+ --extra_dp_size 2 \
+ --zero_stage 1 \
+ --load_balance \
+ --overlap_alltoall
+
+
+# hybrid
+torchrun --standalone --nproc_per_node $NUM_GPU \
+ $example_dir/benchmark/benchmark_cai.py \
+ --model_name $MODEL \
+ --batch_size 128 \
+ --seq_length $SEQ_LENGTH \
+ --warmup $WARMUP \
+ --active $ACTIVE \
+ --use_kernel \
+ --plugin hybrid \
+ --pp_size 2 \
+ --dp_size 1 \
+ --ep_size 4 \
+ --zero_stage 1 \
+ --microbatch_size 32
diff --git a/examples/language/openmoe/benchmark/benchmark_cai_dist.sh b/examples/language/openmoe/benchmark/benchmark_cai_dist.sh
new file mode 100755
index 000000000000..06d57e4f0574
--- /dev/null
+++ b/examples/language/openmoe/benchmark/benchmark_cai_dist.sh
@@ -0,0 +1,47 @@
+#!/bin/bash
+
+set -xue
+
+NUM_GPU=8
+MODEL="8b"
+SEQ_LENGTH=2048
+WARMUP=20
+ACTIVE=4
+
+# HACK: make model importable
+example_dir=$(dirname $(realpath $(dirname $0)))
+if [ -z ${PYTHONPATH+x} ]; then
+ export PYTHONPATH=$example_dir
+else
+ export PYTHONPATH=$example_dir:$PYTHONPATH
+fi
+
+
+# ep
+echo -e "\n\n Naive EP \n\n"
+colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
+ $example_dir/benchmark/benchmark_cai.py \
+ --model_name $MODEL \
+ --batch_size 12 \
+ --seq_length $SEQ_LENGTH \
+ --warmup $WARMUP \
+ --active $ACTIVE \
+ --plugin ep \
+ --zero_stage 2
+
+
+# ep_zero
+echo -e "\n\n EP-ZERO \n\n"
+colossalai run --nproc_per_node $NUM_GPU --hostfile "hostfile.txt" \
+ $example_dir/benchmark/benchmark_cai.py \
+ --model_name $MODEL \
+ --batch_size 20 \
+ --seq_length $SEQ_LENGTH \
+ --warmup $WARMUP \
+ --active $ACTIVE \
+ --plugin ep_zero \
+ --use_kernel \
+ --extra_dp_size 2 \
+ --zero_stage 1 \
+ --load_balance \
+ --overlap_alltoall
diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.py b/examples/language/openmoe/benchmark/benchmark_fsdp.py
new file mode 100644
index 000000000000..7f438fc5acce
--- /dev/null
+++ b/examples/language/openmoe/benchmark/benchmark_fsdp.py
@@ -0,0 +1,139 @@
+import argparse
+import functools
+import os
+
+import torch
+import torch.distributed as dist
+import tqdm
+from model.modeling_openmoe import LlamaConfig, OpenMoeDecoderLayer, OpenMoeForCausalLM, set_openmoe_args
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
+from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
+from torch.utils.data import Dataset
+from torch.utils.data.distributed import DistributedSampler
+from transformers.models.llama import LlamaConfig
+from utils import PerformanceEvaluator, get_model_numel
+
+from colossalai.moe.manager import MOE_MANAGER
+
+
+class RandomDataset(Dataset):
+ def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000):
+ self.num_samples = num_samples
+ self.max_length = max_length
+ self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length))
+ self.attention_mask = torch.ones_like(self.input_ids)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ return {
+ "input_ids": self.input_ids[idx],
+ "attention_mask": self.attention_mask[idx],
+ "labels": self.input_ids[idx],
+ }
+
+
+def fsdp_main(rank, world_size, args):
+ # initialize the process group
+
+ # initialize the process group
+ dist.init_process_group("nccl")
+
+ MOE_MANAGER.setup(parallel=None)
+
+ dp_size = dist.get_world_size()
+ dataset = RandomDataset(
+ max_length=args.seq_length,
+ num_samples=args.batch_size * (args.warmup + args.active) * dp_size,
+ )
+ sampler = DistributedSampler(dataset, rank=rank, num_replicas=world_size, shuffle=False)
+ train_kwargs = {"batch_size": args.batch_size, "sampler": sampler}
+ train_loader = torch.utils.data.DataLoader(dataset, **train_kwargs)
+ torch.cuda.set_device(rank)
+
+ config = LlamaConfig.from_pretrained("hpcaitech/openmoe-%s" % args.model_name)
+ set_openmoe_args(
+ config,
+ num_experts=config.num_experts,
+ moe_layer_interval=config.moe_layer_interval,
+ enable_load_balance=False,
+ enable_kernel=False,
+ enable_comm_overlap=False,
+ )
+ torch.set_default_dtype(torch.float16)
+ model = OpenMoeForCausalLM(config)
+ torch.set_default_dtype(torch.float32)
+ auto_wrap_policy = functools.partial(
+ transformer_auto_wrap_policy,
+ transformer_layer_cls={
+ OpenMoeDecoderLayer,
+ },
+ )
+ model = FSDP(
+ model,
+ mixed_precision=MixedPrecision(
+ param_dtype=torch.bfloat16,
+ reduce_dtype=torch.bfloat16,
+ buffer_dtype=torch.bfloat16,
+ ),
+ auto_wrap_policy=auto_wrap_policy,
+ device_id=torch.cuda.current_device(),
+ )
+ optimizer = torch.optim.Adam(model.parameters(), weight_decay=0.01, lr=1e-5)
+ model.train()
+
+ model_numel = get_model_numel(model)
+ performance_evaluator = PerformanceEvaluator(
+ model_numel,
+ enable_grad_checkpoint=True,
+ ignore_steps=args.warmup,
+ dp_world_size=dist.get_world_size(),
+ )
+
+ for step, data in tqdm.tqdm(enumerate(train_loader), total=len(train_loader)):
+ performance_evaluator.on_step_start(step)
+ input_ids, attention_mask, labels = (
+ data["input_ids"].cuda(),
+ data["attention_mask"].cuda(),
+ data["labels"].cuda(),
+ )
+
+ optimizer.zero_grad()
+ output = model(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=attention_mask,
+ chunk_head=False,
+ )
+ loss = output["loss"]
+ loss.backward()
+ optimizer.step()
+ performance_evaluator.on_step_end(input_ids)
+
+ performance_evaluator.on_fit_end()
+ if dist.get_rank() == 0:
+ print(f"Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="base",
+ choices=["base", "8b"],
+ help="base or 8b",
+ )
+ parser.add_argument("--batch_size", type=int, default=1)
+ parser.add_argument("--seq_length", type=int, default=2048)
+ parser.add_argument("--warmup", type=int, default=20)
+ parser.add_argument("--active", type=int, default=20)
+ args = parser.parse_args()
+
+ torch.manual_seed(42)
+
+ world_size = int(os.environ["WORLD_SIZE"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ fsdp_main(local_rank, world_size, args)
diff --git a/examples/language/openmoe/benchmark/benchmark_fsdp.sh b/examples/language/openmoe/benchmark/benchmark_fsdp.sh
new file mode 100755
index 000000000000..c6f5624dd746
--- /dev/null
+++ b/examples/language/openmoe/benchmark/benchmark_fsdp.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+
+set -xue
+
+MODEL="8b"
+BATCH_SIZE=1
+SEQ_LENGTH=2048
+WARMUP=8
+ACTIVE=4
+
+# HACK: make model importable
+example_dir=$(dirname $(realpath $(dirname $0)))
+if [ -z ${PYTHONPATH+x} ]; then
+ export PYTHONPATH=$example_dir
+else
+ export PYTHONPATH=$example_dir:$PYTHONPATH
+fi
+
+# single node
+torchrun --standalone $example_dir/benchmark/benchmark_fsdp.py \
+ --model_name $MODEL \
+ --batch_size $BATCH_SIZE \
+ --seq_length $SEQ_LENGTH \
+ --warmup $WARMUP \
+ --active $ACTIVE
+
+# multi node
+torchrun --nproc_per_node=8 --nnodes=2 --node_rank=node_rank --master_addr=master_addr --master_port=master_port \
+ $example_dir/benchmark/benchmark_fsdp.py \
+ --model_name $MODEL \
+ --batch_size $BATCH_SIZE \
+ --seq_length $SEQ_LENGTH \
+ --warmup $WARMUP \
+ --active $ACTIVE
diff --git a/examples/language/openmoe/benchmark/hostfile.txt b/examples/language/openmoe/benchmark/hostfile.txt
new file mode 100644
index 000000000000..994b3e2cfc4f
--- /dev/null
+++ b/examples/language/openmoe/benchmark/hostfile.txt
@@ -0,0 +1,2 @@
+host1
+host2
diff --git a/examples/language/openmoe/benchmark/utils.py b/examples/language/openmoe/benchmark/utils.py
new file mode 100644
index 000000000000..7a0955bb028a
--- /dev/null
+++ b/examples/language/openmoe/benchmark/utils.py
@@ -0,0 +1,126 @@
+from time import time
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch import Tensor
+
+from colossalai.logging import DistributedLogger
+
+
+def print_model_numel(logger: DistributedLogger, model: nn.Module) -> None:
+ B = 1024**3
+ M = 1024**2
+ K = 1024
+ outputs = "Model param count: "
+ model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ if model_param >= B:
+ outputs += f"{model_param / B:.2f} B\n"
+ elif model_param >= M:
+ outputs += f"{model_param / M:.2f} M\n"
+ elif model_param >= K:
+ outputs += f"{model_param / K:.2f} K\n"
+ else:
+ outputs += f"{model_param}\n"
+ logger.info(outputs, ranks=[0])
+
+
+def get_model_numel(model: nn.Module) -> None:
+ model_param = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return model_param
+
+
+def divide(x: float, y: float) -> float:
+ if y == 0:
+ return float("inf")
+ elif y == float("inf"):
+ return float("nan")
+ return x / y
+
+
+@torch.no_grad()
+def all_reduce_mean(x: float, world_size: int) -> float:
+ if world_size == 1:
+ return x
+ tensor = torch.tensor([x], device=torch.cuda.current_device())
+ dist.all_reduce(tensor)
+ tensor = tensor / world_size
+ return tensor.item()
+
+
+class Timer:
+
+ def __init__(self) -> None:
+ self.start_time: Optional[float] = None
+ self.duration: float = 0.0
+
+ def start(self) -> None:
+ self.start_time = time()
+
+ def end(self) -> None:
+ assert self.start_time is not None
+ self.duration += time() - self.start_time
+ self.start_time = None
+
+ def reset(self) -> None:
+ self.duration = 0.0
+
+
+class PerformanceEvaluator:
+ """
+ Callback for valuate the performance of the model.
+ Args:
+ actor_num_params: The number of parameters of the actor model.
+ critic_num_params: The number of parameters of the critic model.
+ initial_model_num_params: The number of parameters of the initial model.
+ reward_model_num_params: The number of parameters of the reward model.
+ enable_grad_checkpoint: Whether to enable gradient checkpointing.
+ ignore_episodes: The number of episodes to ignore when calculating the performance.
+ """
+
+ def __init__(
+ self,
+ model_numel: int,
+ enable_grad_checkpoint: bool = False,
+ ignore_steps: int = 0,
+ dp_world_size: Optional[int] = None,
+ ) -> None:
+ self.model_numel = model_numel
+ self.enable_grad_checkpoint = enable_grad_checkpoint
+ self.ignore_steps = ignore_steps
+ self.dp_world_size = dp_world_size
+ self.world_size = dist.get_world_size()
+ self.disable: bool = False
+ self.timer = Timer()
+ self.num_samples: int = 0
+ self.flop: int = 0
+
+ def on_step_start(self, step: int) -> None:
+ self.disable = self.ignore_steps > 0 and step < self.ignore_steps
+ if self.disable:
+ return
+ torch.cuda.synchronize()
+ self.timer.start()
+
+ def on_step_end(self, input_ids: Tensor, **kwargs) -> None:
+ if self.disable:
+ return
+ torch.cuda.synchronize()
+ self.timer.end()
+
+ batch_size, seq_len = input_ids.shape
+
+ self.num_samples += batch_size
+ self.flop += (batch_size * seq_len * self.model_numel * 2 * (3 + int(self.enable_grad_checkpoint)))
+
+ def on_fit_end(self) -> None:
+ avg_duration = all_reduce_mean(self.timer.duration, self.world_size)
+ avg_throughput = self.num_samples * self.dp_world_size / (avg_duration + 1e-12)
+ mp_world_size = self.world_size // self.dp_world_size
+ avg_tflops_per_gpu = self.flop / 1e12 / (avg_duration + 1e-12) / mp_world_size
+ if dist.get_rank() == 0:
+ print(
+ f"num_samples: {self.num_samples}, dp_world_size: {self.dp_world_size}, flop: {self.flop}, avg_duration: {avg_duration}, "
+ f"avg_throughput: {avg_throughput}")
+ print(f"Throughput: {avg_throughput:.2f} samples/sec, TFLOPS per GPU: {avg_tflops_per_gpu:.2f}")
diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py
new file mode 100644
index 000000000000..db90c6e34507
--- /dev/null
+++ b/examples/language/openmoe/infer.py
@@ -0,0 +1,57 @@
+from argparse import ArgumentParser
+
+import torch
+from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
+from transformers import T5Tokenizer
+from transformers.models.llama import LlamaConfig
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"])
+ return parser.parse_args()
+
+
+def inference(args):
+ tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
+ if args.model == "test":
+ config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
+ set_openmoe_args(config,
+ num_experts=config.num_experts,
+ moe_layer_interval=config.moe_layer_interval,
+ enable_kernel=True)
+ model = OpenMoeForCausalLM(config)
+ else:
+ config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}")
+ set_openmoe_args(config,
+ num_experts=config.num_experts,
+ moe_layer_interval=config.moe_layer_interval,
+ enable_kernel=False)
+ model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config)
+ model = model.eval().bfloat16()
+ model = model.to(torch.cuda.current_device())
+
+ input_str = """```
+y = list(map(int, ['1', 'hello', '2']))
+```
+What error does this program produce?
+ValueError: invalid literal for int() with base 10: 'hello'
+
+```
+sum = 0
+for i in range(100):
+ sum += i
+```
+What is the value of sum immediately after the 10th time line 3 is executed?"""
+
+ # print("model config: ", model.config)
+ input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=False)
+ input_ids = input_ids.input_ids.to(torch.cuda.current_device())
+ generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64)
+ out = tokenizer.decode(generation_output[0], skip_special_tokens=False)
+ print(f"output: \n{out}\n")
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ inference(args)
diff --git a/examples/language/openmoe/infer.sh b/examples/language/openmoe/infer.sh
new file mode 100644
index 000000000000..a578203eba84
--- /dev/null
+++ b/examples/language/openmoe/infer.sh
@@ -0,0 +1 @@
+python infer.py --model "base"
diff --git a/examples/language/openmoe/model/__init__.py b/examples/language/openmoe/model/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.py b/examples/language/openmoe/model/convert_openmoe_ckpt.py
new file mode 100644
index 000000000000..20b1e780d8b3
--- /dev/null
+++ b/examples/language/openmoe/model/convert_openmoe_ckpt.py
@@ -0,0 +1,224 @@
+# coding=utf-8
+# Copyright 2022 Google LLC and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Convert T5X checkpoint to PyTorch
+
+Steps:
+- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install
+- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example:
+ `gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/`
+- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use
+ https://huggingface.co/google/t5-v1_1-small/blob/main/config.json
+- Convert:
+ ```
+ python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\
+ --pytorch_dump_path=$HOME/t5_1_1_small_pt
+ ```
+"""
+
+import argparse
+import collections
+
+import torch
+from flax import traverse_util
+from modeling_openmoe import OpenMoeForCausalLM
+from t5x import checkpoints
+from transformers import LlamaConfig
+from transformers.utils import logging
+
+logging.set_verbosity_info()
+
+
+def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
+ """Returns the KOQV parameters of (self-)attention. Does not transpose."""
+ k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"]
+ o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"]
+ q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"]
+ v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"]
+ return k, o, q, v
+
+
+def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False):
+ """Returns the MLP parameters of a layer. Does not transpose."""
+ if split_mlp_wi:
+ wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"]
+ wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"]
+ wi = (wi_0, wi_1)
+ else:
+ wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"]
+
+ wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"]
+ return wi, wo
+
+
+def t5x_extra_mlp_lookup(params, i, prefix, split_mlp_wi=False):
+ """Returns the MLP parameters of a layer. Does not transpose."""
+ if split_mlp_wi:
+ wi_0 = params[f"{prefix}/layers_{i}/extra_mlp/wi_0/kernel"]
+ wi_1 = params[f"{prefix}/layers_{i}/extra_mlp/wi_1/kernel"]
+ wi = (wi_0, wi_1)
+ else:
+ wi = params[f"{prefix}/layers_{i}/extra_mlp/wi/kernel"]
+
+ wo = params[f"{prefix}/layers_{i}/extra_mlp/wo/kernel"]
+ return wi, wo
+
+
+def t5x_experts_lookup(params, i, prefix, split_mlp_wi=False):
+ """Returns the MLP parameters of a layer. Does not transpose."""
+ if split_mlp_wi:
+ wi_0 = params[f"{prefix}/layers_{i}/mlp/expert/wi_0/kernel"]
+ wi_1 = params[f"{prefix}/layers_{i}/mlp/expert/wi_1/kernel"]
+ wi = (wi_0, wi_1)
+ else:
+ wi = params[f"{prefix}/layers_{i}/mlp/expert/wi/kernel"]
+
+ wo = params[f"{prefix}/layers_{i}/mlp/expert/wo/kernel"]
+ return wi, wo
+
+
+def t5x_gate_lookup(params, i, prefix, split_mlp_wi=False):
+ """Returns the MLP parameters of a layer. Does not transpose."""
+ return params[f"{prefix}/layers_{i}/mlp/router/router_weights/w/kernel"]
+
+
+def t5x_layer_norm_lookup(params, i, prefix, layer_name):
+ """Returns the layer norm param of a layer."""
+ return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
+
+
+def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: int):
+ """Converts the parameters from T5X-Flax to Transformers-PyTorch."""
+ old = traverse_util.flatten_dict(variables["target"])
+ old = {"/".join(k): v for k, v in old.items()}
+
+ # v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi
+ split_mlp_wi = True
+ print("Split MLP:", split_mlp_wi)
+
+ new = collections.OrderedDict()
+ print(old.keys())
+ for key, value in old.items():
+ print(f"{key}: {value.shape}")
+
+ # Shared embeddings.
+ new["model.embed_tokens.weight"] = old["token_embedder/embedding"]
+
+ # Decoder.
+ for i in range(num_layers):
+ # Block i, layer 0 (Self Attention).
+ layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
+ k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
+ new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm
+ new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T
+ new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T
+ new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T
+ new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T
+
+ # Block i, layer 2 (MLP).
+ layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
+ new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm
+
+ if (i + 1) % moe_interval == 0:
+ # moe
+ gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi)
+ new[f"model.layers.{i}.mlp.gate_weight"] = gate.T
+ wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi)
+ new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0]
+ new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1]
+ new[f"model.layers.{i}.mlp.experts.wo"] = wo
+ # extra
+ layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm")
+ new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm
+ wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi)
+ new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T
+ new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T
+ new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T
+ else:
+ wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
+ new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T
+ new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T
+ new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T
+
+ new["model.norm.weight"] = old["decoder/decoder_norm/scale"]
+
+ # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
+ if "decoder/logits_dense/kernel" in old:
+ new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
+
+ return new
+
+
+def make_state_dict(converted_params):
+ """Prepares a state dict for the PyTorch model."""
+ # Make a state dict with torch tensors.
+ state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
+
+ return state_dict
+
+
+def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
+ """Replaces the params in model witht the T5X converted params."""
+ variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
+ converted = convert_t5x_to_pytorch(variables,
+ num_layers=config.num_hidden_layers,
+ moe_interval=config.moe_layer_interval)
+ state_dict = make_state_dict(converted)
+ model.load_state_dict(state_dict, strict=True)
+
+
+def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
+ """Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
+ # Initialise PyTorch model
+ config = LlamaConfig.from_json_file(config_file)
+ print(f"Building PyTorch model from configuration: {config}")
+ # Non-v1.1 checkpoints could also use T5Model, but this works for all.
+ # The v1.0 checkpoints will simply have an LM head that is the word embeddings.
+ model = OpenMoeForCausalLM(config)
+
+ # Load weights from tf checkpoint
+ load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
+
+ # Save pytorch-model
+ print(f"Save PyTorch model to {pytorch_dump_path}")
+ model.save_pretrained(pytorch_dump_path)
+
+ # Verify that we can load the checkpoint.
+ model.from_pretrained(pytorch_dump_path)
+ print("Done")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
+ # Required parameters
+ parser.add_argument("--t5x_checkpoint_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to the T5X checkpoint.")
+ parser.add_argument(
+ "--config_file",
+ default=None,
+ type=str,
+ required=True,
+ help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
+ )
+ parser.add_argument("--pytorch_dump_path",
+ default=None,
+ type=str,
+ required=True,
+ help="Path to the output PyTorch model.")
+ args = parser.parse_args()
+ convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)
diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.sh b/examples/language/openmoe/model/convert_openmoe_ckpt.sh
new file mode 100644
index 000000000000..c0d53f562e40
--- /dev/null
+++ b/examples/language/openmoe/model/convert_openmoe_ckpt.sh
@@ -0,0 +1 @@
+python convert_openmoe_ckpt.py --t5x_checkpoint_path /path/to/t5x --config_file /path/to/config --pytorch_dump_path /path/to/save
diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py
new file mode 100644
index 000000000000..ec7644317903
--- /dev/null
+++ b/examples/language/openmoe/model/modeling_openmoe.py
@@ -0,0 +1,1116 @@
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch OpenMoE model."""
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRMSNorm
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+
+from colossalai.kernel.cuda_native.mha.flash_attn_2 import HAS_FLASH_ATTN
+from colossalai.kernel.triton.llama_act_combine_kernel import HAS_TRITON
+from colossalai.moe.layers import SparseMLP
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.moe.utils import get_activation, set_moe_args
+
+if HAS_TRITON:
+ from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "LlamaConfig"
+
+
+def set_openmoe_args(
+ config: LlamaConfig,
+ num_experts: int,
+ moe_layer_interval: int,
+ router_topk: int = 2,
+ router_capacity_factor_train: float = 1.25,
+ router_capacity_factor_eval: float = 2.0,
+ router_min_capacity: int = 4,
+ router_noisy_policy: str = None,
+ router_drop_tks: bool = True,
+ router_aux_loss_factor: float = 0.01,
+ router_z_loss_factor: float = 0.0001,
+ mlp_gated: bool = True,
+ label_smoothing: float = 0.001,
+ z_loss_factor: float = 0.01,
+ enable_load_balance: bool = False,
+ load_balance_tolerance: float = 0.1,
+ load_balance_beam_width: int = 8,
+ load_balance_group_swap_factor: float = 0.4,
+ enable_kernel: bool = False,
+ enable_comm_overlap: bool = False,
+ enable_hierarchical_alltoall: bool = False,
+) -> None:
+ """
+ MoE related arguments.
+ It inserts the MoE arguments into the Llama config.
+
+ Args:
+ config (LlamaConfig): Transformers Llama config.
+ num_experts (int, optional): Number of experts.
+ moe_layer_interval (int, optional): The interval moe layer.
+ router_topk (int, optional): Moe router top k. Defaults to 2.
+ router_capacity_factor_train (float, optional): Moe router max capacity for train. Defaults to 1.25.
+ router_capacity_factor_eval (float, optional): Moe router max capacity for eval. Defaults to 2.0.
+ router_min_capacity (int, optional): Moe router min capacity. Defaults to 4.
+ router_noisy_policy (str, optional): Moe router noisy policy. You can choose [Jitter, Gaussian, None]. Defaults to None.
+ router_drop_tks (bool, optional): Whether moe router drop tokens which exceed max capacity. Defaults to True.
+ router_aux_loss_factor (float, optional): Moe router aux loss. You can refer to STMoE for details. Defaults to 0.01.
+ router_z_loss_factor (float, optional): Moe router z loss. You can refer to STMoE for details. Defaults to 0.01.
+ mlp_gated (bool, optional): Use gate in mlp. Defaults to True.
+ label_smoothing (float, optional): Label smoothing. Defaults to 0.001.
+ z_loss_factor (float, optional): The final outputs' classification z loss factor. Defaults to 0.01.
+ enable_load_balance (bool, optional): Expert load balance. Defaults to False.
+ load_balance_tolerance (float, optional): Expert load balance search's difference tolerance. Defaults to 0.1.
+ load_balance_beam_width (int, optional): Expert load balance search's beam width. Defaults to 8.
+ load_balance_group_swap_factor (float, optional): Expert load balance group swap factor. Longer value encourages less swap. Defaults to 0.4.
+ enable_kernel (bool, optional): Use kernel optimization. Defaults to False.
+ enable_comm_overlap (bool, optional): Use communication overlap for MoE. Recommended to enable for muiti-node training. Defaults to False.
+ enable_hierarchical_alltoall (bool, optional): Use hierarchical alltoall for MoE. Defaults to False.
+ """
+ moe_args = dict(
+ num_experts=num_experts,
+ moe_layer_interval=moe_layer_interval,
+ router_topk=router_topk,
+ router_capacity_factor_train=router_capacity_factor_train,
+ router_capacity_factor_eval=router_capacity_factor_eval,
+ router_min_capacity=router_min_capacity,
+ router_noisy_policy=router_noisy_policy,
+ router_drop_tks=router_drop_tks,
+ router_aux_loss_factor=router_aux_loss_factor,
+ router_z_loss_factor=router_z_loss_factor,
+ mlp_gated=mlp_gated,
+ label_smoothing=label_smoothing,
+ z_loss_factor=z_loss_factor,
+ enable_load_balance=enable_load_balance,
+ load_balance_tolerance=load_balance_tolerance,
+ load_balance_beam_width=load_balance_beam_width,
+ load_balance_group_swap_factor=load_balance_group_swap_factor,
+ enable_kernel=enable_kernel,
+ enable_comm_overlap=enable_comm_overlap,
+ enable_hierarchical_alltoall=enable_hierarchical_alltoall,
+ )
+ set_moe_args(config, moe_args)
+
+
+# Copied from transformers.models.bart.modeling_bart._make_causal_mask
+def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
+):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timescale=10000.0):
+ """Generate Sin/Cos for Rotary Embeddings.
+
+ Args:
+ features: an integer
+ length: an integer
+ min_timescale: an optional float
+ max_timescale: an optional float
+
+ Returns:
+ output_sin: a float32 Tensor with shape [length, features]
+ output_cos: a float32 Tensor with shape [length, features]
+ """
+ fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features
+ timescale = min_timescale * (max_timescale / min_timescale) ** fraction
+ rotational_frequency = 1.0 / timescale
+
+ sinusoid_inp = torch.einsum("i,j->ij", torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency)
+
+ sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)
+
+ return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
+
+
+def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None):
+ """Helper function to apply Rotary Embeddings."""
+ cos = cos.to(q.dtype)
+ sin = sin.to(q.dtype)
+
+ if len(k.shape) == 3:
+ # for multi query attention
+ k = k.unsqueeze(2)
+ multiquery = True
+ else:
+ multiquery = False
+
+ batch, qlen, qheads, d = q.shape
+ kbatch, klen, kheads, kd = k.shape
+ assert batch == kbatch, f"{batch} != {kbatch}"
+ assert d == kd, f"{d} != {kd}"
+ if decode and qlen == 1 and rotary_index is not None:
+ qcos = cos[rotary_index + 1, :]
+ qsin = sin[rotary_index + 1, :]
+ qcos = qcos.unsqueeze(2)
+ qsin = qsin.unsqueeze(2)
+ kcos, ksin = cos[:klen, :], sin[:klen, :]
+ kcos = kcos.unsqueeze(0).unsqueeze(2)
+ ksin = ksin.unsqueeze(0).unsqueeze(2)
+ else:
+ qcos, qsin = cos[:qlen, :], sin[:qlen, :]
+ qcos = qcos.unsqueeze(0).unsqueeze(2)
+ qsin = qsin.unsqueeze(0).unsqueeze(2)
+ kcos, ksin = qcos, qsin
+
+ out_q = (q * qcos) + (rotate_half(q) * qsin)
+ out_k = (k * kcos) + (rotate_half(k) * ksin)
+
+ if multiquery:
+ out_k = out_k.squeeze(2)
+
+ return out_q, out_k
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def SwiGLU(x):
+ """Gated linear unit activation function.
+ Args:
+ x : input array
+ axis: the axis along which the split should be computed (default: -1)
+ """
+ size = x.shape[-1]
+ assert size % 2 == 0, "axis size must be divisible by 2"
+ x1, x2 = torch.split(x, size // 2, -1)
+ return x1 * (x2 * torch.sigmoid(x2))
+
+
+class OpenMoeMLP(nn.Module):
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.pretraining_tp = config.pretraining_tp
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.hidden_act = config.hidden_act
+ self.act_fn = get_activation(self.hidden_act)
+ self.use_kernel = config.enable_kernel
+
+ def forward(self, x):
+ if self.pretraining_tp > 1:
+ slice = self.intermediate_size // self.pretraining_tp
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
+
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
+
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
+ down_proj = sum(down_proj)
+ else:
+ if HAS_TRITON and self.use_kernel and self.hidden_act == "swiglu":
+ down_proj = self.down_proj(LlamaActCombine.apply(self.gate_proj(x), self.up_proj(x)))
+ else:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+ return down_proj
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class OpenMoeAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.pretraining_tp = config.pretraining_tp
+ self.max_position_embeddings = config.max_position_embeddings
+
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ use_kernel: bool = True,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ if self.pretraining_tp > 1:
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
+ query_states = torch.cat(query_states, dim=-1)
+
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
+ key_states = torch.cat(key_states, dim=-1)
+
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
+ value_states = torch.cat(value_states, dim=-1)
+
+ else:
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ max_length = max(query_states.shape[1], key_states.shape[1])
+ assert max_length <= self.sin.shape[0]
+ sin, cos = self.sin[:max_length], self.cos[:max_length]
+ # TODO: for inference, we can add emb kv into cache to avoid computation
+ query_states, key_states = apply_rotary_embedding(
+ query_states, key_states, cos, sin, decode=True if q_len == 1 else False, rotary_index=position_ids
+ )
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if HAS_FLASH_ATTN and use_kernel:
+ from flash_attn import flash_attn_func
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ attn_output = flash_attn_func(query_states, key_states, value_states, softmax_scale=1.0, causal=True)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ if self.training:
+ attention_mask = attention_mask.clone().detach()
+ attention_mask[:, :, :, 0] = 0
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
+
+ if self.pretraining_tp > 1:
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
+ else:
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class OpenMoeDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig, moe: bool):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.moe = moe
+ self.self_attn = OpenMoeAttention(config=config)
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ if self.moe:
+ self.mlp = SparseMLP(
+ num_experts=config.num_experts,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ router_top_k=config.router_topk,
+ router_capacity_factor_train=config.router_capacity_factor_train,
+ router_capacity_factor_eval=config.router_capacity_factor_eval,
+ router_min_capacity=config.router_min_capacity,
+ router_noisy_policy=config.router_noisy_policy,
+ router_drop_tks=config.router_drop_tks,
+ mlp_activation=config.hidden_act,
+ mlp_gated=config.mlp_gated,
+ enable_load_balance=config.enable_load_balance,
+ load_balance_tolerance=config.load_balance_tolerance,
+ load_balance_beam_width=config.load_balance_beam_width,
+ load_balance_group_swap_factor=config.load_balance_group_swap_factor,
+ enable_kernel=config.enable_kernel,
+ enable_comm_overlap=config.enable_comm_overlap,
+ )
+ self.pre_extra_mlp_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.extra_mlp = OpenMoeMLP(config)
+ else:
+ self.mlp = OpenMoeMLP(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ if self.moe:
+ residual = hidden_states
+ hidden_states = self.pre_extra_mlp_layernorm(hidden_states)
+ hidden_states = self.extra_mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+LLAMA_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`LlamaConfig`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class OpenMoePreTrainedModel(PreTrainedModel):
+ config_class = LlamaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["LlamaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, OpenMoeModel):
+ module.gradient_checkpointing = value
+
+
+LLAMA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
+ LLAMA_START_DOCSTRING,
+)
+class OpenMoeModel(OpenMoePreTrainedModel):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
+
+ Args:
+ config: LlamaConfig
+ """
+
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [
+ OpenMoeDecoderLayer(config, moe=True if (i + 1) % config.moe_layer_interval == 0 else False)
+ for i in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
+ inputs_embeds.device
+ )
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class OpenMoeForCausalLM(OpenMoePreTrainedModel):
+ # _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = OpenMoeModel(config)
+ self.pretraining_tp = config.pretraining_tp
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ chunk_head: Optional[bool] = True,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ # reset moe loss
+ MOE_MANAGER.reset_loss()
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ if self.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
+ logits = torch.cat(logits, dim=-1)
+
+ loss = None
+ # if no training, just do forward
+ if labels is None:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+ # the vocab size for openmoe is 30w+
+ # which causes great activation memory in training, up to 20G for one sequence
+ # so we use chunk and checkpoint to reduce memory
+ else:
+ if chunk_head == True:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ logits = module(inputs[0])
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous().float()
+ shift_labels = inputs[1][..., 1:].contiguous()
+ # Flatten the tokens
+ loss = self._calculate_loss(shift_logits, shift_labels)
+ return loss
+
+ return custom_forward
+
+ aux_loss, z_loss = self._calculate_router_loss()
+ loss = aux_loss + z_loss
+ for batch_idx in range(hidden_states.shape[0]):
+ loss = loss + torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.lm_head),
+ hidden_states[batch_idx : batch_idx + 1, :],
+ labels[batch_idx : batch_idx + 1, :],
+ )
+ logits = None
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ aux_loss, z_loss = self._calculate_router_loss()
+ loss = aux_loss + z_loss
+ loss = loss + self._calculate_loss(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+ def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None):
+ if aux_loss is None or z_loss is None:
+ aux_loss, z_loss = MOE_MANAGER.get_loss()
+ assert len(aux_loss) == len(z_loss) == self.config.num_hidden_layers // self.config.moe_layer_interval
+ aux_loss = self.config.router_aux_loss_factor * sum(aux_loss) / len(aux_loss)
+ z_loss = self.config.router_z_loss_factor * sum(z_loss) / len(z_loss)
+ return aux_loss, z_loss
+
+ def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """Compute cross entropy and entropy for log probs and targets.
+
+ Args:
+ logits: [batch, length, num_classes] float array.
+ targets: categorical targets [batch, length] int array.
+
+ Returns:
+ Tuple of scalar loss.
+ """
+ if len(logits.shape) != len(targets.shape) + 1:
+ raise ValueError(
+ "Incorrect shapes. Got shape %s logits and %s targets" % (str(logits.shape), str(targets.shape))
+ )
+ vocab_size = logits.shape[-1]
+ confidence = 1.0 - self.config.label_smoothing
+ low_confidence = (1.0 - confidence) / (vocab_size - 1)
+ normalizing_constant = -(
+ confidence * math.log(confidence) + (vocab_size - 1) * low_confidence * math.log(low_confidence + 1e-20)
+ )
+
+ # one hot
+ soft_targets = targets[..., None] == torch.arange(vocab_size, device=targets.device).reshape(
+ (1,) * len(targets.shape) + (-1,)
+ )
+ soft_targets = torch.where(
+ soft_targets, torch.full_like(soft_targets, confidence), torch.full_like(soft_targets, low_confidence)
+ )
+ soft_targets = soft_targets.to(torch.float32)
+
+ # cross entropy
+ total_loss = ZLossCrossEntropy.apply(logits, soft_targets, self.config.z_loss_factor)
+ total_loss = total_loss - normalizing_constant
+ total_loss = torch.mean(torch.sum(total_loss, dim=-1), dim=0)
+ return total_loss
+
+
+class ZLossCrossEntropy(torch.autograd.Function):
+ """Computes cross entropy loss with stable custom gradient.
+
+ Computes a stabilized-gradient version of:
+ -jnp.sum(targets * nn.log_softmax(logits), axis=-1)
+
+ If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2
+ will be added to the cross entropy loss (z = softmax normalization constant).
+ The two uses of z_loss are:
+ 1. To keep the logits from drifting too far from zero, which can cause
+ unacceptable roundoff errors in bfloat16.
+ 2. To encourage the logits to be normalized log-probabilities.
+
+ Args:
+ logits: [batch, length, num_classes] float array.
+ targets: categorical one-hot targets [batch, length, num_classes] float
+ array.
+ z_loss: coefficient for auxilliary z-loss loss term.
+
+ Returns:
+ tuple with the total loss and the z_loss, both
+ float arrays with shape [batch, length].
+ """
+
+ @staticmethod
+ def forward(ctx, logits, targets, z_loss):
+ max_logit = torch.max(logits, dim=-1, keepdim=True)[0]
+ shifted = logits - max_logit
+ exp_shifted = torch.exp(shifted)
+ sum_exp = torch.sum(exp_shifted, axis=-1, keepdims=True)
+ sum_exp_log = torch.log(sum_exp)
+ log_softmax = shifted - sum_exp_log
+ loss = -torch.sum(targets * log_softmax, axis=-1)
+ # Add auxilliary z-loss term.
+ log_z = torch.squeeze(sum_exp_log + max_logit, axis=-1)
+ total_z_loss = z_loss * torch.square(log_z)
+ loss += total_z_loss
+ ctx.z_loss = z_loss
+ ctx.save_for_backward(logits, targets, exp_shifted, sum_exp, log_softmax, log_z)
+ return loss
+
+ @staticmethod
+ def backward(ctx, *grad_outputs):
+ assert len(grad_outputs) == 1
+ g = grad_outputs[0]
+ z_loss = ctx.z_loss
+ logits, targets, exp_shifted, sum_exp, log_softmax, log_z = ctx.saved_tensors
+ # z-loss term adds the (2 * z_loss * log_z) factor.
+ deriv = (1 + 2 * z_loss * log_z).unsqueeze(-1) * exp_shifted / sum_exp - targets
+ g_logits = g.unsqueeze(-1) * deriv
+ g_targets = -g.unsqueeze(-1) * log_softmax
+
+ return (
+ g_logits.to(logits.dtype),
+ g_targets.to(targets.dtype),
+ None,
+ )
diff --git a/examples/language/openmoe/model/openmoe_8b_config.json b/examples/language/openmoe/model/openmoe_8b_config.json
new file mode 100644
index 000000000000..248697c37d3c
--- /dev/null
+++ b/examples/language/openmoe/model/openmoe_8b_config.json
@@ -0,0 +1,24 @@
+{
+ "architectures": [
+ "OpenMoeForCausalLM"
+ ],
+ "intermediate_size": 8192,
+ "hidden_size": 2048,
+ "num_hidden_layers": 24,
+ "head_dim": 128,
+ "num_attention_heads": 24,
+ "dropout_rate": 0.0,
+ "layer_norm_epsilon": 1e-06,
+ "vocab_size": 256384,
+ "hidden_act": "swiglu",
+ "num_experts": 32,
+ "topk": 2,
+ "capacity_factor_train": 1.25,
+ "capacity_factor_eval": 2.0,
+ "min_capacity": 4,
+ "noisy_policy": null,
+ "drop_tks": true,
+ "expert_parallel": null,
+ "gated": true,
+ "moe_layer_interval": 6
+}
diff --git a/examples/language/openmoe/model/openmoe_base_config.json b/examples/language/openmoe/model/openmoe_base_config.json
new file mode 100644
index 000000000000..5a7c97bd1916
--- /dev/null
+++ b/examples/language/openmoe/model/openmoe_base_config.json
@@ -0,0 +1,24 @@
+{
+ "architectures": [
+ "OpenMoeForCausalLM"
+ ],
+ "intermediate_size": 2048,
+ "hidden_size": 768,
+ "num_hidden_layers": 12,
+ "head_dim": 64,
+ "num_attention_heads": 12,
+ "dropout_rate": 0.0,
+ "layer_norm_epsilon": 1e-06,
+ "vocab_size": 256384,
+ "hidden_act": "swiglu",
+ "num_experts": 16,
+ "topk": 2,
+ "capacity_factor_train": 1.25,
+ "capacity_factor_eval": 2.0,
+ "min_capacity": 4,
+ "noisy_policy": null,
+ "drop_tks": true,
+ "expert_parallel": null,
+ "gated": true,
+ "moe_layer_interval": 4
+}
diff --git a/examples/language/openmoe/model/openmoe_policy.py b/examples/language/openmoe/model/openmoe_policy.py
new file mode 100644
index 000000000000..f354bbea990e
--- /dev/null
+++ b/examples/language/openmoe/model/openmoe_policy.py
@@ -0,0 +1,562 @@
+import warnings
+from functools import partial
+from typing import Callable, Dict, List, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import Module
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.utils import logging
+
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col
+from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+
+from .modeling_openmoe import OpenMoeDecoderLayer, OpenMoeForCausalLM, OpenMoeModel
+
+__all__ = ["OpenMoePolicy", "OpenMoeForCausalLMPolicy"]
+
+
+class OpenMoePolicy(Policy):
+
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ if self.shard_config.enable_tensor_parallelism:
+ # Resize embedding
+ vocab_size = self.model.config.vocab_size
+ world_size = self.shard_config.tensor_parallel_size
+
+ if vocab_size % world_size != 0:
+ new_vocab_size = vocab_size + world_size - vocab_size % world_size
+ self.model.resize_token_embeddings(new_vocab_size)
+
+ return self.model
+
+ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+ policy = {}
+
+ if self.shard_config.enable_sequence_parallelism:
+ self.shard_config.enable_sequence_parallelism = False
+ raise NotImplementedError(
+ "openmoe dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
+
+ if self.shard_config.enable_tensor_parallelism:
+ raise NotImplementedError("Tensor parallelism is not supported for openmoe model now.")
+
+ # optimization configuration
+ if self.shard_config.enable_fused_normalization:
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="input_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="post_attention_layernorm",
+ target_module=FusedRMSNorm,
+ ),
+ SubModuleReplacementDescription(
+ suffix="pre_extra_mlp_layernorm",
+ target_module=FusedRMSNorm,
+ ignore_if_not_exist=True,
+ ),
+ ],
+ policy=policy,
+ target_key=OpenMoeDecoderLayer,
+ )
+
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="norm",
+ target_module=FusedRMSNorm,
+ ),
+ policy=policy,
+ target_key=OpenMoeModel,
+ )
+
+ if self.shard_config.enable_flash_attention:
+ raise NotImplementedError("Flash attention has already been replaced in openmoe.")
+
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager:
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "OpenMoeModel":
+ module = self.model
+ else:
+ module = self.model.model
+
+ layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
+ method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
+ self.append_or_create_method_replacement(description=method_replacement,
+ policy=policy,
+ target_key=model_cls)
+
+ return
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == "OpenMoeModel":
+ module = self.model
+ else:
+ module = self.model.model
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.norm)
+
+ return held_layers
+
+ @staticmethod
+ def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
+ """Divide layers into stages
+
+ """
+ if num_layers == 24 and num_stages == 4:
+ return [7, 7, 7, 3]
+ elif num_layers == 24 and num_stages == 2:
+ return [15, 9]
+ elif num_layers == 12 and num_stages == 4:
+ return [5, 5, 5, 1]
+ elif num_layers == 12 and num_stages == 2:
+ return [8, 4]
+ else:
+ print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy")
+ return Policy.distribute_layers(num_layers, num_stages)
+
+
+class OpenMoeModelPolicy(OpenMoePolicy):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def module_policy(self):
+ policy = super().module_policy()
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=OpenMoeModel,
+ new_forward=OpenMoePipelineForwards.openmoe_model_forward,
+ policy=policy,
+ )
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in llama model"""
+ return []
+
+
+class OpenMoeForCausalLMPolicy(OpenMoePolicy):
+
+ def module_policy(self):
+ policy = super().module_policy()
+
+ if self.shard_config.enable_tensor_parallelism:
+ # add a new item for casual lm
+ new_item = {
+ OpenMoeForCausalLM:
+ ModulePolicyDescription(sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=Linear1D_Col,
+ kwargs=dict(gather_output=True),
+ )
+ ])
+ }
+ policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=OpenMoeForCausalLM,
+ new_forward=OpenMoePipelineForwards.llama_for_causal_lm_forward,
+ policy=policy,
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_last_stage():
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ llama_model = self.model.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if (id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight)
+ and self.pipeline_stage_manager.num_stages > 1):
+ # tie weights
+ return [{
+ 0: llama_model.embed_tokens.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
+ }]
+ return []
+
+
+class OpenMoePipelineForwards:
+ """
+ This class serves as a micro library for forward function substitution of Llama models
+ under pipeline setting.
+ """
+
+ @staticmethod
+ def openmoe_model_forward(
+ self: OpenMoeModel,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ past_router_aux_loss: Optional[torch.FloatTensor] = None,
+ past_router_z_loss: Optional[torch.FloatTensor] = None,
+ ):
+ # reset moe loss for different data
+ MOE_MANAGER.reset_loss()
+
+ logger = logging.get_logger(__name__)
+
+ output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
+
+ # retrieve input_ids and inputs_embeds
+ if stage_manager.is_first_stage():
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
+ use_cache = False
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions, for the first stage, hidden_states is the input embeddings,
+ # for the other stages, hidden_states is the output of the previous stage
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past),
+ dtype=torch.bool,
+ device=hidden_states.device,
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
+ )
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = (past_key_values[idx] if past_key_values is not None else None)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if stage_manager.is_last_stage():
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ next_cache = next_decoder_cache if use_cache else None
+
+ # concat past losses with current ones
+ router_aux_loss, router_z_loss = MOE_MANAGER.get_loss()
+ if past_router_aux_loss is not None and past_router_z_loss is not None:
+ router_aux_loss = past_router_aux_loss + router_aux_loss
+ router_z_loss = past_router_z_loss + router_z_loss
+
+ if stage_manager.is_last_stage():
+ return tuple([
+ hidden_states,
+ next_cache,
+ all_hidden_states,
+ all_self_attns,
+ router_aux_loss,
+ router_z_loss,
+ ])
+ # always return dict for imediate stage
+ return {
+ "hidden_states": hidden_states,
+ "router_aux_loss": router_aux_loss,
+ "router_z_loss": router_z_loss,
+ }
+
+ @staticmethod
+ def llama_for_causal_lm_forward(
+ self: OpenMoeForCausalLM,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ chunk_head: Optional[bool] = True,
+ past_router_aux_loss: Optional[torch.FloatTensor] = None,
+ past_router_z_loss: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+ logger = logging.get_logger(__name__)
+ output_attentions = (output_attentions if output_attentions is not None else self.config.output_attentions)
+ output_hidden_states = (output_hidden_states
+ if output_hidden_states is not None else self.config.output_hidden_states)
+ return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = OpenMoePipelineForwards.openmoe_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ past_router_aux_loss=past_router_aux_loss,
+ past_router_z_loss=past_router_z_loss,
+ )
+
+ if stage_manager.is_last_stage():
+ (
+ hidden_states,
+ past_key_values,
+ all_hidden_states,
+ attentions,
+ router_aux_loss,
+ router_z_loss,
+ ) = outputs
+
+ if self.pretraining_tp > 1:
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
+ logits = torch.cat(logits, dim=-1)
+
+ loss = None
+ # if no training, just do forward
+ if labels is None:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+ # the vocab size for openmoe is 30w+
+ # which causes great activation memory in training, up to 20G for one sequence
+ # so we use chunk and checkpoint to reduce memory
+ else:
+ if chunk_head == True:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ logits = module(inputs[0])
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous().float()
+ shift_labels = inputs[1][..., 1:].contiguous()
+ # Flatten the tokens
+ loss = self._calculate_loss(shift_logits, shift_labels)
+ return loss
+
+ return custom_forward
+
+ aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
+ loss = aux_loss + z_loss
+ for batch_idx in range(hidden_states.shape[0]):
+ loss = loss + torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.lm_head),
+ hidden_states[batch_idx:batch_idx + 1, :],
+ labels[batch_idx:batch_idx + 1, :],
+ )
+ logits = None
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+ # Shift so that tokens < n predict n
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ aux_loss, z_loss = self._calculate_router_loss(router_aux_loss, router_z_loss)
+ loss = aux_loss + z_loss
+ loss = loss + self._calculate_loss(shift_logits, shift_labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ attentions=attentions,
+ )
+ else:
+ hidden_states = outputs["hidden_states"]
+ router_aux_loss = outputs["router_aux_loss"]
+ router_z_loss = outputs["router_z_loss"]
+ return {
+ "hidden_states": hidden_states,
+ "past_router_aux_loss": router_aux_loss,
+ "past_router_z_loss": router_z_loss,
+ }
diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt
new file mode 100644
index 000000000000..6b9f807116df
--- /dev/null
+++ b/examples/language/openmoe/requirements.txt
@@ -0,0 +1,5 @@
+colossalai >= 0.3.3
+torch >= 1.8.1
+transformers >= 4.20.0, <= 4.34.0
+sentencepiece
+datasets
diff --git a/examples/language/openmoe/test_ci.sh b/examples/language/openmoe/test_ci.sh
new file mode 100644
index 000000000000..960c83adb489
--- /dev/null
+++ b/examples/language/openmoe/test_ci.sh
@@ -0,0 +1,37 @@
+pip install -r requirements.txt
+
+# inference
+python infer.py --model "test"
+
+# train
+torchrun --standalone --nproc_per_node 4 train.py \
+ --num_epoch 1 \
+ --model_name "test" \
+ --plugin "ep" \
+ --batch_size 1
+
+torchrun --standalone --nproc_per_node 4 train.py \
+ --num_epoch 1 \
+ --model_name "test" \
+ --plugin "ep_zero" \
+ --batch_size 1 \
+ --zero_stage 1 \
+ --extra_dp_size 2 \
+
+torchrun --standalone --nproc_per_node 4 train.py \
+ --num_epoch 1 \
+ --model_name "test" \
+ --plugin "ep_zero" \
+ --batch_size 1 \
+ --zero_stage 2 \
+ --extra_dp_size 2 \
+
+torchrun --standalone --nproc_per_node 4 train.py \
+ --model_name "test" \
+ --plugin "hybrid" \
+ --num_epoch 1 \
+ --pp_size 2 \
+ --dp_size 1 \
+ --ep_size 2 \
+ --zero_stage 1 \
+ --batch_size 1
diff --git a/examples/language/openmoe/train.py b/examples/language/openmoe/train.py
new file mode 100644
index 000000000000..b084361661ac
--- /dev/null
+++ b/examples/language/openmoe/train.py
@@ -0,0 +1,382 @@
+import argparse
+import os
+from functools import partial
+from typing import Dict
+
+import torch
+import torch.distributed as dist
+from datasets import load_dataset
+from huggingface_hub import snapshot_download
+from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args
+from model.openmoe_policy import OpenMoeForCausalLMPolicy
+from torch.utils.data import Dataset
+from tqdm import tqdm
+from transformers import T5Tokenizer
+from transformers.models.llama import LlamaConfig
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.moe.layers import apply_load_balance
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.moe.utils import skip_init
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+
+def move_to_cuda(batch, device):
+ return {k: v.to(device) for k, v in batch.items()}
+
+
+def load_ckpt(repo_name: str, model: OpenMoeForCausalLM, booster: Booster):
+ ckpt_path = snapshot_download(repo_name)
+ # single ckpt
+ if os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin")):
+ ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin")
+ # shard ckpt
+ elif os.path.exists(os.path.join(ckpt_path, "pytorch_model.bin.index.json")):
+ ckpt_path = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
+ else:
+ raise ValueError(f"Invalid checkpoint path: {ckpt_path}")
+ booster.load_model(model, ckpt_path)
+
+
+def tokenize_data(batch, tokenizer: T5Tokenizer, max_length: int) -> Dict:
+ texts = ["" + sample["prompt"] + sample["completion"] for sample in batch]
+ data = tokenizer(
+ texts,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=max_length,
+ add_special_tokens=False,
+ )
+ data = {k: v.cuda() for k, v in data.items()}
+ data["labels"] = data["input_ids"].clone()
+ return data
+
+
+class RandomDataset(Dataset):
+ def __init__(self, num_samples: int = 1000, max_length: int = 2048, vocab_size: int = 32000, tokenizer=None):
+ self.num_samples = num_samples
+ self.max_length = max_length
+ self.input_ids = torch.randint(0, vocab_size, (num_samples, max_length), device=get_current_device())
+ self.attention_mask = torch.ones_like(self.input_ids)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, idx):
+ return {
+ "input_ids": self.input_ids[idx],
+ "attention_mask": self.attention_mask[idx],
+ "labels": self.input_ids[idx],
+ }
+
+
+def parse_args():
+ # basic settings
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_name",
+ type=str,
+ default="base",
+ choices=["base", "8b", "test"],
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="hybrid",
+ choices=["ep", "ep_zero", "hybrid"],
+ help="Parallel methos. ep_zero is recommended for general cases. ep can provides least memory consumption and hybrid suits large scale training.",
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ default="./outputs",
+ help="The path of your saved model after finetuning.",
+ )
+ parser.add_argument("--num_epoch", type=int, default=1, help="Number of epochs.")
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="Batch size (per dp group) for the training dataloader.",
+ )
+ parser.add_argument(
+ "--save_interval",
+ type=int,
+ default=1000,
+ help=" The interval (steps) of saving checkpoints.",
+ )
+ parser.add_argument(
+ "--precision",
+ type=str,
+ default="bf16",
+ choices=["fp32", "bf16", "fp16"],
+ help="The mixed precision training.",
+ )
+ parser.add_argument("--max_length", type=int, default=2048, help="Max sequence length.")
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default="yizhongw/self_instruct",
+ help="dataset name from `datasets` repo.",
+ )
+ parser.add_argument(
+ "--task_name",
+ type=str,
+ default="super_natural_instructions",
+ help="task of corresponding dataset.",
+ )
+
+ # optim
+ parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate.")
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
+
+ # zero stage for all plugins
+ parser.add_argument("--zero_stage", type=int, default=2, help="zero stage.")
+ # ep_zero plugin
+ parser.add_argument(
+ "--extra_dp_size", type=int, default=1, help="ep_zero plugin's moe dp size. Recommended to be 2 or 4."
+ )
+ # hybrid plugin
+ parser.add_argument("--pp_size", type=int, default=2, help="pp size for hybrid plugin")
+ parser.add_argument("--dp_size", type=int, default=1, help="dp size for hybrid plugin")
+ parser.add_argument("--ep_size", type=int, default=2, help="ep size for hybrid plugin")
+ parser.add_argument("--microbatch_size", type=int, default=1, help="Microbatch size in pipeline for hybrid plugin")
+
+ # kernel
+ parser.add_argument(
+ "--use_kernel",
+ action="store_true",
+ help="Use kernel optim. Need to install flash attention and triton to enable all kernel optimizations. Skip if not installed.",
+ )
+ parser.add_argument(
+ "--use_layernorm_kernel",
+ action="store_true",
+ help="Use layernorm kernel. Need to install apex. Raise error if not installed.",
+ )
+
+ # loss
+ parser.add_argument(
+ "--router_aux_loss_factor",
+ type=float,
+ default=0.01,
+ help="Moe router z loss. You can refer to STMoE for details.",
+ )
+ parser.add_argument(
+ "--router_z_loss_factor",
+ type=float,
+ default=0.0001,
+ help="Moe router aux loss. You can refer to STMoE for details.",
+ )
+ parser.add_argument("--label_smoothing", type=float, default=0.0, help="Label smoothing.")
+ parser.add_argument(
+ "--z_loss_factor", type=float, default=0.0001, help="The final outputs' classification z loss factor."
+ )
+
+ # load balance
+ parser.add_argument(
+ "--load_balance", action="store_true", help="Expert load balance. Defaults to False. Recommend to enable."
+ )
+ parser.add_argument("--load_balance_interval", type=int, default=1000, help="Expert load balance interval.")
+ # communicate overlap
+ parser.add_argument(
+ "--comm_overlap",
+ action="store_true",
+ help="Use communication overlap for MoE. Recommended to enable for muiti-node training.",
+ )
+ # hierarchical all-to-all
+ parser.add_argument(
+ "--hierarchical_alltoall",
+ action="store_true",
+ help="Use hierarchical all-to-all for MoE. Recommended to enable for muiti-node training.",
+ )
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+ test_mode = args.model_name == "test"
+
+ # Set plugin
+ booster_kwargs = {}
+ hybrid_dict = {
+ "tp_size": 1,
+ "custom_policy": OpenMoeForCausalLMPolicy(),
+ "enable_fused_normalization": args.use_layernorm_kernel,
+ "enable_jit_fused": args.use_kernel,
+ "precision": args.precision,
+ "zero_stage": args.zero_stage,
+ }
+ mgr_dict = {}
+ if args.plugin == "ep":
+ dp_size = dist.get_world_size()
+ plugin = MoeHybridParallelPlugin(
+ pp_size=1,
+ **hybrid_dict,
+ )
+ MOE_MANAGER.setup(
+ parallel="EP",
+ max_ep_size=dp_size,
+ **mgr_dict,
+ )
+ elif args.plugin == "ep_zero":
+ dp_size = dist.get_world_size()
+ use_ep_inside = False
+ plugin = MoeHybridParallelPlugin(
+ pp_size=1,
+ extra_dp_size=args.extra_dp_size,
+ use_ep_inside=use_ep_inside,
+ **hybrid_dict,
+ )
+ MOE_MANAGER.setup(
+ parallel="EP",
+ max_ep_size=dp_size // args.extra_dp_size,
+ use_ep_inside=use_ep_inside,
+ **mgr_dict,
+ )
+ elif args.plugin == "hybrid":
+ dp_size = dist.get_world_size() // args.pp_size
+ plugin = MoeHybridParallelPlugin(
+ pp_size=args.pp_size,
+ microbatch_size=args.microbatch_size,
+ **hybrid_dict,
+ )
+ MOE_MANAGER.setup(
+ parallel="EP",
+ mode="fixed",
+ fixed_dp_size=args.dp_size,
+ fixed_ep_size=args.ep_size,
+ fixed_pp_size=args.pp_size,
+ **mgr_dict,
+ )
+ else:
+ raise ValueError(f"Invalid plugin {args.plugin}")
+ coordinator.print_on_master(f"Set plugin as {plugin.__class__.__name__}")
+
+ # Build OpenMoe model
+ if test_mode:
+ config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base")
+ config.hidden_size = 128
+ config.intermediate_size = 256
+ config.vocab_size = 32000
+ else:
+ repo_name = "hpcaitech/openmoe-" + args.model_name
+ config = LlamaConfig.from_pretrained(repo_name)
+ set_openmoe_args(
+ config,
+ num_experts=config.num_experts,
+ moe_layer_interval=config.moe_layer_interval,
+ router_aux_loss_factor=args.router_aux_loss_factor,
+ router_z_loss_factor=args.router_z_loss_factor,
+ z_loss_factor=args.z_loss_factor,
+ enable_load_balance=args.load_balance,
+ enable_comm_overlap=args.comm_overlap,
+ enable_hierarchical_alltoall=args.hierarchical_alltoall,
+ enable_kernel=args.use_kernel,
+ )
+ with skip_init():
+ model = OpenMoeForCausalLM(config)
+ coordinator.print_on_master(f"Finish init model with config:\n{config}")
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
+ # Prepare tokenizer and dataloader
+ tokenizer = T5Tokenizer.from_pretrained("google/umt5-small")
+ if test_mode:
+ dataset = RandomDataset(num_samples=20, tokenizer=tokenizer)
+ collate_fn = None
+ else:
+ dataset = load_dataset(args.dataset, args.task_name)
+ dataset = dataset["train"]
+ collate_fn = partial(tokenize_data, tokenizer=tokenizer, max_length=args.max_length)
+ dataloader = plugin.prepare_dataloader(
+ dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=collate_fn
+ )
+
+ # Set optimizer
+ optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
+
+ # Set booster
+ booster = Booster(plugin=plugin, **booster_kwargs)
+ if not test_mode:
+ load_ckpt(repo_name, model, booster)
+ model, optimizer, _, dataloader, _ = booster.boost(model=model, optimizer=optimizer, dataloader=dataloader)
+ use_pipeline = isinstance(booster.plugin, MoeHybridParallelPlugin) and booster.plugin.pp_size > 1
+ is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
+ coordinator.print_on_master(f"Finish init booster")
+
+ # Start finetuning
+ coordinator.print_on_master(f"Start finetuning")
+ for epoch in range(args.num_epoch):
+ model.train()
+ train_dataloader_iter = iter(dataloader)
+ total_len = len(train_dataloader_iter)
+ with tqdm(
+ range(total_len),
+ desc=f"Epoch [{epoch + 1}/{args.num_epoch}]",
+ disable=not coordinator.is_master(),
+ ) as pbar:
+ for step in pbar:
+ if use_pipeline:
+ # Forward pass
+ outputs = booster.execute_pipeline(
+ train_dataloader_iter,
+ model,
+ lambda x, y: x.loss,
+ optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = outputs["loss"]
+ pbar.set_postfix({"loss": loss.item()})
+ else:
+ # Forward pass
+ data = next(train_dataloader_iter)
+ data = move_to_cuda(data, torch.cuda.current_device())
+ outputs = model(**data)
+ loss = outputs["loss"]
+ # Backward
+ booster.backward(loss, optimizer)
+ pbar.set_postfix({"loss": loss.item()})
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # Apply load balance
+ if (
+ args.load_balance
+ and args.load_balance_interval > 0
+ and (step + 1) % args.load_balance_interval == 0
+ ):
+ coordinator.print_on_master(f"Apply load balance")
+ apply_load_balance(model, optimizer)
+ # save ckeckpoint
+ if (step + 1) % args.save_interval == 0:
+ coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
+ booster.save_model(model, args.output_path, shard=True)
+
+ # save checkpoint at the end of each epochs
+ booster.save_model(model, args.output_path, shard=True)
+ coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
+
+ # Finish training
+ coordinator.print_on_master(f"Finish training")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/language/openmoe/train.sh b/examples/language/openmoe/train.sh
new file mode 100644
index 000000000000..91cd3db8d7ba
--- /dev/null
+++ b/examples/language/openmoe/train.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+
+set -xue
+
+NUM_GPU=8
+MODEL="8b"
+SEQ_LENGTH=2048
+BATCH_SIZE=1
+LR=0.00001
+
+# ep zero
+torchrun --standalone --nproc_per_node $NUM_GPU train.py \
+ --num_epoch 1 \
+ --model_name $MODEL \
+ --plugin "ep_zero" \
+ --batch_size $BATCH_SIZE \
+ --lr $LR \
+ --zero_stage 1 \
+ --extra_dp_size 2
+
+# ep
+# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
+# --num_epoch 1 \
+# --model_name $MODEL \
+# --plugin "ep_zero" \
+# --batch_size $BATCH_SIZE \
+# --lr $LR \
+# --zero_stage 1
+
+# hybrid
+# torchrun --standalone --nproc_per_node $NUM_GPU train.py \
+# --num_epoch 1 \
+# --model_name $MODEL \
+# --plugin "hybrid" \
+# --batch_size $BATCH_SIZE \
+# --lr $LR \
+# --zero_stage 1 \
+# --pp_size 2 \
+# --dp_size 1 \
+# --ep_size 2 \
diff --git a/inference b/inference
deleted file mode 160000
index 56b35f3c06ea..000000000000
--- a/inference
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 56b35f3c06eaac11b1bee633d1e836563f74bcea
diff --git a/op_builder/__init__.py b/op_builder/__init__.py
index 808559ec9c2d..21e216437c47 100644
--- a/op_builder/__init__.py
+++ b/op_builder/__init__.py
@@ -1,3 +1,4 @@
+from .arm_cpu_adam import ArmCPUAdamBuilder
from .cpu_adam import CPUAdamBuilder
from .fused_optim import FusedOptimBuilder
from .layernorm import LayerNormBuilder
@@ -29,4 +30,5 @@
"MultiTensorLambBuilder",
"MultiTensorScaleBuilder",
"MultiTensorL2NormBuilder",
+ "ArmCPUAdamBuilder",
]
diff --git a/op_builder/arm_cpu_adam.py b/op_builder/arm_cpu_adam.py
new file mode 100644
index 000000000000..18dd519fae46
--- /dev/null
+++ b/op_builder/arm_cpu_adam.py
@@ -0,0 +1,34 @@
+from .builder import Builder
+
+
+class ArmCPUAdamBuilder(Builder):
+ NAME = "arm_cpu_adam"
+ PREBUILT_IMPORT_PATH = "colossalai._C.arm_cpu_adam"
+ ext_type = "cpu"
+
+ def __init__(self):
+ super().__init__(name=ArmCPUAdamBuilder.NAME, prebuilt_import_path=ArmCPUAdamBuilder.PREBUILT_IMPORT_PATH)
+ self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]
+
+ # necessary 4 functions
+ def sources_files(self):
+ ret = [
+ self.csrc_abs_path("cpu_adam_arm.cpp"),
+ ]
+ return ret
+
+ def include_dirs(self):
+ return [self.csrc_abs_path("includes")]
+
+ def cxx_flags(self):
+ extra_cxx_flags = [
+ "-std=c++14",
+ "-std=c++17",
+ "-g",
+ "-Wno-reorder",
+ "-fopenmp",
+ ]
+ return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
+
+ def nvcc_flags(self):
+ return []
diff --git a/op_builder/builder.py b/op_builder/builder.py
index 75823ef105c7..d804cb1602e4 100644
--- a/op_builder/builder.py
+++ b/op_builder/builder.py
@@ -7,7 +7,7 @@
import time
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import List, Optional
+from typing import List, Optional, Union
from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
@@ -21,6 +21,8 @@ class Builder(ABC):
prebuilt_import_path (str): the path where the extension is installed during pip install
"""
+ ext_type: str = "cuda"
+
def __init__(self, name: str, prebuilt_import_path: str):
self.name = name
self.prebuilt_import_path = prebuilt_import_path
@@ -165,7 +167,8 @@ def load(self, verbose: Optional[bool] = None):
)
except ImportError:
# check environment
- self.check_runtime_build_environment()
+ if self.ext_type == "cuda":
+ self.check_runtime_build_environment()
# time the kernel compilation
start_build = time.time()
@@ -208,11 +211,19 @@ def load(self, verbose: Optional[bool] = None):
return op_module
- def builder(self) -> "CUDAExtension":
+ def builder(self) -> Union["CUDAExtension", "CppExtension"]:
"""
get a CUDAExtension instance used for setup.py
"""
- from torch.utils.cpp_extension import CUDAExtension
+ from torch.utils.cpp_extension import CppExtension, CUDAExtension
+
+ if self.ext_type == "cpp":
+ return CppExtension(
+ name=self.prebuilt_import_path,
+ sources=self.strip_empty_entries(self.sources_files()),
+ include_dirs=self.strip_empty_entries(self.include_dirs()),
+ extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
+ )
return CUDAExtension(
name=self.prebuilt_import_path,
diff --git a/op_builder/cpu_adam.py b/op_builder/cpu_adam.py
index 5a2a2e3e6a56..7988aae4be12 100644
--- a/op_builder/cpu_adam.py
+++ b/op_builder/cpu_adam.py
@@ -21,12 +21,22 @@ def include_dirs(self):
return [self.csrc_abs_path("includes"), self.get_cuda_home_include()]
def cxx_flags(self):
- extra_cxx_flags = ["-std=c++14", "-lcudart", "-lcublas", "-g", "-Wno-reorder", "-fopenmp", "-march=native"]
+ extra_cxx_flags = [
+ "-std=c++14",
+ "-std=c++17",
+ "-lcudart",
+ "-lcublas",
+ "-g",
+ "-Wno-reorder",
+ "-fopenmp",
+ "-march=native",
+ ]
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags
def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
+ "-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
diff --git a/op_builder/gptq.py b/op_builder/gptq.py
index bc4f445de067..a17801f8783c 100644
--- a/op_builder/gptq.py
+++ b/op_builder/gptq.py
@@ -37,12 +37,12 @@ def nvcc_flags(self):
extra_cuda_flags = [
"-v",
"-std=c++14",
+ "-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-DTHRUST_IGNORE_CUB_VERSION_CHECK",
"-lcublas",
- "-std=c++17",
]
for arch in torch.cuda.get_arch_list():
diff --git a/op_builder/multi_head_attn.py b/op_builder/multi_head_attn.py
index b70f041db7d6..cb8fc489ced1 100644
--- a/op_builder/multi_head_attn.py
+++ b/op_builder/multi_head_attn.py
@@ -35,6 +35,7 @@ def cxx_flags(self):
def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
+ "-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
diff --git a/op_builder/scaled_masked_softmax.py b/op_builder/scaled_masked_softmax.py
index b2f1de7792c8..d9239a80eef6 100644
--- a/op_builder/scaled_masked_softmax.py
+++ b/op_builder/scaled_masked_softmax.py
@@ -25,6 +25,7 @@ def cxx_flags(self):
def nvcc_flags(self):
extra_cuda_flags = [
"-std=c++14",
+ "-std=c++17",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
diff --git a/pytest.ini b/pytest.ini
index 38ad7d76de50..598e0a74e71c 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -2,4 +2,4 @@
markers =
dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs)
largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs)
-addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy
+addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_fx --ignore=tests/test_legacy
diff --git a/requirements/requirements-infer.txt b/requirements/requirements-infer.txt
new file mode 100644
index 000000000000..f85f9d88e629
--- /dev/null
+++ b/requirements/requirements-infer.txt
@@ -0,0 +1,4 @@
+transformers==4.34.0
+auto-gptq==0.5.0
+git+https://github.com/ModelTC/lightllm.git@ece7b43f8a6dfa74027adc77c2c176cff28c76c8
+git+https://github.com/Dao-AILab/flash-attention.git@017716451d446e464dde9aca3a3c1ed2209caaa9
diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt
index 467f83610eb0..61b58055e666 100644
--- a/requirements/requirements-test.txt
+++ b/requirements/requirements-test.txt
@@ -12,10 +12,12 @@ torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package
torchrec==0.2.0
contexttimer
einops
-triton==2.0.0.dev20221202
+triton==2.1.0
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
SentencePiece
ninja
flash_attn==2.0.5
datasets
+pydantic
+ray
#auto-gptq now not support torch1.12
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index 19cb7a154a01..095617d76355 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -11,6 +11,8 @@ ninja
torch>=1.12
safetensors
einops
+pydantic
+ray
sentencepiece
google
protobuf
diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py
index 041de6b90f8d..4730642705ff 100644
--- a/tests/kit/model_zoo/transformers/llama.py
+++ b/tests/kit/model_zoo/transformers/llama.py
@@ -27,8 +27,10 @@ def data_gen():
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------
- input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
- attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
+ input_ids = torch.Tensor(
+ [[1, 15043, 29892, 590, 11203, 338, 274, 1082], [1, 15043, 29892, 590, 11203, 338, 274, 1082]]
+ ).long()
+ attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)
# label is needed for casual lm
diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py
index 00ff6cb37d2a..61debe47b599 100644
--- a/tests/test_booster/test_plugin/test_gemini_plugin.py
+++ b/tests/test_booster/test_plugin/test_gemini_plugin.py
@@ -1,5 +1,6 @@
from contextlib import nullcontext
from typing import Optional
+import pytest
import torch
import torch.distributed as dist
@@ -10,18 +11,22 @@
from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
-def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
+def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size) -> Optional[str]:
try:
if init_method == "lazy":
ctx = LazyInitContext()
else:
ctx = nullcontext()
- plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5)
+ extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
+ enable_all_optimization = True if tp_size > 1 else False
+ plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
with ctx:
model = model_fn()
@@ -46,6 +51,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
booster.backward(loss, optimizer)
optimizer.step()
+ except NotImplementedError:
+ print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
except Exception as e:
# raise e
return repr(e)
@@ -57,7 +64,9 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
@parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"])
-def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True):
+@parameterize("zero_size", [2])
+@parameterize("tp_size", [2])
+def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True, zero_size: int = 1, tp_size: int = 1):
"""check gemini plugin over model zoo
Args:
@@ -116,7 +125,12 @@ def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool
"torchvision_efficientnet_v2_s",
]:
continue
- err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
+
+ # TODO debug blip2 when using tp, something wrong with shift_logits's shape
+ if "transformers_blip2" in name:
+ tp_size = 1
+
+ err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, zero_size, tp_size)
torch.cuda.empty_cache()
if err is None:
passed_models.append(name)
@@ -142,6 +156,11 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
def test_gemini_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+def test_gemini_plugin_3d(early_stop: bool = True):
+ spawn(run_dist, 8, early_stop=early_stop)
+
if __name__ == "__main__":
- test_gemini_plugin(early_stop=False)
+ test_gemini_plugin(early_stop=False)
\ No newline at end of file
diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
index 104ca254c572..3eaaf882c9ba 100644
--- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
+++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py
@@ -2,11 +2,14 @@
import torch
import torch.distributed as dist
+from torch.optim import Adam
import colossalai
+import colossalai.utils.device as device_utils
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
-from colossalai.nn.optimizer import HybridAdam
+
+# from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
@@ -19,16 +22,17 @@
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
+ device = device_utils.get_current_device()
try:
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model = model_fn()
- optimizer = HybridAdam(model.parameters(), lr=1e-3)
+ optimizer = Adam(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
- k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
+ k: v.to(device) if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
@@ -65,7 +69,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
continue
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn)
- torch.cuda.empty_cache()
+ device_utils.empty_cache()
if err is None:
passed_models.append(name)
@@ -89,7 +93,7 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use()
def test_low_level_zero_plugin(early_stop: bool = True):
- spawn(run_dist, 4, early_stop=early_stop)
+ spawn(run_dist, 2, early_stop=early_stop)
if __name__ == "__main__":
diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
index f876040384b3..8343c5f07e30 100644
--- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
+++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py
@@ -37,17 +37,21 @@
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True])
-def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool):
+@parameterize("tp_size", [1, 2])
+@parameterize("zero_size", [2])
+def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()
+ enable_all_optimization = True if tp_size > 1 else False
with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, "pretrained")
bert_model.config.save_pretrained(save_directory=pretrained_path)
- plugin = GeminiPlugin(**placement_config)
+ extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
+ plugin = GeminiPlugin(**placement_config, tp_size=tp_size, enable_all_optimization=enable_all_optimization, extra_dp_size=extra_dp_size)
booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
@@ -63,13 +67,17 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
-@parameterize("shard", [False, True])
+@parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_gpt"])
@parameterize("size_per_shard", [32])
-def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int):
+@parameterize("tp_size", [1, 2])
+@parameterize("zero_size", [2])
+def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
- plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14))
+ enable_all_optimization = True if tp_size > 1 else False
+ extra_dp_size = dist.get_world_size() // (zero_size * tp_size)
+ plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), tp_size=tp_size, extra_dp_size=extra_dp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin)
model = model_fn()
@@ -148,7 +156,13 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2])
+@pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
+
+@pytest.mark.largedist
+@pytest.mark.parametrize("world_size", [8])
+@rerun_if_address_is_in_use()
+def test_gemini_ckpIO_3d(world_size):
+ spawn(run_dist, world_size)
\ No newline at end of file
diff --git a/tests/test_infer/_utils.py b/tests/test_infer/_utils.py
index 2ddc8b6e68e4..310c214f4533 100644
--- a/tests/test_infer/_utils.py
+++ b/tests/test_infer/_utils.py
@@ -19,7 +19,6 @@ def build_model(
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused,
- inference_only=True,
)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py
deleted file mode 100644
index ba978ad9bf0d..000000000000
--- a/tests/test_infer/test_bloom_infer.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import pytest
-import torch
-from packaging import version
-from transformers import BloomForCausalLM
-from transformers.models.bloom.configuration_bloom import BloomConfig
-
-import colossalai
-from colossalai.inference.tensor_parallel import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-
-TP_SIZE = 2
-MAX_BATCH_SIZE = 4
-MAX_INPUT_LEN = 16
-MAX_OUTPUT_LEN = 32
-
-CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
-
-
-@parameterize(
- "test_config",
- [
- {
- "tp_size": TP_SIZE,
- }
- ],
-)
-def run(test_config):
- bloom_config = BloomConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
- model = BloomForCausalLM(bloom_config)
- model = model.half()
-
- shard_config = ShardConfig(
- enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
- )
- infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
- generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
-
- input_tokens = {
- "input_ids": torch.randint(1, 1000, (MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
- "attention_mask": torch.ones((MAX_BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
- }
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
-
- assert outputs is not None
-
-
-def check_bloom(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run()
-
-
-@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_bloom_infer():
- spawn(check_bloom, TP_SIZE)
-
-
-if __name__ == "__main__":
- test_bloom_infer()
diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py
deleted file mode 100644
index 399b70e1460e..000000000000
--- a/tests/test_infer/test_chatglm2_infer.py
+++ /dev/null
@@ -1,74 +0,0 @@
-import os
-
-import pytest
-import torch
-from packaging import version
-
-import colossalai
-from colossalai.inference.tensor_parallel.engine import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
-from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-TPSIZE = 1
-BATCH_SIZE = 8
-MAX_INPUT_LEN = 12
-MAX_OUTPUT_LEN = 100
-CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
-
-
-@parameterize(
- "test_config",
- [
- {
- "tp_size": TPSIZE,
- }
- ],
-)
-def run_chatglm2_test(test_config):
- chatglm_config = ChatGLMConfig(
- num_layers=2,
- vocab_size=1200,
- use_cache=True,
- multi_query_attention=True,
- multi_query_group_num=2,
- num_attention_heads=8,
- hidden_size=1024,
- )
- model = ChatGLMForConditionalGeneration(chatglm_config)
- model = model.half()
-
- shard_config = ShardConfig(
- enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
- )
- infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
- generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
-
- input_tokens = {
- "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
- "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
- }
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
-
- assert outputs is not None
-
-
-def check_chatglm2(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_chatglm2_test()
-
-
-@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_chatglm2():
- spawn(check_chatglm2, TPSIZE)
-
-
-if __name__ == "__main__":
- test_chatglm2()
diff --git a/tests/test_infer/test_hybrid_bloom.py b/tests/test_infer/test_hybrid_bloom.py
new file mode 100644
index 000000000000..8cad06dca6d9
--- /dev/null
+++ b/tests/test_infer/test_hybrid_bloom.py
@@ -0,0 +1,121 @@
+import importlib.util
+
+import pytest
+import torch
+import torch.distributed as dist
+import transformers
+from packaging import version
+
+import colossalai
+from colossalai.inference import InferenceEngine
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
+
+CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
+HAS_LIGHTLLM_KERNEL = True
+
+if importlib.util.find_spec("lightllm") is None:
+ HAS_LIGHTLLM_KERNEL = False
+
+
+def data_gen():
+ input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+inputs = data_gen()
+for k, v in inputs.items():
+ if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
+ new_shape = [1] * v.dim()
+ new_shape[0] = 16
+ inputs[k] = v.to("cuda").repeat(*new_shape)
+
+
+def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ model = transformers.BloomForCausalLM(
+ transformers.BloomConfig(vocab_size=20000, hidden_size=512, n_head=4, n_layer=4)
+ )
+
+ engine = InferenceEngine(
+ tp_size=tp_size,
+ pp_size=pp_size,
+ model=model,
+ max_output_len=max_output_len,
+ micro_batch_size=micro_batch_size,
+ )
+ output = engine.generate(inputs)
+ if dist.get_rank() == 0:
+ assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
+
+
+@parameterize("tp_size", [1])
+@parameterize("pp_size", [2])
+@parameterize("max_output_len", [4])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+@parameterize("tp_size", [2])
+@parameterize("pp_size", [2])
+@parameterize("max_output_len", [4])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+@parameterize("tp_size", [2])
+@parameterize("pp_size", [1])
+@parameterize("max_output_len", [2])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+@parameterize("tp_size", [1])
+@parameterize("pp_size", [1])
+@parameterize("max_output_len", [2])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+def check_tp_pp_inference(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_tp_pipeline_inference_test()
+
+
+def check_tp_or_pp_inference(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_tp_inference_test()
+ run_pipeline_inference_test()
+
+
+def check_single_inference(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_single_inference_test
+
+
+@pytest.mark.skipif(
+ not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
+ reason="kv-cache manager engine requires cuda version to be higher than 11.5",
+)
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_pipeline_inference():
+ spawn(check_tp_pp_inference, nprocs=4)
+ spawn(check_tp_or_pp_inference, nprocs=2)
+ spawn(check_single_inference, nprocs=1)
+
+
+if __name__ == "__main__":
+ test_pipeline_inference()
diff --git a/tests/test_infer/test_hybrid_chatglm2.py b/tests/test_infer/test_hybrid_chatglm2.py
new file mode 100644
index 000000000000..b53bb25f442f
--- /dev/null
+++ b/tests/test_infer/test_hybrid_chatglm2.py
@@ -0,0 +1,129 @@
+import importlib.util
+
+import pytest
+import torch
+import torch.distributed as dist
+from packaging import version
+
+import colossalai
+from colossalai.inference import InferenceEngine
+from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
+from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
+
+CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
+HAS_LIGHTLLM_KERNEL = True
+
+if importlib.util.find_spec("lightllm") is None:
+ HAS_LIGHTLLM_KERNEL = False
+
+
+def data_gen():
+ input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+inputs = data_gen()
+for k, v in inputs.items():
+ if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
+ new_shape = [1] * v.dim()
+ new_shape[0] = 16
+ inputs[k] = v.to("cuda").repeat(*new_shape)
+
+
+def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ chatglm_config = ChatGLMConfig(
+ num_layers=2,
+ vocab_size=20000,
+ use_cache=True,
+ multi_query_attention=True,
+ multi_query_group_num=2,
+ num_attention_heads=8,
+ hidden_size=1024,
+ )
+ model = ChatGLMForConditionalGeneration(chatglm_config)
+
+ engine = InferenceEngine(
+ tp_size=tp_size,
+ pp_size=pp_size,
+ model=model,
+ max_output_len=max_output_len,
+ micro_batch_size=micro_batch_size,
+ )
+ output = engine.generate(inputs)
+ if dist.get_rank() == 0:
+ assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
+
+
+@parameterize("tp_size", [1])
+@parameterize("pp_size", [2])
+@parameterize("max_output_len", [4])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+@parameterize("tp_size", [2])
+@parameterize("pp_size", [2])
+@parameterize("max_output_len", [4])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+@parameterize("tp_size", [2])
+@parameterize("pp_size", [1])
+@parameterize("max_output_len", [2])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+@parameterize("tp_size", [1])
+@parameterize("pp_size", [1])
+@parameterize("max_output_len", [2])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+def check_tp_pp_inference(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_tp_pipeline_inference_test()
+
+
+def check_tp_or_pp_inference(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_tp_inference_test()
+ run_pipeline_inference_test()
+
+
+def check_single_inference(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_single_inference_test
+
+
+@pytest.mark.skipif(
+ not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
+ reason="kv-cache manager engine requires cuda version to be higher than 11.5",
+)
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_pipeline_inference():
+ spawn(check_tp_pp_inference, nprocs=4)
+ spawn(check_tp_or_pp_inference, nprocs=2)
+ spawn(check_single_inference, nprocs=1)
+
+
+if __name__ == "__main__":
+ test_pipeline_inference()
diff --git a/tests/test_infer/test_hybrid_llama.py b/tests/test_infer/test_hybrid_llama.py
new file mode 100644
index 000000000000..30b8b0a991d0
--- /dev/null
+++ b/tests/test_infer/test_hybrid_llama.py
@@ -0,0 +1,126 @@
+import importlib.util
+
+import pytest
+import torch
+import torch.distributed as dist
+import transformers
+from packaging import version
+
+import colossalai
+from colossalai.inference import InferenceEngine
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
+
+CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
+
+import importlib.util
+
+HAS_LIGHTLLM_KERNEL = True
+
+if importlib.util.find_spec("lightllm") is None:
+ HAS_LIGHTLLM_KERNEL = False
+
+
+def data_gen():
+ input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
+ attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+inputs = data_gen()
+for k, v in inputs.items():
+ if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
+ new_shape = [1] * v.dim()
+ new_shape[0] = 16
+ inputs[k] = v.to("cuda").repeat(*new_shape)
+
+
+def pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ model = transformers.LlamaForCausalLM(
+ transformers.LlamaConfig(
+ vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4
+ )
+ )
+
+ engine = InferenceEngine(
+ tp_size=tp_size,
+ pp_size=pp_size,
+ model=model,
+ max_output_len=max_output_len,
+ micro_batch_size=micro_batch_size,
+ )
+ output = engine.generate(inputs)
+ if dist.get_rank() == 0:
+ assert len(output[0]) == max_output_len, f"{len(output)}, {max_output_len}"
+
+
+@parameterize("tp_size", [1])
+@parameterize("pp_size", [2])
+@parameterize("max_output_len", [4])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+@parameterize("tp_size", [2])
+@parameterize("pp_size", [2])
+@parameterize("max_output_len", [4])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_tp_pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+@parameterize("tp_size", [2])
+@parameterize("pp_size", [1])
+@parameterize("max_output_len", [2])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_tp_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+@parameterize("tp_size", [1])
+@parameterize("pp_size", [1])
+@parameterize("max_output_len", [2])
+@parameterize("micro_batch_size", [1])
+@clear_cache_before_run()
+def run_single_inference_test(tp_size, pp_size, max_output_len, micro_batch_size):
+ pipeline_inference_test(tp_size, pp_size, max_output_len, micro_batch_size)
+ torch.cuda.empty_cache()
+
+
+def check_tp_pp_inference(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_tp_pipeline_inference_test()
+
+
+def check_tp_or_pp_inference(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_tp_inference_test()
+ run_pipeline_inference_test()
+
+
+def check_single_inference(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_single_inference_test
+
+
+@pytest.mark.skipif(
+ not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL,
+ reason="kv-cache manager engine requires cuda version to be higher than 11.5",
+)
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_pipeline_inference():
+ spawn(check_tp_pp_inference, nprocs=4)
+ spawn(check_tp_or_pp_inference, nprocs=2)
+ spawn(check_single_inference, nprocs=1)
+
+
+if __name__ == "__main__":
+ test_pipeline_inference()
diff --git a/tests/test_infer/test_infer_engine.py b/tests/test_infer/test_infer_engine.py
deleted file mode 100644
index f24160820e71..000000000000
--- a/tests/test_infer/test_infer_engine.py
+++ /dev/null
@@ -1,102 +0,0 @@
-from itertools import accumulate
-
-import pytest
-import torch
-from packaging import version
-from transformers import BloomConfig, BloomForCausalLM
-from transformers.tokenization_utils_base import BatchEncoding
-
-import colossalai
-from colossalai.inference.tensor_parallel import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-
-TP_SIZE = 2
-MAX_BATCH_SIZE = 4
-MAX_INPUT_LEN = 16
-MAX_OUTPUT_LEN = 8
-
-CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
-
-
-@parameterize(
- "test_config",
- [
- {
- "tp_size": TP_SIZE,
- }
- ],
-)
-def run(test_config):
- model_config = BloomConfig(num_hidden_layers=4, hidden_size=128, intermediate_size=256, num_attention_heads=4)
- model = BloomForCausalLM(model_config)
- model = model.half()
- model.to(torch.cuda.current_device())
-
- # 1. check TPInferEngine init and model optimization
- shard_config = ShardConfig(
- enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
- )
- infer_engine = TPInferEngine(model, shard_config, MAX_BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
-
- assert infer_engine.cache_manager is not None
- assert infer_engine.tp_size == TP_SIZE
- assert infer_engine.head_num == model_config.num_attention_heads // TP_SIZE
-
- # 2. check data preparation
- input_ids_list = [
- [80540, 15473, 3331, 11970, 90472, 361, 61335],
- [80540, 15473, 3331, 11970],
- [80540, 15473, 3331, 11970],
- [80540, 15473],
- ]
- batch_size = len(input_ids_list)
- max_seq_len = max(len(li) for li in input_ids_list)
- attention_mask = [[0] * max_seq_len for _ in range(batch_size)]
- for i, li in enumerate(input_ids_list):
- attention_mask[i][max_seq_len - len(li) :] = [1 for _ in range(len(li))]
- data = dict(input_ids=input_ids_list, attention_mask=attention_mask)
- inputs_batch_encoding = BatchEncoding(data=data)
- seq_lengths = [len(li) for li in input_ids_list]
- start_loc = list(accumulate([0] + seq_lengths[:-1]))
- seq_lengths = torch.tensor(seq_lengths, dtype=torch.int32)
- start_loc = torch.tensor(start_loc, dtype=torch.int32)
- # input token id list as inputs
- batch_state_out1 = infer_engine.prepare_batch_state(inputs_batch_encoding)
- # BatchEncoding as inputs
- batch_state_out2 = infer_engine.prepare_batch_state(input_ids_list)
-
- assert batch_state_out1.batch_size == batch_state_out2.batch_size == batch_size
- assert torch.equal(batch_state_out1.seq_len, batch_state_out2.seq_len)
-
- # The following tests are discarded for now, and will be reused after all features are added
- # assert torch.equal(batch_state_out1.seq_len.to(seq_lengths.device), seq_lengths)
- # assert torch.equal(batch_state_out2.seq_len.to(seq_lengths.device), seq_lengths)
- # assert torch.equal(batch_state_out1.start_loc.to(start_loc.device), start_loc)
- # assert torch.equal(batch_state_out2.start_loc.to(start_loc.device), start_loc)
-
- # 3. check optimized model generate
- input_ids = torch.randint(low=10, high=1000, size=(MAX_BATCH_SIZE, MAX_INPUT_LEN))
- generate_kwargs = dict(do_sample=False)
- infer_engine.generate(input_ids, **generate_kwargs)
-
- torch.cuda.empty_cache()
-
-
-def check_engine(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run()
-
-
-@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_engine():
- spawn(check_engine, TP_SIZE)
-
-
-if __name__ == "__main__":
- test_engine()
diff --git a/tests/test_infer/test_kvcache_manager.py b/tests/test_infer/test_kvcache_manager.py
index f3e2cdf1e18f..e8765317291a 100644
--- a/tests/test_infer/test_kvcache_manager.py
+++ b/tests/test_infer/test_kvcache_manager.py
@@ -4,7 +4,7 @@
import torch
from packaging import version
-from colossalai.inference.tensor_parallel import MemoryManager
+from colossalai.inference.kv_cache import MemoryManager
from colossalai.logging import disable_existing_loggers
from colossalai.testing import rerun_if_address_is_in_use, spawn
diff --git a/tests/test_infer/test_llama2_infer.py b/tests/test_infer/test_llama2_infer.py
deleted file mode 100644
index 0eebed8892ea..000000000000
--- a/tests/test_infer/test_llama2_infer.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import os
-
-import pytest
-import torch
-from packaging import version
-from transformers import LlamaForCausalLM
-from transformers.models.llama.configuration_llama import LlamaConfig
-
-import colossalai
-from colossalai.inference.tensor_parallel.engine import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-TPSIZE = 2
-BATCH_SIZE = 8
-MAX_INPUT_LEN = 12
-MAX_OUTPUT_LEN = 100
-
-CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
-
-
-@parameterize(
- "test_config",
- [
- {
- "tp_size": TPSIZE,
- }
- ],
-)
-def run_llama_test(test_config):
- llama_config = LlamaConfig(
- num_hidden_layers=2, num_key_value_heads=8, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024
- )
- model = LlamaForCausalLM(llama_config)
- model = model.half()
-
- shard_config = ShardConfig(
- enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
- )
- infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
- generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
-
- input_tokens = {
- "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
- "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
- }
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
-
- assert outputs is not None
-
-
-def check_llama(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_llama_test()
-
-
-@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_llama():
- spawn(check_llama, TPSIZE)
-
-
-if __name__ == "__main__":
- test_llama()
diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py
deleted file mode 100644
index b424525a3719..000000000000
--- a/tests/test_infer/test_llama_infer.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import os
-
-import pytest
-import torch
-from packaging import version
-from transformers import LlamaForCausalLM
-from transformers.models.llama.configuration_llama import LlamaConfig
-
-import colossalai
-from colossalai.inference.tensor_parallel.engine import TPInferEngine
-from colossalai.logging import disable_existing_loggers
-from colossalai.shardformer import ShardConfig
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-
-os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
-TPSIZE = 2
-BATCH_SIZE = 8
-MAX_INPUT_LEN = 12
-MAX_OUTPUT_LEN = 100
-
-CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.5")
-
-
-@parameterize(
- "test_config",
- [
- {
- "tp_size": TPSIZE,
- }
- ],
-)
-def run_llama_test(test_config):
- llama_config = LlamaConfig(num_hidden_layers=2, bos_token_id=0, eos_token_id=1, vocab_size=1200, hidden_size=1024)
- model = LlamaForCausalLM(llama_config)
- model = model.half()
-
- shard_config = ShardConfig(
- enable_tensor_parallelism=True if test_config["tp_size"] > 1 else False, inference_only=True
- )
- infer_engine = TPInferEngine(model, shard_config, BATCH_SIZE, MAX_INPUT_LEN, MAX_OUTPUT_LEN)
- generate_kwargs = dict(max_new_tokens=MAX_OUTPUT_LEN, do_sample=False)
-
- input_tokens = {
- "input_ids": torch.randint(1, 1000, (BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
- "attention_mask": torch.ones((BATCH_SIZE, MAX_INPUT_LEN), device="cuda"),
- }
- outputs = infer_engine.generate(input_tokens, **generate_kwargs)
-
- assert outputs is not None
-
-
-def check_llama(rank, world_size, port):
- disable_existing_loggers()
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_llama_test()
-
-
-@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5")
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_llama():
- spawn(check_llama, TPSIZE)
-
-
-if __name__ == "__main__":
- test_llama()
diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py
deleted file mode 100644
index ad8e32b48bae..000000000000
--- a/tests/test_infer/test_pipeline_infer.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import pytest
-import torch
-import torch.distributed as dist
-import transformers
-
-import colossalai
-from colossalai.inference.pipeline.engine import PPInferEngine
-from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
-from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
-
-
-def data_gen():
- input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
- attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
- return dict(input_ids=input_ids, attention_mask=attention_mask)
-
-
-inputs = data_gen()
-for k, v in inputs.items():
- if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
- new_shape = [1] * v.dim()
- new_shape[0] = 16
- inputs[k] = v.to("cuda").repeat(*new_shape)
-
-
-def pipeline_inference_test(pp_size, new_length, micro_batch_size):
- model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
- engine = PPInferEngine(
- pp_size=pp_size,
- model=model,
- model_policy=GPT2LMHeadModelPipelinePolicy(),
- new_length=new_length,
- micro_batch_size=micro_batch_size,
- )
- output = engine.inference([inputs])
- if dist.get_rank() == 0:
- assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
-
-
-@parameterize("pp_size", [4])
-@parameterize("new_length", [4, 8, 16])
-@parameterize("micro_batch_size", [1, 4])
-@clear_cache_before_run()
-def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
- pipeline_inference_test(pp_size, new_length, micro_batch_size)
- torch.cuda.empty_cache()
-
-
-def check_pipeline_inference(rank, world_size, port):
- colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- run_pipeline_inference_test()
-
-
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-@clear_cache_before_run()
-def test_pipeline_inference():
- spawn(check_pipeline_inference, nprocs=4)
-
-
-if __name__ == "__main__":
- test_pipeline_inference()
diff --git a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py
deleted file mode 100644
index a4d893f8e830..000000000000
--- a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-import pytest
-import torch
-from torch import nn
-
-try:
- from vllm import layernorm_ops
-
- rms_norm = layernorm_ops.rms_norm
- HAS_VLLM_KERNERL = True
-except:
- print("please install vllm kernels to install rmsnorm")
- print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
- HAS_VLLM_KERNERL = False
-
-
-class LlamaRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- LlamaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, hidden_states):
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
-
-
-def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon):
- x = hidden_states
- out = torch.empty_like(x)
- rms_norm(
- out,
- x,
- weight,
- variance_epsilon,
- )
- return out
-
-
-@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
-def test_rmsnorm():
- data = torch.randn((1024, 64), dtype=torch.float16, device="cuda")
- hg_rms = LlamaRMSNorm(64)
- hg_rms = hg_rms.half().cuda()
- out_torch = hg_rms(data)
- out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon)
-
- check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5)
- assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward"
-
-
-if __name__ == "__main__":
- test_rmsnorm()
diff --git a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py
deleted file mode 100644
index 40451ef6636d..000000000000
--- a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py
+++ /dev/null
@@ -1,153 +0,0 @@
-#!/usr/bin/env python
-# -*- encoding: utf-8 -*-
-from typing import Tuple
-
-import pytest
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half
-
-try:
- from vllm import pos_encoding_ops
-
- rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox
- HAS_VLLM_KERNERL = True
-except:
- print("fall back to original rotary_embedding_neox of huggingface")
- print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference")
- HAS_VLLM_KERNERL = False
-
-
-def rotate_half(x: torch.Tensor) -> torch.Tensor:
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(
- q: torch.Tensor,
- k: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-class RefRotaryEmbeddingNeox(nn.Module):
- """Reference implementation of the GPT-NeoX style rotary embedding."""
-
- def __init__(
- self,
- dim: int,
- max_position_embeddings: int = 2048,
- base: int = 10000,
- ) -> None:
- super().__init__()
- self.rotary_dim = dim
- self.max_position_embeddings = max_position_embeddings
-
- # Create cos and sin embeddings.
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim))
- t = torch.arange(max_position_embeddings).float()
- freqs = torch.einsum("i,j->ij", t, inv_freq.float())
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos().to(dtype=inv_freq.dtype)
- sin = emb.sin().to(dtype=inv_freq.dtype)
- self.register_buffer("cos_cached", cos, persistent=False)
- self.register_buffer("sin_cached", sin, persistent=False)
-
- def forward(
- self,
- positions: torch.Tensor, # [num_tokens]
- query: torch.Tensor, # [num_tokens, num_heads, head_size]
- key: torch.Tensor, # [num_tokens, num_heads, head_size]
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- query_rot = query[..., : self.rotary_dim]
- query_pass = query[..., self.rotary_dim :]
- key_rot = key[..., : self.rotary_dim]
- key_pass = key[..., self.rotary_dim :]
-
- query_rot = query_rot.transpose(0, 1)
- key_rot = key_rot.transpose(0, 1)
- cos = F.embedding(positions, self.cos_cached)
- sin = F.embedding(positions, self.sin_cached)
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
- query_rot = query_rot.transpose(0, 1).contiguous()
- key_rot = key_rot.transpose(0, 1).contiguous()
-
- query = torch.cat((query_rot, query_pass), dim=-1)
- key = torch.cat((key_rot, key_pass), dim=-1)
-
- # Output query/key shape: [num_tokens, num_tokens, head_size]
- return query, key
-
-
-def run_rotary_embedding_neox(
- num_tokens: int,
- num_heads: int,
- head_size: int,
- max_position: int,
- rotary_dim: int,
- dtype: torch.dtype,
- base: int = 10000,
-) -> None:
- positions = torch.randint(0, max_position, (num_tokens,), device="cuda")
- query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda")
- key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda")
-
- # Create the rotary embedding.
- inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim))
- t = torch.arange(max_position).float()
- freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
- cos = freqs.cos()
- sin = freqs.sin()
- cos_sin_cache = torch.cat((cos, sin), dim=-1)
- cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
-
- # Run the kernel. The kernel is in-place, so we need to clone the inputs.
- out_query = query.clone()
- out_key = key.clone()
- rotary_embedding_neox(
- positions,
- out_query,
- out_key,
- head_size,
- cos_sin_cache,
- )
-
- # Run the reference implementation.
- ref_rotary_embedding = RefRotaryEmbeddingNeox(
- dim=rotary_dim,
- max_position_embeddings=max_position,
- base=base,
- ).to(dtype=dtype, device="cuda")
- ref_query, ref_key = ref_rotary_embedding(
- positions,
- query.view(num_tokens, num_heads, head_size),
- key.view(num_tokens, num_heads, head_size),
- )
- ref_query = ref_query.view(num_tokens, num_heads * head_size)
- ref_key = ref_key.view(num_tokens, num_heads * head_size)
-
- # Compare the results.
- assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5)
- assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5)
-
-
-@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test")
-def test_rotary_embedding():
- run_rotary_embedding_neox(
- num_tokens=1024,
- num_heads=8,
- head_size=64,
- max_position=8192,
- rotary_dim=64,
- dtype=torch.float16,
- )
-
-
-if __name__ == "__main__":
- test_rotary_embedding()
diff --git a/tests/test_infer_ops/triton/test_llama_act_combine.py b/tests/test_infer_ops/triton/test_llama_act_combine.py
new file mode 100644
index 000000000000..5341aa35ab90
--- /dev/null
+++ b/tests/test_infer_ops/triton/test_llama_act_combine.py
@@ -0,0 +1,56 @@
+import pytest
+import torch
+from packaging import version
+from torch import nn
+
+from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine
+
+try:
+ import triton
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+ print("please install triton from https://github.com/openai/triton")
+TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4')
+
+BATCH_SIZE = 4
+SEQ_LEN = 16
+HIDDEN_SIZE = 32
+
+
+def SwiGLU(x):
+ """Gated linear unit activation function.
+ Args:
+ x : input array
+ axis: the axis along which the split should be computed (default: -1)
+ """
+ size = x.shape[-1]
+ assert size % 2 == 0, "axis size must be divisible by 2"
+ x1, x2 = torch.split(x, size // 2, -1)
+ return x1 * (x2 * torch.sigmoid(x2.to(torch.float32)).to(x.dtype))
+
+
+@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
+def test_llama_act_combine(dtype: str):
+ x_gate = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE * 2, dtype=dtype).cuda()
+ x_gate_torch = nn.Parameter(x_gate.detach().clone())
+ x_gate_kernel = nn.Parameter(x_gate.detach().clone())
+ x_up = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, dtype=dtype).cuda()
+ x_up_torch = nn.Parameter(x_up.detach().clone())
+ x_up_kernel = nn.Parameter(x_up.detach().clone())
+
+ torch_out = SwiGLU(x_gate_torch) * x_up_torch
+ kernel_out = LlamaActCombine.apply(x_gate_kernel, x_up_kernel)
+ atol = 1e-5 if dtype == torch.float32 else 5e-2
+ assert torch.allclose(torch_out, kernel_out, atol=atol)
+
+ torch_out.mean().backward()
+ kernel_out.mean().backward()
+ assert all(grad is not None for grad in [x_gate_torch.grad, x_up_torch.grad, x_gate_kernel.grad, x_up_kernel.grad])
+ assert torch.allclose(x_gate_torch.grad, x_gate_kernel.grad, atol=atol)
+ assert torch.allclose(x_up_torch.grad, x_up_kernel.grad, atol=atol)
+
+
+if __name__ == '__main__':
+ test_llama_act_combine(torch.float16)
diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py
index be6de6db2471..95fe50cf1d9c 100644
--- a/tests/test_infer_ops/triton/test_llama_context_attention.py
+++ b/tests/test_infer_ops/triton/test_llama_context_attention.py
@@ -41,7 +41,6 @@ def test_llama_context_attention():
llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len)
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)
-
assert torch.allclose(
torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3
), "outputs from triton and torch are not matched"
diff --git a/tests/test_infer_ops/triton/test_token_attn_fwd.py b/tests/test_infer_ops/triton/test_token_attn_fwd.py
index a7fc3d29b77a..4ee1a5fb1234 100644
--- a/tests/test_infer_ops/triton/test_token_attn_fwd.py
+++ b/tests/test_infer_ops/triton/test_token_attn_fwd.py
@@ -4,11 +4,20 @@
try:
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
+
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
+
+import importlib.util
+
+HAS_LIGHTLLM_KERNEL = True
+
+if importlib.util.find_spec("lightllm") is None:
+ HAS_LIGHTLLM_KERNEL = False
+
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6")
@@ -25,7 +34,8 @@ def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
@pytest.mark.skipif(
- not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
+ not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_LIGHTLLM_KERNEL,
+ reason="triton requires cuda version to be higher than 11.4 or not install lightllm",
)
def test():
Z, head_num, seq_len, head_dim = 22, 112 // 8, 2048, 128
diff --git a/tests/test_legacy/test_utils/test_memory.py b/tests/test_legacy/test_utils/test_memory.py
index 9416ac86e325..9df7cf75aae5 100644
--- a/tests/test_legacy/test_utils/test_memory.py
+++ b/tests/test_legacy/test_utils/test_memory.py
@@ -3,7 +3,7 @@
import colossalai
from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.testing import spawn
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
diff --git a/tests/test_moe/moe_utils.py b/tests/test_moe/moe_utils.py
new file mode 100644
index 000000000000..721a4796abfd
--- /dev/null
+++ b/tests/test_moe/moe_utils.py
@@ -0,0 +1,90 @@
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+from colossalai.legacy.engine.gradient_handler._base_gradient_handler import BaseGradientHandler
+from colossalai.legacy.engine.gradient_handler.utils import bucket_allreduce
+from colossalai.legacy.registry import GRADIENT_HANDLER
+from colossalai.moe import SparseMLP
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.moe.utils import get_moe_epsize_param_dict
+
+
+class MoeModel(nn.Module):
+ def __init__(self, enable_load_balance: bool = False):
+ class TestSubModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.moe = SparseMLP(
+ num_experts=8, hidden_size=16, intermediate_size=32, enable_load_balance=enable_load_balance
+ )
+ self.proj = nn.Linear(16, 4)
+
+ def forward(self, x):
+ x = self.moe(x)
+ x = self.proj(x)
+ return x
+
+ super().__init__()
+ self.test_embed = nn.Linear(4, 16)
+ self.test_transform = TestSubModule()
+
+ def forward(self, x):
+ MOE_MANAGER.reset_loss()
+
+ x = self.test_embed(x)
+ x = self.test_transform(x)
+
+ return x
+
+
+@GRADIENT_HANDLER.register_module
+class MoeGradientHandler(BaseGradientHandler):
+ """A helper class to handle all-reduce operations in a data parallel group and
+ moe model parallel. A all-reduce collective communication will be operated in
+ :func:`handle_gradient` among a data parallel group.
+ For better performance, it bucketizes the gradients of all parameters that are
+ the same type to improve the efficiency of communication.
+
+ Args:
+ model (Module): Model where the gradients accumulate.
+ optimizer (Optimizer): Optimizer for updating the parameters.
+ """
+
+ def __init__(self, model, optimizer=None):
+ super().__init__(model, optimizer)
+
+ def handle_gradient(self):
+ """A method running an all-reduce operation in a data parallel group.
+ Then running an all-reduce operation for all parameters in experts
+ across moe model parallel group
+ """
+ if dist.get_world_size() > 1:
+ epsize_param_dict = get_moe_epsize_param_dict(self._model)
+
+ # epsize is 1, indicating the params are replicated among processes in data parallelism
+ # use the ParallelMode.DATA to get data parallel group
+ # reduce gradients for all parameters in data parallelism
+ if 1 in epsize_param_dict:
+ bucket_allreduce(param_list=epsize_param_dict[1])
+
+ for ep_size in epsize_param_dict:
+ if ep_size != 1 and ep_size != MOE_MANAGER.world_size:
+ bucket_allreduce(
+ param_list=epsize_param_dict[ep_size], group=MOE_MANAGER.parallel_info_dict[ep_size].dp_group
+ )
+
+
+def assert_not_equal_in_group(tensor, process_group=None):
+ # all gather tensors from different ranks
+ world_size = dist.get_world_size(process_group)
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
+ dist.all_gather(tensor_list, tensor, group=process_group)
+
+ # check if they are equal one by one
+ for i in range(world_size - 1):
+ a = tensor_list[i]
+ b = tensor_list[i + 1]
+ assert not torch.allclose(a, b), \
+ (f"expected tensors on rank {i} and {i + 1} not to be equal "
+ f"but they are, {a} vs {b}")
diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py
index 8742e5f41136..3fac624729db 100644
--- a/tests/test_moe/test_grad_handler.py
+++ b/tests/test_moe/test_grad_handler.py
@@ -4,40 +4,48 @@
import torch.nn as nn
import colossalai
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
-from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, UniformNoiseGenerator
+from colossalai.moe import SparseMLP
+from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
-from colossalai.utils.moe import sync_moe_model_param
+from tests.test_moe.moe_utils import MoeGradientHandler
BATCH_SIZE = 4
DIM = 16
-CONFIG = dict()
def run_test(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- expert_module = nn.Linear
- expert_factor = dict(in_features=DIM, out_features=DIM, device=get_current_device())
-
- MOE_CONTEXT.setup(42) # MOE initialization
- noisy_func = UniformNoiseGenerator()
- router = Top1Router(noisy_func=noisy_func)
+ colossalai.launch(
+ config=dict(),
+ rank=rank,
+ world_size=world_size,
+ host="localhost",
+ port=port,
+ backend="nccl",
+ )
+
+ MOE_MANAGER.setup(parallel="EP") # MOE initialization
num_experts_list = [1, 2, 4]
layer_list = []
for num_experts in num_experts_list:
- exp = Experts(expert_module, num_experts, **expert_factor)
- moe_layer = MoeLayer(DIM, num_experts, router, exp)
+ moe_layer = SparseMLP(
+ hidden_size=DIM,
+ intermediate_size=DIM * 4,
+ num_experts=num_experts,
+ router_top_k=1,
+ router_noisy_policy="Jitter",
+ )
layer_list.append(moe_layer)
model = nn.ModuleList(layer_list)
model = model.to(get_current_device())
- sync_moe_model_param(model)
-
- dist_dict = MOE_CONTEXT.parallel_info_dict
- assert_equal_in_group(layer_list[0].experts.experts[0].weight.data, dist_dict[1].dp_group)
- assert_equal_in_group(layer_list[1].experts.experts[0].weight.data, dist_dict[2].dp_group)
+ dist_dict = MOE_MANAGER.parallel_info_dict
+ assert_equal_in_group(layer_list[0].experts.wi.data, dist_dict[1].dp_group)
+ assert_equal_in_group(layer_list[0].experts.wo.data, dist_dict[1].dp_group)
+ assert_equal_in_group(layer_list[1].experts.wi.data, dist_dict[2].dp_group)
+ assert_equal_in_group(layer_list[1].experts.wo.data, dist_dict[2].dp_group)
+ assert_equal_in_group(layer_list[2].experts.wi.data, dist_dict[4].dp_group)
+ assert_equal_in_group(layer_list[2].experts.wo.data, dist_dict[4].dp_group)
# MoE model synchronization passed
grad_handler = MoeGradientHandler(model, 0)
@@ -47,17 +55,18 @@ def run_test(rank, world_size, port):
data = torch.randn(BATCH_SIZE, DIM, device=get_current_device())
grad = torch.randn_like(data)
- MOE_CONTEXT.reset_loss()
+ MOE_MANAGER.reset_loss()
for layer in layer_list:
- data, _ = layer(data)
+ data = layer(data)
data.backward(grad)
grad_handler.handle_gradient()
- assert_equal_in_group(layer_list[0].experts.experts[0].weight.grad, dist_dict[1].dp_group)
- assert_equal_in_group(layer_list[0].experts.experts[0].bias.grad, dist_dict[1].dp_group)
-
- assert_equal_in_group(layer_list[1].experts.experts[0].weight.grad, dist_dict[2].dp_group)
- assert_equal_in_group(layer_list[1].experts.experts[0].bias.grad, dist_dict[2].dp_group)
+ assert_equal_in_group(layer_list[0].experts.wi.grad, dist_dict[1].dp_group)
+ assert_equal_in_group(layer_list[0].experts.wo.grad, dist_dict[1].dp_group)
+ assert_equal_in_group(layer_list[1].experts.wi.grad, dist_dict[2].dp_group)
+ assert_equal_in_group(layer_list[1].experts.wo.grad, dist_dict[2].dp_group)
+ assert_equal_in_group(layer_list[2].experts.wi.grad, dist_dict[4].dp_group)
+ assert_equal_in_group(layer_list[2].experts.wo.grad, dist_dict[4].dp_group)
# MoE grad handler test passed
diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py
index 7a9c551d679d..255ec7444a2c 100644
--- a/tests/test_moe/test_kernel.py
+++ b/tests/test_moe/test_kernel.py
@@ -1,49 +1,49 @@
import pytest
import torch
-import torch.nn as nn
+import torch.distributed as dist
import colossalai
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.legacy.context import ParallelMode
-from colossalai.legacy.core import global_context as gpc
-from colossalai.nn.layer.moe import Experts, MoeLayer, Top1Router, Top2Router
+from colossalai.moe import SparseMLP
+from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
-BATCH_SIZE = 16
+BATCH_SIZE = 4
NUM_EXPERTS = 4
-CONFIG = dict()
def check_equal(tensor_a, tensor_b, atol=1e-06):
assert torch.allclose(tensor_a, tensor_b, rtol=0, atol=atol) is True
-def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, router=Top2Router):
+def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.float32, topk=1):
# Here we do not need TF32, since it brings absolute error on results
torch.backends.cuda.matmul.allow_tf32 = False
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- local_rank = gpc.get_local_rank(ParallelMode.GLOBAL)
+ colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ local_rank = dist.get_rank()
- MOE_CONTEXT.setup(42) # MOE environment initialization
- MOE_CONTEXT.reset_loss()
+ MOE_MANAGER.setup(parallel="EP") # MOE environment initialization
+ MOE_MANAGER.reset_loss()
torch.manual_seed(rs + local_rank) # set each process has different random seed
# get randomized data
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
- expert_module = nn.Linear
- expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
- expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
- layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
+ layer = SparseMLP(
+ hidden_size=hidden_size,
+ intermediate_size=hidden_size * 2,
+ num_experts=NUM_EXPERTS,
+ router_top_k=topk,
+ router_capacity_factor_train=1.0,
+ )
layer = layer.to(get_current_device())
if data_type == torch.float16:
layer = layer.half()
# use matrix multiplication instead of COL_MOE_KERNEL in MOE dispatch and combine
- layer.use_kernel = False
- old_out, _ = layer(tokens)
+ layer.enable_kernel = False
+ old_out = layer(tokens)
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
old_out.backward(grad) # get gradient
@@ -56,8 +56,8 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
tokens.grad.zero_()
layer.gate_weight.grad.zero_()
- layer.use_kernel = True
- new_out, _ = layer(tokens) # get outputs through colossal kernel
+ layer.enable_kernel = True
+ new_out = layer(tokens) # get outputs through colossal kernel
if data_type == torch.float32:
check_equal(old_out, new_out)
@@ -86,11 +86,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
@pytest.mark.parametrize("rs", [131])
@pytest.mark.parametrize("hidden_size", [32, 144])
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16])
-@pytest.mark.parametrize("router", [Top1Router, Top2Router])
+@pytest.mark.parametrize("topk", [1, 2])
@rerun_if_address_is_in_use()
-def test_moe_kernel(rs, hidden_size, data_type, router):
- spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, router=router)
+def test_moe_kernel(rs, hidden_size, data_type, topk):
+ spawn(run_routing, 4, rs=rs, hidden_size=hidden_size, data_type=data_type, topk=topk)
if __name__ == "__main__":
- test_moe_kernel(2, 256, torch.float16, Top2Router)
+ test_moe_kernel(2, 256, torch.float16, 2)
diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py
index b7024f32b1cf..bd1103df30d3 100644
--- a/tests/test_moe/test_moe_checkpoint.py
+++ b/tests/test_moe/test_moe_checkpoint.py
@@ -1,50 +1,219 @@
+import importlib
import os
+import shutil
+import sys
import pytest
import torch
import torch.distributed as dist
+from transformers.models.llama import LlamaConfig
import colossalai
-from colossalai.context import MOE_CONTEXT
-from colossalai.nn.layer.moe import load_moe_model, save_moe_model
-from colossalai.testing import rerun_if_address_is_in_use, spawn
+from colossalai.booster import Booster
+from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.testing import DummyDataloader, check_state_dict_equal, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext
-from tests.test_moe.test_moe_zero_init import MoeModel
-from tests.test_zero.test_legacy.common import CONFIG
+sys.path.append(
+ os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
+ "examples/language/openmoe",
+ )
+)
-def exam_moe_checkpoint():
- with ColoInitContext(device=get_current_device()):
- model = MoeModel(checkpoint=True)
- save_moe_model(model, "temp_path.pth")
+OpenMoeForCausalLM = importlib.import_module("model.modeling_openmoe").OpenMoeForCausalLM
+set_openmoe_args = importlib.import_module("model.modeling_openmoe").set_openmoe_args
+OpenMoeForCausalLMPolicy = importlib.import_module("model.openmoe_policy").OpenMoeForCausalLMPolicy
- with ColoInitContext(device=get_current_device()):
- other_model = MoeModel(checkpoint=True)
- load_moe_model(other_model, "temp_path.pth")
- state_0 = model.state_dict()
- state_1 = other_model.state_dict()
- for k, v in state_0.items():
- u = state_1.get(k)
- assert torch.equal(u.data, v.data)
+def data_gen_fn(batch_size: int = 2, max_length: int = 4, vocab_size: int = 20):
+ input_ids = torch.randint(0, vocab_size, (batch_size, max_length), device=get_current_device())
+ attention_mask = torch.ones_like(input_ids)
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "labels": input_ids,
+ }
+
+
+def run_fwd_bwd(
+ model, data, label, criterion, optimizer, enable_autocast=False, pipeline=False, booster=None, plugin=None
+):
+ model.train()
+ if pipeline:
+ train_dataloader_iter = DummyDataloader(data_gen_fn, length=1)
+ is_pp_last_stage = booster.plugin.stage_manager.is_last_stage()
+ y = booster.execute_pipeline(
+ train_dataloader_iter,
+ model,
+ lambda x, y: x.loss,
+ optimizer,
+ return_loss=True,
+ return_outputs=True,
+ )
+ # Backward and optimize
+ if is_pp_last_stage:
+ loss = y["loss"]
+ else:
+ if criterion:
+ y = model(data).logits
+ loss = criterion(y)
+ else:
+ loss = model(data, label)
+ loss = loss.float()
+
+ if optimizer is not None:
+ optimizer.backward(loss)
+ else:
+ loss.backward()
+ return y
+
+
+def get_config():
+ config = LlamaConfig(
+ vocab_size=300,
+ hidden_size=16,
+ intermediate_size=32,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ head_dim=4,
+ dropout_rate=0.0,
+ hidden_act="swiglu",
+ )
+ set_openmoe_args(config, num_experts=8, moe_layer_interval=1)
+ return config
+
+
+def get_model(parallel):
+ config = get_config()
+ model = OpenMoeForCausalLM(config)
+ optim = torch.optim.Adam(model.parameters())
+
+ if parallel == None:
+ plugin = MoeHybridParallelPlugin(
+ precision="bf16",
+ tp_size=1,
+ pp_size=1,
+ zero_stage=2,
+ custom_policy=OpenMoeForCausalLMPolicy(),
+ )
+ elif parallel == "ep":
+ plugin = MoeHybridParallelPlugin(
+ precision="bf16",
+ tp_size=1,
+ pp_size=1,
+ zero_stage=2,
+ custom_policy=OpenMoeForCausalLMPolicy(),
+ )
+ elif parallel == "ep_zero":
+ plugin = MoeHybridParallelPlugin(
+ precision="bf16",
+ tp_size=1,
+ pp_size=1,
+ zero_stage=2,
+ extra_dp_size=2,
+ custom_policy=OpenMoeForCausalLMPolicy(),
+ )
+ elif parallel == "hybrid":
+ plugin = MoeHybridParallelPlugin(
+ precision="bf16",
+ tp_size=1,
+ pp_size=2,
+ zero_stage=1,
+ microbatch_size=1,
+ custom_policy=OpenMoeForCausalLMPolicy(),
+ )
+ booster = Booster(plugin=plugin)
+ model, optim, _, _, _ = booster.boost(model=model, optimizer=optim)
+ return model, booster, optim
+
+
+def _test_moe_checkpoint(rank, parallel):
+ if parallel == None:
+ MOE_MANAGER.setup(
+ parallel=None,
+ )
+ elif parallel == "ep":
+ MOE_MANAGER.setup(
+ parallel="EP",
+ )
+ elif parallel == "ep_zero":
+ MOE_MANAGER.setup(
+ parallel="EP",
+ max_ep_size=2,
+ )
+ elif parallel == "hybrid":
+ MOE_MANAGER.setup(
+ parallel="EP",
+ mode="fixed",
+ fixed_dp_size=1,
+ fixed_ep_size=2,
+ fixed_pp_size=2,
+ )
+ model1, booster1, optim1 = get_model(parallel)
+ model2, booster2, optim2 = get_model(parallel)
+ model3, booster3, optim3 = get_model(parallel)
+
+ # param ckpt
+ # shard
+ booster1.save_model(model1, "./tmp_ckpt1", shard=True, size_per_shard=1)
+ booster2.load_model(model2, "./tmp_ckpt1")
+ # unshard
+ booster1.save_model(model1, "./tmp_ckpt1.pth")
+ booster3.load_model(model3, "./tmp_ckpt1.pth")
+ # check
+ check_state_dict_equal(model1.state_dict(), model2.state_dict(), False)
+ check_state_dict_equal(model1.state_dict(), model3.state_dict(), False)
+
+ # optim ckpt
+ criterion = lambda x: x.mean()
+ data = torch.randint(0, 4, (2, 4)).cuda()
+ label = torch.randint(0, 4, (2,)).cuda()
+ if parallel == "hybrid":
+ kwargs = {"pipeline": True, "booster": booster1, "plugin": booster1.plugin}
+ else:
+ kwargs = {}
+ run_fwd_bwd(model1, data, label, criterion, optim1, **kwargs)
+ optim1.step()
+ optim1.zero_grad()
+ # shard
+ booster1.save_optimizer(optim1, "./tmp_ckpt2", shard=True, size_per_shard=1)
+ dist.barrier()
+ booster2.load_optimizer(optim2, "./tmp_ckpt2")
+ # unshard
+ booster1.save_optimizer(optim1, "./tmp_ckpt2.pth")
+ booster3.load_optimizer(optim3, "./tmp_ckpt2.pth")
+ # check
+ check_state_dict_equal(optim1.optim.state_dict(), optim2.optim.state_dict(), False)
+ check_state_dict_equal(optim1.optim.state_dict(), optim3.optim.state_dict(), False)
if dist.get_rank() == 0:
- os.remove("temp_path.pth")
+ shutil.rmtree("./tmp_ckpt1")
+ shutil.rmtree("./tmp_ckpt2")
+ os.remove("./tmp_ckpt1.pth")
+ os.remove("./tmp_ckpt2.pth")
-def _run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- MOE_CONTEXT.setup(seed=42)
- exam_moe_checkpoint()
+def _run_dist(rank, world_size, port, parallel):
+ colossalai.launch(
+ config=dict(),
+ rank=rank,
+ world_size=world_size,
+ host="localhost",
+ port=port,
+ backend="nccl",
+ )
+ _test_moe_checkpoint(rank, parallel)
@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2, 4])
+@pytest.mark.parametrize("world_size", [4])
+@pytest.mark.parametrize("parallel", [None, "ep", "ep_zero", "hybrid"])
@rerun_if_address_is_in_use()
-def test_moe_checkpoint(world_size):
- spawn(_run_dist)
+def test_moe_checkpoint(world_size, parallel):
+ spawn(_run_dist, world_size, parallel=parallel)
if __name__ == "__main__":
- test_moe_checkpoint(world_size=4)
+ test_moe_checkpoint(world_size=4, parallel="hybrid")
diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py
deleted file mode 100644
index 488573b733b1..000000000000
--- a/tests/test_moe/test_moe_colo_init.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import pytest
-import torch
-import torch.distributed as dist
-
-import colossalai
-from colossalai.context import MOE_CONTEXT
-from colossalai.tensor import ColoParameter
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext
-from tests.test_moe.test_moe_zero_init import MoeModel
-from tests.test_zero.test_legacy.common import CONFIG
-
-
-@parameterize("init_device_type", ["cpu", "cuda"])
-def exam_moe_colo_init(init_device_type):
- world_size = dist.get_world_size()
-
- if init_device_type == "cuda":
- init_device = get_current_device()
- elif init_device_type == "cpu":
- init_device = torch.device("cpu")
- else:
- raise NotImplementedError("Unknown device found.")
-
- with ColoInitContext(device=init_device):
- model = MoeModel(checkpoint=True)
-
- for name, param in model.named_parameters():
- assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name)
-
- if hasattr(param, "moe_info"):
- param.set_process_group(param.moe_info.pg)
-
- if hasattr(param, "moe_info"):
- assert param.process_group.dp_world_size() == param.moe_info.dp_size
- else:
- assert param.process_group.dp_world_size() == world_size
-
-
-def _run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- MOE_CONTEXT.setup(seed=42)
- exam_moe_colo_init()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [4])
-@rerun_if_address_is_in_use()
-def test_moe_colo_init(world_size):
- spawn(_run_dist, world_size)
-
-
-if __name__ == "__main__":
- test_moe_colo_init(world_size=4)
diff --git a/tests/test_moe/test_moe_ep_tp.py b/tests/test_moe/test_moe_ep_tp.py
new file mode 100644
index 000000000000..f87d4c792155
--- /dev/null
+++ b/tests/test_moe/test_moe_ep_tp.py
@@ -0,0 +1,232 @@
+import os
+import warnings
+from typing import Dict
+
+import pytest
+import torch
+import torch.distributed as dist
+
+import colossalai
+from colossalai.moe import SparseMLP
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.moe.utils import sync_moe_model_param
+from colossalai.tensor.moe_tensor.api import get_ep_group, get_ep_rank, get_ep_size, is_moe_tensor
+from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
+from colossalai.utils import get_current_device
+from tests.test_moe.moe_utils import MoeGradientHandler
+
+
+def sync_tp_from_local(tp_model: SparseMLP, local_model: SparseMLP, assert_grad_flag: bool = False) -> None:
+ """Sync the parameters of tp model from local model
+
+ Args:
+ tp_model (MoeModule)
+ local_model (MoeModule)
+ """
+ for (tp_name, tp_param), (local_name, local_param) in \
+ zip(tp_model.named_parameters(), local_model.named_parameters()):
+ assert tp_name == local_name
+ if not is_moe_tensor(tp_param):
+ if assert_grad_flag:
+ assert torch.allclose(tp_param, local_param)
+ assert torch.allclose(tp_param.grad, local_param.grad)
+ else:
+ tp_param.data.copy_(local_param.data)
+ continue
+
+ tp_rank = get_ep_rank(tp_param)
+ tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape, local_param.shape)) if d1 != d2][0]
+ tp_slice = [slice(None)] * tp_dim + [
+ slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
+ ]
+
+ if assert_grad_flag:
+ assert torch.allclose(tp_param, local_param[tuple(tp_slice)])
+ assert torch.allclose(tp_param.grad, local_param.grad[tuple(tp_slice)])
+ else:
+ tp_param.data.copy_(local_param[tuple(tp_slice)].data)
+
+
+def sync_tp_from_ep(tp_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
+ """Sync the parameters of tp model from ep model
+
+ Args:
+ tp_model (MoeModule)
+ ep_model (MoeModule)
+ """
+ for (tp_name, tp_param), (ep_name, ep_param) in \
+ zip(tp_model.named_parameters(), ep_model.named_parameters()):
+ assert tp_name == ep_name
+ if not is_moe_tensor(tp_param):
+ if assert_grad_flag:
+ assert torch.allclose(tp_param, ep_param)
+ assert torch.allclose(tp_param.grad, ep_param.grad)
+ else:
+ tp_param.data.copy_(ep_param.data)
+ continue
+
+ # gather param from ep model
+ param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
+ all_param = torch.cat(param_list, dim=0)
+ if assert_grad_flag:
+ grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
+ all_grad = torch.cat(grad_list, dim=0)
+
+ # get tp param
+ tp_dim = [i for i, (d1, d2) in enumerate(zip(tp_param.shape[1:], all_param.shape[1:])) if d1 != d2][0] + 1
+ tp_rank = get_ep_rank(tp_param)
+ tp_slice = [slice(None)] * tp_dim + [
+ slice(tp_param.shape[tp_dim] * tp_rank, tp_param.shape[tp_dim] * (tp_rank + 1))
+ ]
+ new_tp_param = all_param[tuple(tp_slice)]
+ if assert_grad_flag:
+ new_grad = all_grad[tuple(tp_slice)]
+ if assert_grad_flag:
+ assert torch.allclose(tp_param, new_tp_param)
+ assert torch.allclose(tp_param.grad, new_grad)
+ else:
+ tp_param.data.copy_(new_tp_param.data)
+
+
+def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_flag: bool = False) -> None:
+ """Sync the parameters of tp model from ep model
+
+ Args:
+ local_model (MoeModule)
+ ep_model (MoeModule)
+ """
+ for (local_name, local_param), (ep_name, ep_param) in \
+ zip(local_model.named_parameters(), ep_model.named_parameters()):
+ assert local_name == ep_name
+ if "experts" not in local_name:
+ if assert_grad_flag:
+ assert torch.allclose(local_param, ep_param)
+ assert torch.allclose(local_param.grad, ep_param.grad)
+ else:
+ local_param.data.copy_(ep_param.data)
+ continue
+
+ # gather param from ep model
+ param_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(param_list, ep_param, group=get_ep_group(ep_param))
+ all_param = torch.cat(param_list, dim=0)
+ if assert_grad_flag:
+ grad_list = [torch.zeros_like(ep_param) for _ in range(get_ep_size(ep_param))]
+ dist.all_gather(grad_list, ep_param.grad, group=get_ep_group(ep_param))
+ all_grad = torch.cat(grad_list, dim=0)
+
+ if assert_grad_flag:
+ assert torch.allclose(local_param, all_param)
+ assert torch.allclose(local_param.grad, all_grad)
+ else:
+ local_param.data.copy_(all_param.data)
+
+
+def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
+ assert batch_size % world_size == 0
+
+ colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel=None)
+ local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel="EP")
+ enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
+ if enable_hierarchical_comm:
+ os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
+ ep_model = SparseMLP(
+ num_experts=num_experts,
+ hidden_size=dim,
+ intermediate_size=dim * 2,
+ enable_hierarchical_comm=enable_hierarchical_comm
+ )
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel="TP")
+ tp_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
+ ep_model = ep_model.to(get_current_device())
+ tp_model = tp_model.to(get_current_device())
+ local_model = local_model.to(get_current_device())
+
+ # sync ep param
+ sync_moe_model_param(ep_model)
+ dist_dict = MOE_MANAGER.parallel_info_dict
+ assert_equal_in_group(ep_model.experts.wi.data, dist_dict[world_size].dp_group)
+ assert_equal_in_group(ep_model.experts.wo.data, dist_dict[world_size].dp_group)
+ ep_grad_handler = MoeGradientHandler(ep_model)
+ # sync local param
+ sync_local_from_ep(local_model, ep_model)
+ # sync tp param
+ sync_tp_from_ep(tp_model, ep_model)
+ tp_grad_handler = MoeGradientHandler(tp_model)
+
+ rank = dist.get_rank()
+ input_data = torch.randn(batch_size, dim, device=get_current_device())
+ micro_batch_size = batch_size // world_size
+ index = rank * micro_batch_size
+ # NOTE: ep & tp takes in sharded data for each process
+ shard_data = input_data.detach()[index:index + micro_batch_size]
+
+ out_local = local_model(input_data)
+ MOE_MANAGER.reset_loss()
+ out_tp = tp_model(shard_data)
+ MOE_MANAGER.reset_loss()
+ out_ep = ep_model(shard_data)
+ MOE_MANAGER.reset_loss()
+
+ assert torch.allclose(out_tp, out_ep, atol=1e-6), \
+ f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_tp - out_ep))}"
+ try:
+ out_local_slice = out_local[index:index + micro_batch_size]
+ assert torch.allclose(out_ep, out_local_slice, atol=1e-6), \
+ f"Rank {rank} failed, max diff: {torch.max(torch.abs(out_ep - out_local_slice))}"
+ except AssertionError as e:
+ """
+ e.g., in local model, tokens = 4, capacity = 2, experts = 2, topk = 1
+ router yields [01] --> [0], [23] --> [1], this is valid as capacity is 2
+ However, in ep mode, there are 2 separate routers dealing with sharded data.
+ Assume router 0 handles token [01] and router 1 handles token [23].
+ Note that for each router the capacity is only 1 !!!
+ Thus, router 0 may yields [0] --> [0] or [1] --> [0], but not both.
+ The same thing happens on router 1. And finally some tokens are dropped due to the sharded nature.
+ """
+ warnings.warn(
+ "EP & TP may result in different behavior from local model. "
+ "Please check the comments for details."
+ )
+
+ out_local.mean().backward()
+ out_tp.mean().backward()
+ tp_grad_handler.handle_gradient()
+ out_ep.mean().backward()
+ ep_grad_handler.handle_gradient()
+
+ assert_equal_in_group(ep_model.experts.wi.grad, dist_dict[world_size].dp_group)
+ assert_equal_in_group(ep_model.experts.wo.grad, dist_dict[world_size].dp_group)
+ sync_tp_from_ep(tp_model, ep_model, assert_grad_flag=True)
+ try:
+ sync_local_from_ep(local_model, ep_model, assert_grad_flag=True)
+ except AssertionError as e:
+ warnings.warn(
+ "EP & TP may result in different behavior from local model. "
+ "Please check the comments for details."
+ )
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("num_experts", [4, 64])
+@pytest.mark.parametrize("batch_size", [16])
+@pytest.mark.parametrize("dim", [64])
+@pytest.mark.parametrize("config", [
+ {"enable_hierarchical_comm": False},
+ {"enable_hierarchical_comm": True},
+])
+@rerun_if_address_is_in_use()
+def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
+ spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
+
+
+if __name__ == '__main__':
+ test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py
index 300fb6c99b7b..95c0e715dc34 100644
--- a/tests/test_moe/test_moe_group.py
+++ b/tests/test_moe/test_moe_group.py
@@ -3,66 +3,80 @@
import torch.nn as nn
import colossalai
-from colossalai.context.moe_context import MOE_CONTEXT
-from colossalai.nn.layer.moe import Experts
+from colossalai.moe.experts import MLPExperts
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.moe.utils import sync_moe_model_param
from colossalai.testing import assert_equal_in_group, rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
-from colossalai.utils.moe import sync_moe_model_param
-D_MODEL = 4
-D_FF = 8
-CONFIG = dict()
-
-
-def run_test(rank, world_size, port):
- world_size = 4
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- expert_module = nn.Linear
- expert_factor = dict(in_features=D_MODEL, out_features=D_FF, device=get_current_device())
-
- MOE_CONTEXT.setup(42) # MOE environment initialization
- exp0 = Experts(expert_module, 1, **expert_factor)
- exp1 = Experts(expert_module, 2, **expert_factor)
- exp2 = Experts(expert_module, 4, **expert_factor)
- exp3 = Experts(expert_module, 8, **expert_factor)
-
- assert exp0.num_local_experts == 1
- assert exp1.num_local_experts == 1
- assert exp2.num_local_experts == 1
- assert exp3.num_local_experts == 2
- # experts deployment passed
-
- parallel_info_dict = MOE_CONTEXT.parallel_info_dict
+HIDDEN_SIZE = 4
+INTERMEDIATE_SIZE = 8
+
+
+def run_moe_init(expert_parallel):
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel=expert_parallel)
+ expert_args = dict(
+ hidden_size=HIDDEN_SIZE,
+ intermediate_size=INTERMEDIATE_SIZE,
+ expert_parallel=expert_parallel,
+ )
+ exp0 = MLPExperts(1, **expert_args)
+ exp1 = MLPExperts(2, **expert_args)
+ exp2 = MLPExperts(4, **expert_args)
+
+ if expert_parallel == "EP":
+ assert exp0.num_local_experts == 1
+ assert exp1.num_local_experts == 1
+ assert exp2.num_local_experts == 2
+ else:
+ assert exp0.num_local_experts == 1
+ assert exp1.num_local_experts == 2
+ assert exp2.num_local_experts == 4
+
+ parallel_info_dict = MOE_MANAGER.parallel_info_dict
rank = dist.get_rank()
- assert len(parallel_info_dict) == 3
- assert dist.get_rank(parallel_info_dict[4].ep_group) == rank
+ # group creation assert
+ assert len(parallel_info_dict) == 2
assert dist.get_rank(parallel_info_dict[2].ep_group) == rank % 2
assert dist.get_rank(parallel_info_dict[1].ep_group) == 0
- assert dist.get_rank(parallel_info_dict[4].dp_group) == 0
assert dist.get_rank(parallel_info_dict[2].dp_group) == rank // 2
assert dist.get_rank(parallel_info_dict[1].dp_group) == rank
- # group creation passed
- model = nn.ModuleList([exp0, exp1, exp2, exp3])
+ model = nn.ModuleList([exp0, exp1, exp2])
model = model.to(get_current_device())
sync_moe_model_param(model)
- assert_equal_in_group(exp0.experts[0].weight.data, parallel_info_dict[1].dp_group)
- assert_equal_in_group(exp0.experts[0].bias.data, parallel_info_dict[1].dp_group)
# MOE experts layout success when ep_size = 1
+ assert_equal_in_group(exp0.wi.data, parallel_info_dict[1].dp_group)
+ assert_equal_in_group(exp0.wo.data, parallel_info_dict[1].dp_group)
- assert_equal_in_group(exp1.experts[0].weight.data, parallel_info_dict[2].dp_group)
- assert_equal_in_group(exp1.experts[0].bias.data, parallel_info_dict[2].dp_group)
# MOE experts layout success when ep_size = 2
+ assert_equal_in_group(exp1.wi.data, parallel_info_dict[2].dp_group)
+ assert_equal_in_group(exp1.wo.data, parallel_info_dict[2].dp_group)
+
+
+def _run_test(rank, world_size, port, expert_parallel):
+ colossalai.launch(
+ config=dict(),
+ rank=rank,
+ world_size=world_size,
+ host="localhost",
+ port=port,
+ backend="nccl",
+ )
+ run_moe_init(expert_parallel)
@pytest.mark.dist
+@pytest.mark.parametrize("expert_parallel", ["EP", "TP"])
@rerun_if_address_is_in_use()
-def test_moe_initialization():
- spawn(run_test, 4)
+def test_moe_initialization(expert_parallel):
+ spawn(_run_test, 2, expert_parallel=expert_parallel)
if __name__ == "__main__":
- test_moe_initialization()
+ test_moe_initialization("EP")
+ test_moe_initialization("TP")
diff --git a/tests/test_moe/test_moe_hybrid_zero.py b/tests/test_moe/test_moe_hybrid_zero.py
new file mode 100644
index 000000000000..7ada4090fb47
--- /dev/null
+++ b/tests/test_moe/test_moe_hybrid_zero.py
@@ -0,0 +1,97 @@
+import pytest
+import torch
+import torch.distributed as dist
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import LowLevelZeroPlugin
+from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+from tests.test_moe.moe_utils import MoeModel
+
+
+def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
+ model.train()
+ with torch.cuda.amp.autocast(enabled=enable_autocast):
+ if criterion:
+ y = model(data)
+ loss = criterion(y, label)
+ else:
+ loss = model(data, label)
+ loss = loss.float()
+
+ if isinstance(model, LowLevelZeroModel):
+ optimizer.backward(loss / 2)
+ else:
+ loss.backward()
+ return y
+
+
+def run_zero_optim_test(local_rank, world_size, stage=1):
+ criterion = torch.nn.CrossEntropyLoss()
+ data = torch.randn(16, 4).cuda()
+ label = torch.randint(0, 4, (16,)).cuda()
+
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel=None)
+ torch_model = MoeModel()
+ torch_optimizer = torch.optim.Adam(torch_model.parameters())
+ torch_model = torch_model.cuda()
+
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(max_ep_size=2, use_ep_inside=False, parallel="EP")
+ zero_model = MoeModel()
+ extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group
+ ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group)
+ ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size
+ for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
+ if is_moe_tensor(zero_param):
+ num_expert = torch_param.data.shape[0]
+ zero_param.data.copy_(
+ torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)]
+ .detach()
+ .clone()
+ )
+ else:
+ zero_param.data.copy_(torch_param.data.detach().clone())
+ zero_optimizer = torch.optim.Adam(zero_model.parameters())
+ plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
+ plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group
+ booster = Booster(plugin=plugin)
+ zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+
+ run_fwd_bwd(torch_model, data, label, criterion, None)
+ torch_optimizer.step()
+ run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
+ zero_optimizer.step()
+
+ for (torch_name, torch_param), (zero_name, zero_param) in zip(
+ torch_model.named_parameters(), zero_model.named_parameters()
+ ):
+ if is_moe_tensor(zero_param):
+ num_expert = torch_param.data.shape[0]
+ torch_param.data = torch_param.data[
+ ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)
+ ]
+ assert torch.allclose(
+ torch_param.data, zero_param.data, atol=1e-4
+ ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_zero_optim_test(rank, world_size, stage=1)
+ run_zero_optim_test(rank, world_size, stage=2)
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [4])
+@rerun_if_address_is_in_use()
+def test_moe_zero_optim(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_moe_zero_optim(world_size=4)
diff --git a/tests/test_moe/test_moe_load_balance.py b/tests/test_moe/test_moe_load_balance.py
new file mode 100644
index 000000000000..717bb99fb830
--- /dev/null
+++ b/tests/test_moe/test_moe_load_balance.py
@@ -0,0 +1,188 @@
+import pytest
+import torch
+import torch.distributed as dist
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import LowLevelZeroPlugin
+from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
+from colossalai.moe.layers import apply_load_balance
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.tensor.moe_tensor.api import is_moe_tensor
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
+
+
+def split_ddp_grad(grad, world_size):
+ with torch.no_grad():
+ grad = grad.clone().detach().flatten()
+ padding_size = (world_size - grad.numel() % world_size) % world_size
+ if padding_size > 0:
+ grad = torch.nn.functional.pad(grad, [0, padding_size])
+ splited_grad = grad.split(grad.numel() // world_size)
+ return splited_grad
+
+
+def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
+ model.train()
+ with torch.cuda.amp.autocast(enabled=enable_autocast):
+ if criterion:
+ y = model(data)
+ loss = criterion(y, label)
+ else:
+ loss = model(data, label)
+ loss = loss.float()
+
+ if isinstance(model, LowLevelZeroModel):
+ optimizer.backward(loss)
+ else:
+ loss.backward()
+ return y
+
+
+def run_zero_optim_test(local_rank, world_size, stage=1):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(
+ parallel="EP",
+ )
+ zero_model = MoeModel(enable_load_balance=True)
+ zero_optimizer = torch.optim.Adam(zero_model.parameters())
+ plugin = LowLevelZeroPlugin(stage=stage, precision="bf16", verbose=True)
+ booster = Booster(plugin=plugin)
+ zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel="EP")
+ torch_model = MoeModel()
+ for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
+ torch_param.data.copy_(zero_param.data)
+ torch_optimizer = torch.optim.Adam(torch_model.parameters())
+ torch_model = torch_model.cuda().bfloat16()
+ grad_handler = MoeGradientHandler(torch_model)
+
+ # run to update expert load
+ data = torch.randn(16, 4).cuda().bfloat16() / 1000 / (local_rank + 1)
+ label = torch.randint(0, 4, (16,)).cuda()
+
+ # run torch model twice
+ run_fwd_bwd(torch_model, data, label, criterion, None)
+ grad_handler.handle_gradient()
+ torch_optimizer.step()
+ torch_optimizer.zero_grad()
+ run_fwd_bwd(torch_model, data, label, criterion, None)
+ grad_handler.handle_gradient()
+
+ # get optim and load status in zero model
+ run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
+ zero_optimizer.step()
+ zero_optimizer.zero_grad()
+ with torch.no_grad():
+ origin_out = zero_model(data)
+
+ # load balance
+ apply_load_balance(zero_model, zero_optimizer)
+
+ # run again to test
+ zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
+ torch.allclose(origin_out, zero_out)
+
+ # assert optim
+ torch_optimizer.step()
+ torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
+ zero_optimizer.step()
+ zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
+ assert torch.allclose(zero_out, torch_out, atol=3e-5), f"zero_out:{zero_out}\ntorch_out{torch_out}"
+
+
+def run_hybrid_zero_optim_test(local_rank, world_size, stage=1):
+ criterion = torch.nn.CrossEntropyLoss()
+ data = torch.randn(16, 4).cuda()
+ label = torch.randint(0, 4, (16,)).cuda()
+
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(parallel=None)
+ torch_model = MoeModel()
+ torch_optimizer = torch.optim.Adam(torch_model.parameters())
+ torch_model = torch_model.cuda()
+
+ MOE_MANAGER.__init__()
+ MOE_MANAGER.setup(
+ max_ep_size=2,
+ use_ep_inside=False,
+ parallel="EP",
+ )
+ zero_model = MoeModel(enable_load_balance=True)
+ extra_dp_group = MOE_MANAGER.parallel_info_dict[2].dp_group
+ ep_rank = dist.get_rank(MOE_MANAGER.parallel_info_dict[2].ep_group)
+ ep_size = MOE_MANAGER.parallel_info_dict[2].ep_size
+ for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
+ if is_moe_tensor(zero_param):
+ num_expert = torch_param.data.shape[0]
+ zero_param.data.copy_(
+ torch_param.data[ep_rank * (num_expert // ep_size) : (ep_rank + 1) * (num_expert // ep_size)]
+ .detach()
+ .clone()
+ )
+ else:
+ zero_param.data.copy_(torch_param.data.detach().clone())
+ zero_optimizer = torch.optim.Adam(zero_model.parameters())
+ plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
+ plugin.zero_optim_kwargs["moe_extra_dp_process_group"] = extra_dp_group
+ booster = Booster(plugin=plugin)
+ zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+
+ # run torch for twice
+ run_fwd_bwd(torch_model, data, label, criterion, None)
+ torch_optimizer.step()
+ torch_optimizer.zero_grad()
+ run_fwd_bwd(torch_model, data, label, criterion, None)
+ torch_optimizer.step()
+
+ # run zero
+ run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
+ zero_optimizer.step()
+ zero_optimizer.zero_grad()
+ with torch.no_grad():
+ origin_out = zero_model(data)
+
+ # load balance
+ apply_load_balance(zero_model, zero_optimizer)
+
+ # assert out
+ zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
+ torch.allclose(origin_out, zero_out)
+
+ # assert optim
+ zero_optimizer.step()
+ zero_out = run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
+ torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
+ # TODO: high atol, check if bug exists
+ assert torch.allclose(zero_out, torch_out, atol=8e-4), f"zero_out:{zero_out}\ntorch_out{torch_out}"
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(
+ config=dict(),
+ rank=rank,
+ world_size=world_size,
+ host="localhost",
+ port=port,
+ backend="nccl",
+ )
+ run_zero_optim_test(rank, world_size, stage=1)
+ run_zero_optim_test(rank, world_size, stage=2)
+ run_hybrid_zero_optim_test(rank, world_size, stage=1)
+ run_hybrid_zero_optim_test(rank, world_size, stage=2)
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [4])
+@rerun_if_address_is_in_use()
+def test_moe_load_balance(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_moe_load_balance(world_size=4)
diff --git a/tests/test_moe/test_moe_router.py b/tests/test_moe/test_moe_router.py
new file mode 100644
index 000000000000..7ba7fa6f6b7d
--- /dev/null
+++ b/tests/test_moe/test_moe_router.py
@@ -0,0 +1,41 @@
+import pytest
+import torch
+
+from colossalai.moe.routers import MoeRouter, Top1Router, Top2Router, TopKRouter
+
+
+@pytest.mark.parametrize(["router", "num_groups"], [
+ (Top1Router(), 1),
+ (Top2Router(), 1),
+ # (TopKRouter(num_selected_experts=3), 4),
+])
+@pytest.mark.parametrize(["batch_size", "seq_len", "num_experts"], [
+ (4, 5, 8),
+ (3, 4, 4),
+])
+def test_router_forward(router: MoeRouter, batch_size: int, seq_len: int, num_experts: int, num_groups: int):
+ x = torch.randn((batch_size * seq_len, num_experts)).cuda()
+ if num_groups > 1:
+ x = x.expand(num_groups, -1, -1)
+
+ router.train()
+ if isinstance(router, TopKRouter):
+ _, combine_array, dispatch_mask = router(x, expert_capacity=2)
+ else:
+ _, combine_array, dispatch_mask = router(x)
+ assert combine_array.shape[:-1] == x.shape
+ assert dispatch_mask.shape[:-1] == x.shape
+ assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
+
+ router.eval()
+ if isinstance(router, TopKRouter):
+ _, combine_array, dispatch_mask = router(x, expert_capacity=2)
+ else:
+ _, combine_array, dispatch_mask = router(x)
+ assert combine_array.shape[:-1] == x.shape
+ assert dispatch_mask.shape[:-1] == x.shape
+ assert torch.all(dispatch_mask.sum(-1).sum(-1) <= router.k_value)
+
+
+if __name__ == "__main__":
+ test_router_forward(Top2Router(), 4, 4, 4, 1)
diff --git a/tests/test_moe/test_moe_zero_fwd_bwd.py b/tests/test_moe/test_moe_zero_fwd_bwd.py
new file mode 100644
index 000000000000..f0795a4c738f
--- /dev/null
+++ b/tests/test_moe/test_moe_zero_fwd_bwd.py
@@ -0,0 +1,105 @@
+import pytest
+import torch
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import LowLevelZeroPlugin
+from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+from colossalai.testing.random import seed_all
+from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
+
+
+def split_ddp_grad(grad, world_size):
+ with torch.no_grad():
+ grad = grad.clone().detach().flatten()
+ padding_size = (world_size - grad.numel() % world_size) % world_size
+ if padding_size > 0:
+ grad = torch.nn.functional.pad(grad, [0, padding_size])
+ splited_grad = grad.split(grad.numel() // world_size)
+ return splited_grad
+
+
+def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
+ model.train()
+ with torch.cuda.amp.autocast(enabled=enable_autocast):
+ if criterion:
+ y = model(data)
+ loss = criterion(y, label)
+ else:
+ loss = model(data, label)
+ loss = loss.float()
+
+ if isinstance(model, LowLevelZeroModel):
+ optimizer.backward(loss)
+ else:
+ loss.backward()
+ return y
+
+
+def run_zero_test(local_rank, world_size, stage=1):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ zero_model = MoeModel()
+ optimizer = torch.optim.Adam(zero_model.parameters())
+ plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
+ booster = Booster(plugin=plugin)
+ zero_model, optimizer, _, _, _ = booster.boost(zero_model, optimizer)
+
+ torch_model = MoeModel()
+ for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
+ torch_param.data.copy_(zero_param.data)
+ torch_model = torch_model.cuda()
+ grad_handler = MoeGradientHandler(torch_model)
+
+ # assert zero model
+ for (torch_name, torch_param), (zero_name, zero_param) in zip(
+ torch_model.named_parameters(), zero_model.module.named_parameters()
+ ):
+ assert zero_name == torch_name
+ assert torch.allclose(zero_param.data, torch_param.data)
+
+ data = torch.randn(16, 4).cuda()
+ label = torch.randint(0, 4, (16,)).cuda()
+
+ torch_out = run_fwd_bwd(torch_model, data, label, criterion, None)
+ zero_out = run_fwd_bwd(zero_model, data, label, criterion, optimizer)
+ assert torch.allclose(torch_out, zero_out)
+ grad_handler.handle_gradient()
+
+ for (zero_name, zero_param), (torch_name, torch_param) in zip(
+ zero_model.module.named_parameters(), torch_model.named_parameters()
+ ):
+ assert zero_name == torch_name
+ zero_grad_list = optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(zero_param))
+ if hasattr(zero_param, "moe_info"):
+ assert len(zero_grad_list) == 0
+ assert torch.allclose(zero_param.grad, torch_param.grad)
+ else:
+ assert len(zero_grad_list) > 0
+ torch_grad_list = split_ddp_grad(torch_param.grad, world_size)
+ if stage == 2:
+ torch_grad_list = torch_grad_list[local_rank : local_rank + 1]
+ assert len(zero_grad_list) == len(torch_grad_list)
+ for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list):
+ assert torch.allclose(zero_grad, torch_grad)
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ MOE_MANAGER.setup(parallel="EP")
+ seed_all(42 + rank)
+ run_zero_test(rank, world_size, stage=1)
+ run_zero_test(rank, world_size, stage=2)
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize("world_size", [2])
+@rerun_if_address_is_in_use()
+def test_moe_zero_model(world_size):
+ spawn(run_dist, world_size)
+
+
+if __name__ == "__main__":
+ test_moe_zero_model(world_size=2)
diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py
deleted file mode 100644
index c48f9a3557ce..000000000000
--- a/tests/test_moe/test_moe_zero_init.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import pytest
-import torch
-import torch.nn as nn
-
-import colossalai
-from colossalai.context import MOE_CONTEXT
-from colossalai.logging import get_dist_logger
-from colossalai.nn import CheckpointModule
-from colossalai.nn.layer import MoeModule
-from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
-from tests.test_zero.test_legacy.common import CONFIG
-
-
-class MoeModel(nn.Module):
- def __init__(self, checkpoint: bool = False):
- class TestSubModule(CheckpointModule):
- def __init__(self):
- super().__init__(checkpoint)
- expert_cls = nn.Linear
- expert_args_dict = dict(in_features=16, out_features=16)
- self.moe = MoeModule(
- dim_model=16, num_experts=8, use_residual=True, expert_cls=expert_cls, **expert_args_dict
- )
- self.proj = nn.Linear(16, 4)
-
- def _forward(self, x):
- x, y = self.moe(x)
- x = self.proj(x)
- return x, y
-
- super().__init__()
- self.test_embed = nn.Linear(4, 16)
- self.test_transform = TestSubModule()
-
- def forward(self, x):
- MOE_CONTEXT.reset_loss()
-
- x = self.test_embed(x)
- x, y = self.test_transform(x)
-
- MOE_CONTEXT.add_loss(y)
- return x
-
-
-@parameterize("init_device_type", ["cpu", "cuda"])
-@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
-def run_moe_zero_init(init_device_type, shard_strategy_class):
- get_dist_logger("test_moe_zero_init")
-
- if init_device_type == "cuda":
- init_device = get_current_device()
- elif init_device_type == "cpu":
- init_device = torch.device("cpu")
- else:
- raise NotImplementedError("Unknown device found.")
-
- model_numel_tensor = torch.zeros(1, dtype=torch.int)
- with ZeroInitContext(
- target_device=init_device,
- shard_strategy=shard_strategy_class(),
- shard_param=True,
- model_numel_tensor=model_numel_tensor,
- ):
- model = MoeModel(checkpoint=True)
-
- for name, param in model.named_parameters():
- assert hasattr(param, "colo_attr")
-
- # the parameters in moe experts and its gate should not be sharded
- if ("experts" in name) or ("gate" in name) or ("residual_combine" in name):
- assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
- else:
- assert param.colo_attr.sharded_data_tensor.is_sharded
-
- # the parameters in moe experts is not replicated
- if "experts" in name:
- assert not param.colo_attr.is_replicated
- else:
- assert param.colo_attr.is_replicated
-
- if param.colo_attr.param_is_sharded:
- assert (
- param.colo_attr.data_payload.device.type == init_device.type
- ), f"{param.colo_attr.data_payload.device.type} vs. {init_device.type}"
- else:
- assert param.colo_attr.data_payload.device.type == "cuda"
-
-
-def _run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- MOE_CONTEXT.setup(seed=42)
- run_moe_zero_init()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2, 4])
-@rerun_if_address_is_in_use()
-def test_moe_zero_init(world_size):
- spawn(_run_dist, world_size)
-
-
-if __name__ == "__main__":
- test_moe_zero_init(world_size=2)
diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py
deleted file mode 100644
index 724d70d77bc6..000000000000
--- a/tests/test_moe/test_moe_zero_model.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import pytest
-import torch
-
-import colossalai
-from colossalai.context import MOE_CONTEXT
-from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
-from colossalai.nn import MoeLoss
-from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
-from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
-from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
-from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_moe.test_moe_zero_init import MoeModel
-from tests.test_zero.test_legacy.common import CONFIG, check_grads_padding, run_fwd_bwd
-
-
-@parameterize("enable_autocast", [False])
-@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
-def run_model_test(enable_autocast, shard_strategy_class):
- shard_strategy = shard_strategy_class()
-
- get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model")
- _, train_dataloader, _, optimizer_class, _ = get_components_func()
- criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
-
- with ZeroInitContext(
- target_device=torch.device("cuda", torch.cuda.current_device()), shard_strategy=shard_strategy, shard_param=True
- ):
- zero_model = MoeModel(checkpoint=True)
- zero_model = ShardedModelV2(zero_model, shard_strategy)
-
- # check whether parameters are identical in ddp
- for name, p in zero_model.named_parameters():
- if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
- assert_equal_in_group(p.colo_attr.data_payload)
-
- model = MoeModel(checkpoint=True).half()
- col_model_deepcopy(zero_model, model)
- model = model.cuda()
- grad_handler = MoeGradientHandler(model)
-
- for i, (data, label) in enumerate(train_dataloader):
- if i > 5:
- break
-
- data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
- run_fwd_bwd(model, data, label, criterion, enable_autocast)
- run_fwd_bwd(zero_model, data, label, criterion, enable_autocast)
- grad_handler.handle_gradient()
-
- check_grads_padding(model, zero_model, loose=True)
-
-
-def run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- MOE_CONTEXT.setup(seed=42)
- run_model_test()
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize("world_size", [2])
-@rerun_if_address_is_in_use()
-def test_moe_zero_model(world_size):
- spawn(run_dist, world_size)
-
-
-if __name__ == "__main__":
- test_moe_zero_model(world_size=2)
diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py
index bb9822daee05..0d2e2fb1b2d8 100644
--- a/tests/test_moe/test_moe_zero_optim.py
+++ b/tests/test_moe/test_moe_zero_optim.py
@@ -2,120 +2,91 @@
import torch
import colossalai
-from colossalai.context import MOE_CONTEXT
-from colossalai.legacy.amp import convert_to_apex_amp
-from colossalai.legacy.engine.gradient_handler import MoeGradientHandler
-from colossalai.nn import MoeLoss
-from colossalai.nn.optimizer import CPUAdam
-from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use, spawn
-from colossalai.utils import get_current_device
-from colossalai.zero.legacy.init_ctx import ZeroInitContext
-from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy
-from colossalai.zero.legacy.sharded_model import ShardedModelV2
-from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy
-from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2
-from colossalai.zero.low_level._utils import has_inf_or_nan
-from tests.components_to_test.registry import non_distributed_component_funcs
-from tests.test_moe.test_moe_zero_init import MoeModel
-from tests.test_zero.test_legacy.common import CONFIG, check_sharded_model_params
-
-
-def _run_step(model, optimizer, data, label, criterion, grad_handler):
- model.train()
- optimizer.zero_grad()
+from colossalai.booster import Booster
+from colossalai.booster.plugin import LowLevelZeroPlugin
+from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
+from colossalai.moe.manager import MOE_MANAGER
+from colossalai.testing import rerun_if_address_is_in_use, spawn
+from tests.test_moe.moe_utils import MoeGradientHandler, MoeModel
- if criterion:
- y = model(data)
- loss = criterion(y, label)
- else:
- loss = model(data, label)
- loss = loss.float()
- if isinstance(model, ShardedModelV2):
+def split_ddp_grad(grad, world_size):
+ with torch.no_grad():
+ grad = grad.clone().detach().flatten()
+ padding_size = (world_size - grad.numel() % world_size) % world_size
+ if padding_size > 0:
+ grad = torch.nn.functional.pad(grad, [0, padding_size])
+ splited_grad = grad.split(grad.numel() // world_size)
+ return splited_grad
+
+
+def run_fwd_bwd(model, data, label, criterion, optimizer, enable_autocast=False):
+ model.train()
+ with torch.cuda.amp.autocast(enabled=enable_autocast):
+ if criterion:
+ y = model(data)
+ loss = criterion(y, label)
+ else:
+ loss = model(data, label)
+ loss = loss.float()
+
+ if isinstance(model, LowLevelZeroModel):
optimizer.backward(loss)
else:
loss.backward()
-
- if grad_handler is not None:
+ return y
+
+
+def run_zero_optim_test(local_rank, world_size, stage=1):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ zero_model = MoeModel()
+ zero_optimizer = torch.optim.Adam(zero_model.parameters())
+ plugin = LowLevelZeroPlugin(stage=stage, precision="fp32")
+ booster = Booster(plugin=plugin)
+ zero_model, zero_optimizer, _, _, _ = booster.boost(zero_model, zero_optimizer)
+
+ torch_model = MoeModel()
+ for zero_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
+ torch_param.data.copy_(zero_param.data)
+ torch_optimizer = torch.optim.Adam(torch_model.parameters())
+ torch_model = torch_model.cuda()
+ grad_handler = MoeGradientHandler(torch_model)
+
+ for _ in range(2):
+ data = torch.randn(16, 4).cuda() / (local_rank + 1)
+ label = torch.randint(0, 4, (16,)).cuda()
+ run_fwd_bwd(torch_model, data, label, criterion, None)
+ run_fwd_bwd(zero_model, data, label, criterion, zero_optimizer)
grad_handler.handle_gradient()
- optimizer.step()
-
-
-@parameterize("cpu_offload", [True])
-@parameterize("use_cpuadam", [True]) # We do not use Hybrid Adam right now, since it has a little bug
-@parameterize("reuse_fp16_shard", [True, False])
-@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
-def _run_test_sharded_optim_v2(
- cpu_offload, shard_strategy_class, use_cpuadam, reuse_fp16_shard, gpu_margin_mem_ratio=0.0
-):
- shard_strategy = shard_strategy_class()
- if use_cpuadam and cpu_offload is False:
- return
- MOE_CONTEXT.reset_loss()
- get_components_func = non_distributed_component_funcs.get_callable("hanging_param_model")
- _, train_dataloader, _, optimizer_class, _ = get_components_func()
- criterion = MoeLoss(aux_weight=0.01, loss_fn=torch.nn.CrossEntropyLoss)
-
- with ZeroInitContext(
- target_device=torch.device("cpu") if cpu_offload else get_current_device(),
- shard_strategy=shard_strategy,
- shard_param=True,
- ):
- zero_model = MoeModel(checkpoint=True)
-
- zero_model = ShardedModelV2(
- zero_model,
- shard_strategy,
- tensor_placement_policy="cpu" if cpu_offload else "cuda",
- reuse_fp16_shard=reuse_fp16_shard,
- )
-
- # check whether parameters are identical in ddp
- for name, p in zero_model.named_parameters():
- if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
- assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device()))
-
- model = MoeModel(checkpoint=True).half()
- col_model_deepcopy(zero_model, model)
- model = model.cuda().float()
-
- if use_cpuadam:
- optimizer_class = CPUAdam
- optim = optimizer_class(model.parameters(), lr=1e-3)
- sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
- sharded_optim = ShardedOptimizerV2(
- zero_model, sharded_optim, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio
- )
-
- amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False)
- apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
- apex_grad_handler = MoeGradientHandler(model)
-
- for i, (data, label) in enumerate(train_dataloader):
- if i > 5:
- break
- data, label = data.cuda(), label.cuda()
- _run_step(apex_model, apex_optimizer, data, label, criterion, apex_grad_handler)
- _run_step(zero_model, sharded_optim, data, label, criterion, None)
- check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
- for param in model.parameters():
- assert not has_inf_or_nan(param)
-
-
-def _run_dist(rank, world_size, port):
- colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
- MOE_CONTEXT.setup(seed=42)
- _run_test_sharded_optim_v2()
-
-
-# use_cpuadam = True can be used with cpu_offload = False
+ torch_optimizer.step()
+ zero_optimizer.step()
+
+ for (torch_name, torch_param), (zero_name, zero_param) in zip(
+ torch_model.named_parameters(), zero_model.named_parameters()
+ ):
+ assert torch.allclose(
+ torch_param.data, zero_param.data
+ ), f"{torch_name}\ntorch_param {torch_param.data}\nzero_param {zero_param.data}"
+
+ torch_optimizer.zero_grad()
+ zero_optimizer.zero_grad()
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ MOE_MANAGER.setup(parallel="EP")
+ run_zero_optim_test(rank, world_size, stage=1)
+ run_zero_optim_test(rank, world_size, stage=2)
+
+
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [2])
@rerun_if_address_is_in_use()
def test_moe_zero_optim(world_size):
- spawn(_run_dist, world_size)
+ spawn(run_dist, world_size)
if __name__ == "__main__":
- test_moe_zero_optim(world_size=4)
+ test_moe_zero_optim(world_size=2)
diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py
index 0192afc99ae4..f652d18e9494 100644
--- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py
+++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_amp_optimizer.py
@@ -124,57 +124,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
- {
- "tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 4,
- "enable_all_optimization": False,
- "use_lazy_init": True,
- "precision": "fp16",
- "max_norm": 5,
- "initial_scale": 1,
- },
- {
- "tp_size": 2,
- "pp_size": 1,
- "enable_all_optimization": False,
- "use_lazy_init": False,
- "precision": "fp16",
- "max_norm": 5,
- "initial_scale": 1,
- },
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
- "enable_all_optimization": False,
+ "enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
"initial_scale": 1,
},
- {
- "tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 4,
- "enable_all_optimization": False,
- "use_lazy_init": True,
- "precision": "bf16",
- "max_norm": 5,
- },
- {
- "tp_size": 2,
- "pp_size": 1,
- "enable_all_optimization": False,
- "use_lazy_init": False,
- "precision": "bf16",
- "max_norm": 5,
- },
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
- "enable_all_optimization": False,
+ "enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
@@ -199,7 +163,7 @@ def run_test(test_config):
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
- "enable_all_optimization": False,
+ "enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
@@ -208,7 +172,7 @@ def run_test(test_config):
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
- "enable_all_optimization": False,
+ "enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py
index da298f5c0be1..a749a2966fde 100644
--- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py
+++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_naive_optimizer.py
@@ -102,23 +102,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
- {
- "tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 4,
- "enable_all_optimization": False,
- "use_lazy_init": True,
- "precision": "fp32",
- "max_norm": 5,
- },
- {
- "tp_size": 2,
- "pp_size": 1,
- "enable_all_optimization": False,
- "use_lazy_init": False,
- "precision": "fp32",
- "max_norm": 5,
- },
{
"tp_size": 2,
"pp_size": 2,
diff --git a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py
index f1ac1de1acc9..41f06a4c3888 100644
--- a/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py
+++ b/tests/test_shardformer/test_hybrid_parallel_grad_clip_norm/test_zero_optimizer.py
@@ -107,16 +107,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"num_microbatches": 4,
"zero_stage": 1,
"enable_all_optimization": False,
- "use_lazy_init": True,
- "precision": "fp16",
- "max_norm": 5,
- "initial_scale": 1,
- },
- {
- "tp_size": 2,
- "pp_size": 1,
- "zero_stage": 1,
- "enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
@@ -132,16 +122,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"max_norm": 5,
"initial_scale": 1,
},
- {
- "tp_size": 1,
- "pp_size": 2,
- "num_microbatches": 4,
- "zero_stage": 1,
- "enable_all_optimization": False,
- "use_lazy_init": True,
- "precision": "bf16",
- "max_norm": 5,
- },
{
"tp_size": 2,
"pp_size": 1,
@@ -151,15 +131,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"precision": "bf16",
"max_norm": 5,
},
- {
- "tp_size": 2,
- "pp_size": 1,
- "zero_stage": 2,
- "enable_all_optimization": False,
- "use_lazy_init": False,
- "precision": "bf16",
- "max_norm": 5,
- },
],
)
def run_test(test_config):
@@ -181,7 +152,7 @@ def run_test(test_config):
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
- "enable_all_optimization": False,
+ "enable_all_optimization": True,
"use_lazy_init": False,
"precision": "bf16",
"max_norm": 5,
@@ -191,7 +162,7 @@ def run_test(test_config):
"pp_size": 2,
"num_microbatches": 4,
"zero_stage": 1,
- "enable_all_optimization": False,
+ "enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp16",
"max_norm": 5,
diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py
index 31fd58d06f77..b38793b7c388 100644
--- a/tests/test_shardformer/test_model/test_shard_bert.py
+++ b/tests/test_shardformer/test_model/test_shard_bert.py
@@ -34,6 +34,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
bert = unwrap_model(org_model, "BertModel", "bert")
sharded_bert = unwrap_model(sharded_model, "BertModel", "bert")
+ norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
col_layer_for_check = ["encoder.layer[0].output.dense"]
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
@@ -50,8 +51,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
+
+ norm_layer_grads = get_grad_tensors_for_check(
+ bert,
+ sharded_bert,
+ norm_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
+
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
+ grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()
@@ -85,6 +99,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
@parameterize(
"test_config",
[
+ {
+ "tp_size": 2,
+ "pp_size": 1,
+ "enable_all_optimization": True,
+ "use_lazy_init": True,
+ "precision": "fp32",
+ },
{
"tp_size": 1,
"pp_size": 2,
diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py
index 7fe791db6d5e..b70cba8b4a53 100644
--- a/tests/test_shardformer/test_model/test_shard_bloom.py
+++ b/tests/test_shardformer/test_model/test_shard_bloom.py
@@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
bloom = unwrap_model(org_model, "BloomModel", "transformer")
sharded_bloom = unwrap_model(sharded_model, "BloomModel", "transformer")
+ norm_layer_for_check = ["word_embeddings_layernorm", "h[0].input_layernorm"]
row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"]
col_layer_for_check = ["h[0].self_attention.dense"]
@@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
col_layer_grads = get_grad_tensors_for_check(
bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
+
+ norm_layer_grads = get_grad_tensors_for_check(
+ bloom,
+ sharded_bloom,
+ norm_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
+
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
+ grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()
diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py
index bdf5b79fc498..29d3592bf34e 100644
--- a/tests/test_shardformer/test_model/test_shard_chatglm2.py
+++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py
@@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
chatglm_model = unwrap_model(org_model, "ChatGLMModel", "transformer")
shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer")
+ norm_layer_for_check = ["encoder.layers[0].input_layernorm"]
row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"]
col_layer_for_check = ["encoder.layers[0].self_attention.dense"]
@@ -66,8 +67,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False,
)
+
+ norm_layer_grads = get_grad_tensors_for_check(
+ chatglm_model,
+ shard_chatglm_model,
+ norm_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
+
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
+ grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()
diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py
index 69a15166a54c..66b30641acc8 100644
--- a/tests/test_shardformer/test_model/test_shard_gpt2.py
+++ b/tests/test_shardformer/test_model/test_shard_gpt2.py
@@ -35,6 +35,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
gpt2 = unwrap_model(org_model, "GPT2Model", "transformer")
sharded_gpt2 = unwrap_model(sharded_model, "GPT2Model", "transformer")
+ norm_layer_for_check = ["h[0].ln_1", "h[0].ln_2"]
col_layer_for_check = ["h[0].mlp.c_fc"]
row_layer_for_check = ["wte", "h[0].mlp.c_proj"]
@@ -51,8 +52,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
row_layer_grads = get_grad_tensors_for_check(
gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
+
+ norm_layer_grads = get_grad_tensors_for_check(
+ gpt2,
+ sharded_gpt2,
+ norm_layer_for_check,
+ tp_group,
+ atol=atol,
+ rtol=rtol,
+ dim=1,
+ verbose=False,
+ )
+
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
+ grads_to_check.update(norm_layer_grads)
# optimizer executes step
org_optimizer.step()
diff --git a/tests/test_zero/test_gemini/test_chunkv2.py b/tests/test_zero/test_gemini/test_chunkv2.py
index a31c888e966d..5977c706fdd1 100644
--- a/tests/test_zero/test_gemini/test_chunkv2.py
+++ b/tests/test_zero/test_gemini/test_chunkv2.py
@@ -39,7 +39,7 @@ def exam_chunk_basic(init_device, keep_gathered, pin_memory):
pg = _get_default_group()
my_chunk = Chunk(
chunk_size=1024,
- process_group=pg,
+ zero_group=pg,
dtype=torch.float32,
init_device=init_device,
cpu_shard_init=True,
diff --git a/tests/test_zero/test_gemini/test_fwd_bwd.py b/tests/test_zero/test_gemini/test_fwd_bwd.py
index b8d3f45e0f34..21afff753ae6 100644
--- a/tests/test_zero/test_gemini/test_fwd_bwd.py
+++ b/tests/test_zero/test_gemini/test_fwd_bwd.py
@@ -9,7 +9,7 @@
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
diff --git a/tests/test_zero/test_gemini/test_grad_accum.py b/tests/test_zero/test_gemini/test_grad_accum.py
index 5e36b18389b1..35323e516071 100644
--- a/tests/test_zero/test_gemini/test_grad_accum.py
+++ b/tests/test_zero/test_gemini/test_grad_accum.py
@@ -9,7 +9,7 @@
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd
@@ -49,7 +49,10 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("keep_gathered", [False, True])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True])
-def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
+@parameterize("use_grad_checkpoint", [False, True])
+def exam_gemini_grad_acc(
+ placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
+):
init_device = get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
@@ -63,6 +66,10 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
torch_p.data.copy_(p.data)
+ if use_grad_checkpoint:
+ gemini_model.gradient_checkpointing_enable()
+ torch_model.gradient_checkpointing_enable()
+
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(gemini_model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
@@ -77,7 +84,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
**placement_config,
)
optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
- gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1)
+ gemini_optim = GeminiOptimizer(optimizer, gemini_model, initial_scale=1, max_norm=1.0)
rank = dist.get_rank()
@@ -112,6 +119,7 @@ def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str,
check_grad(gemini_model, torch_model)
if (i + 1) % accum_iter == 0:
+ torch.nn.utils.clip_grad_norm_(amp.master_params(torch_optim), 1.0)
torch_optim.step()
gemini_optim.step()
torch_optim.zero_grad()
diff --git a/tests/test_zero/test_gemini/test_grad_clip.py b/tests/test_zero/test_gemini/test_grad_clip.py
index c3a36d3bafa1..23b3504fdb7c 100644
--- a/tests/test_zero/test_gemini/test_grad_clip.py
+++ b/tests/test_zero/test_gemini/test_grad_clip.py
@@ -88,7 +88,7 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
- zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, clipping_norm=1.0)
+ zero_optim = GeminiOptimizer(optimizer, model, initial_scale=32, max_norm=1.0)
model.train()
torch_model.train()
diff --git a/tests/test_zero/test_gemini/test_inference.py b/tests/test_zero/test_gemini/test_inference.py
index e20428b67b41..152bf289502a 100644
--- a/tests/test_zero/test_gemini/test_inference.py
+++ b/tests/test_zero/test_gemini/test_inference.py
@@ -11,7 +11,7 @@
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py
index 887e495e6187..405d7d789b01 100644
--- a/tests/test_zero/test_gemini/test_optim.py
+++ b/tests/test_zero/test_gemini/test_optim.py
@@ -9,7 +9,7 @@
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
-from colossalai.utils.cuda import get_current_device
+from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py
index 3c5baea138e0..351ae5f67ff7 100644
--- a/tests/test_zero/test_low_level/test_grad_acc.py
+++ b/tests/test_zero/test_low_level/test_grad_acc.py
@@ -9,7 +9,7 @@
import colossalai
from colossalai.testing import spawn
from colossalai.testing.random import seed_all
-from colossalai.utils import conditional_context
+from colossalai.utils import conditional_context, get_current_device
from colossalai.zero import LowLevelZeroOptimizer
@@ -28,9 +28,9 @@ def forward(self, x):
def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank()
seed_all(2009)
-
+ device = get_current_device()
# create model
- zero1_model = MlpModel().cuda()
+ zero1_model = MlpModel().to(device)
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
@@ -43,8 +43,8 @@ def exam_zero_1_2_grad_acc():
)
# create data
seed_all(2021 + local_rank)
- input_data1 = torch.randn(32, 128).cuda()
- input_data2 = torch.randn(32, 128).cuda()
+ input_data1 = torch.randn(32, 128, device=device)
+ input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
@@ -71,14 +71,15 @@ def fwd_bwd_func(number, cur_data, check_flag):
def exam_zero_1_grad_acc(sync):
local_rank = torch.distributed.get_rank()
seed_all(2008)
+ device = get_current_device()
# create models
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
seed_all(2008)
- zero_model = zero_model.cuda()
- torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
+ zero_model = zero_model.to(device)
+ torch_model = DDP(torch_model.to(device), bucket_cap_mb=0)
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
@@ -94,8 +95,8 @@ def exam_zero_1_grad_acc(sync):
# create data
seed_all(2022 + local_rank)
- input_data1 = torch.randn(32, 128).cuda()
- input_data2 = torch.randn(32, 128).cuda()
+ input_data1 = torch.randn(32, 128, device=device)
+ input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(no_sync, cur_data, check_flag):
# zero1 fwd and bwd
diff --git a/version.txt b/version.txt
index 1c09c74e221c..42045acae20f 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.3.3
+0.3.4