diff --git a/.compatibility b/.compatibility
index c8ac4083d2a2..32da32be5521 100644
--- a/.compatibility
+++ b/.compatibility
@@ -1,3 +1,3 @@
1.12.0-11.3.0
-1.11.0-11.3.0
-1.10.1-11.3.0
+1.13.0-11.6.0
+2.0.0-11.7.0
diff --git a/.github/workflows/build_on_pr.yml b/.github/workflows/build_on_pr.yml
index a2807859b591..513de40b7353 100644
--- a/.github/workflows/build_on_pr.yml
+++ b/.github/workflows/build_on_pr.yml
@@ -60,6 +60,9 @@ jobs:
defaults:
run:
shell: bash
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- name: Copy testmon cache
run: | # branch name may contain slash, we need to replace it with space
@@ -83,6 +86,9 @@ jobs:
changedLibraryFiles: ${{ steps.find-lib-change.outputs.all_changed_files }}
anyLibraryFileChanged: ${{ steps.find-lib-change.outputs.any_changed }}
runs-on: ubuntu-latest
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- uses: actions/checkout@v2
with:
@@ -140,6 +146,9 @@ jobs:
defaults:
run:
shell: bash
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- name: Checkout TensorNVMe
uses: actions/checkout@v2
@@ -271,7 +280,6 @@ jobs:
PR_NUMBER: ${{ github.event.pull_request.number }}
- name: Remove testmon cache
- if: github.event.pull_request.merged != true
run: |
rm -rf /github/home/testmon_cache/_pull/${PR_NUMBER}
env:
diff --git a/.github/workflows/compatiblity_test_on_pr.yml b/.github/workflows/compatiblity_test_on_pr.yml
index 94a723388872..5098b8e364d0 100644
--- a/.github/workflows/compatiblity_test_on_pr.yml
+++ b/.github/workflows/compatiblity_test_on_pr.yml
@@ -12,6 +12,9 @@ jobs:
runs-on: ubuntu-latest
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- uses: actions/checkout@v3
- id: set-matrix
@@ -40,6 +43,9 @@ jobs:
image: ${{ matrix.container }}
options: --gpus all --rm -v /data/scratch/cifar-10:/data/scratch/cifar-10
timeout-minutes: 120
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- name: Install dependencies
run: |
diff --git a/.github/workflows/doc_check_on_pr.yml b/.github/workflows/doc_check_on_pr.yml
index 992cc93b008c..848991bd3a82 100644
--- a/.github/workflows/doc_check_on_pr.yml
+++ b/.github/workflows/doc_check_on_pr.yml
@@ -16,6 +16,9 @@ jobs:
github.event.pull_request.draft == false &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- uses: actions/checkout@v2
@@ -31,6 +34,9 @@ jobs:
github.event.pull_request.draft == false &&
github.event.pull_request.base.repo.full_name == 'hpcaitech/ColossalAI'
runs-on: ubuntu-latest
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- uses: actions/checkout@v2
with:
diff --git a/.github/workflows/doc_test_on_pr.yml b/.github/workflows/doc_test_on_pr.yml
index 325e2a7c95a4..2a07a2297bfb 100644
--- a/.github/workflows/doc_test_on_pr.yml
+++ b/.github/workflows/doc_test_on_pr.yml
@@ -19,6 +19,9 @@ jobs:
outputs:
any_changed: ${{ steps.changed-files.outputs.any_changed }}
changed_files: ${{ steps.changed-files.outputs.all_changed_files }}
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
name: Detect changed example files
steps:
- uses: actions/checkout@v3
@@ -59,6 +62,9 @@ jobs:
defaults:
run:
shell: bash
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- name: Checkout ColossalAI-Documentation
uses: actions/checkout@v2
diff --git a/.github/workflows/example_check_on_pr.yml b/.github/workflows/example_check_on_pr.yml
index 31dbf7540091..ee456c25f2b5 100644
--- a/.github/workflows/example_check_on_pr.yml
+++ b/.github/workflows/example_check_on_pr.yml
@@ -20,6 +20,9 @@ jobs:
matrix: ${{ steps.setup-matrix.outputs.matrix }}
anyChanged: ${{ steps.setup-matrix.outputs.anyChanged }}
name: Detect changed example files
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- uses: actions/checkout@v3
with:
@@ -77,6 +80,9 @@ jobs:
image: hpcaitech/pytorch-cuda:1.12.0-11.3.0
options: --gpus all --rm -v /data/scratch/examples-data:/data/
timeout-minutes: 10
+ concurrency:
+ group: ${{ github.head_ref }}
+ cancel-in-progress: false
steps:
- uses: actions/checkout@v3
diff --git a/.github/workflows/release_docker_after_publish.yml b/.github/workflows/release_docker_after_publish.yml
index 22698ca192ed..6c8df9730b0d 100644
--- a/.github/workflows/release_docker_after_publish.yml
+++ b/.github/workflows/release_docker_after_publish.yml
@@ -23,8 +23,11 @@ jobs:
run: |
version=$(cat version.txt)
tag=hpcaitech/colossalai:$version
+ latest=hpcaitech/colossalai:latest
docker build --build-arg http_proxy=http://172.17.0.1:7890 --build-arg https_proxy=http://172.17.0.1:7890 --build-arg VERSION=v${version} -t $tag ./docker
+ docker tag $tag $latest
echo "tag=${tag}" >> $GITHUB_OUTPUT
+ echo "latest=${latest}" >> $GITHUB_OUTPUT
- name: Log in to Docker Hub
uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9
@@ -36,6 +39,7 @@ jobs:
id: docker-push
run: |
docker push ${{ steps.build.outputs.tag }}
+ docker push ${{ steps.build.outputs.latest }}
notify:
name: Notify Lark via webhook
diff --git a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py
index 16b8957c1d88..d8f6c8fe309e 100644
--- a/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py
+++ b/.github/workflows/scripts/generate_leaderboard_and_send_to_lark.py
@@ -38,7 +38,7 @@ def plot_bar_chart(x: List[Any], y: List[Any], xlabel: str, ylabel: str, title:
def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str, int]:
"""
- Retrive the issue/PR comments made by our members in the last 7 days.
+ Retrieve the issue/PR comments made by our members in the last 7 days.
Args:
github_token (str): GitHub access token for API calls
@@ -89,7 +89,7 @@ def get_issue_pull_request_comments(github_token: str, since: str) -> Dict[str,
def get_discussion_comments(github_token, since) -> Dict[str, int]:
"""
- Retrive the discussion comments made by our members in the last 7 days.
+ Retrieve the discussion comments made by our members in the last 7 days.
This is only available via the GitHub GraphQL API.
Args:
@@ -194,7 +194,7 @@ def _call_graphql_api(query):
discussion_updated_at = datetime.strptime(discussion['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
# check if the updatedAt is within the last 7 days
- # if yes, add it to dicussion_numbers
+ # if yes, add it to discussion_numbers
if discussion_updated_at > since:
if discussion['authorAssociation'] != 'MEMBER':
discussion_numbers.append(discussion['number'])
@@ -207,14 +207,14 @@ def _call_graphql_api(query):
# update cursor
cursor = edges[-1]['cursor']
- # get the dicussion comments and replies made by our member
+ # get the discussion comments and replies made by our member
user_engagement_count = {}
- for dicussion_number in discussion_numbers:
+ for discussion_number in discussion_numbers:
cursor = None
num_per_request = 10
while True:
- query = _generate_comment_reply_count_for_discussion(dicussion_number, num_per_request, cursor)
+ query = _generate_comment_reply_count_for_discussion(discussion_number, num_per_request, cursor)
data = _call_graphql_api(query)
# get the comments
@@ -249,7 +249,7 @@ def _call_graphql_api(query):
reply = reply_edge['node']
if reply['authorAssociation'] == 'MEMBER':
# check if the updatedAt is within the last 7 days
- # if yes, add it to dicussion_numbers
+ # if yes, add it to discussion_numbers
reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ")
if reply_updated_at > since:
member_name = reply['author']['login']
diff --git a/applications/Chat/evaluate/README.md b/applications/Chat/evaluate/README.md
index e3510e3522f6..077193b63ce0 100644
--- a/applications/Chat/evaluate/README.md
+++ b/applications/Chat/evaluate/README.md
@@ -12,12 +12,13 @@ pip install -r requirements.txt
## Evaluation Pipeline
-The whole evaluation pipeline consists of two methods:
+The whole evaluation pipeline consists of three methods:
1. `GPT Evaluation`: evaluates model predictions using GPT models.
* Compare the performance of two different models (battle).
* Rate the model according to pre-defined metrics using prompting design.
2. `Automatic Evaluation`: evaluates model predictions using automatic metrics.
+3. `UniEval`: evaluates model predictions using UniEval models(English only).
### Evaluation Category
@@ -75,7 +76,9 @@ GPT evaluation uses GPT models to evaluate the prediction of different models an
GPT models evaluate the quality of model predictions based on the given prompt words and gives a score between 1-5.
-> **NOTE:** Even for the same metric, the details of its prompt words and CoT(Chain-of-Thought) can differ based on which category you want to evaluate. For example, prompt words for metric `correctness` showed here is "The answer should be in line with common sense, life experience, etc."(this is for category `brainstorming`), but for category `extraction`, prompt words can be "Answers should extract the required information accurately and should not contain any incorrect or misleading information." You can find all the prompt words and CoT(Chain-of-Thought) in `prompt/evaluation_prompt`.
+> **NOTE 1:** Even for the same metric, the details of its prompt words and CoT(Chain-of-Thought) can differ based on which category you want to evaluate. For example, prompt words for metric `correctness` showed here is "The answer should be in line with common sense, life experience, etc."(this is for category `brainstorming`), but for category `extraction`, prompt words can be "Answers should extract the required information accurately and should not contain any incorrect or misleading information." You can find all the prompt words and CoT(Chain-of-Thought) in `prompt/evaluation_prompt`.
+
+> **NOTE 2:** To add customized metrics, you can refer to [FAQ](#faq).
#### Automatic Evaluation
@@ -85,7 +88,7 @@ There are two ways to obtain reference answers:
* For instruction coming from human-designed problems, the reference answers are generated by GPT-3.5, such as roleplay, chat.
* For instruction related with classic NLP problems, the reference answers are collected from open-sourced dataset with target answers, such as classification, extraction, summarization.
-There are 5 types of automatic evaluation metrics listed in the table below:
+There are 6 types of automatic evaluation metrics listed in the table below:
| Automatic Evaluation Metric | Description |
| :---------------------------------: | :----------------------------------------------------------- |
@@ -94,6 +97,25 @@ There are 5 types of automatic evaluation metrics listed in the table below:
| Distinct | Measure the diversity of generation text by counting the unique n-grams. |
| BERTScore | Measure the semantic similarity between tokens of predictions and references with BERT. |
| Precision
Recall
F1 Score | Measure the number of overlaps between prediction and reference (design for classification and extraction categories). |
+| CHRF | Measure the similarity of character n-grams between prediction and reference. |
+
+#### UniEval Evaluation
+
+UniEval converts all evaluation tasks of different dimensions(metrics) into Boolean QA problems and utilize the model to answer with “Yes” or “No”. Compared with similarity-based metrics such as ROUGE and BLEU, UniEval can achieve a more comprehensive evaluation. In addition, UniEval also demonstrates its ability to transfer to unseen dimensions and tasks.
+
+In our evaluation pipeline, two pre-trained UniEval evaluators are used. One is [unieval-sum](https://huggingface.co/MingZhong/unieval-sum) and the other is [unieval-dialog](https://huggingface.co/MingZhong/unieval-dialog). The two models can be used for the 3 tasks, `summarization`, `dialogue` and `data2text`. Each task has different evaluation dimensions.
+
+| UniEval Model | Task | Dimension(Metric) |
+| :------------: | :----------------- | :--- |
+| unieval-sum | summarization | coherence: whether the summary is coherent
consistency: whether the claim is consistent with the given document
fluency: whether the paragraph is fluent
relevance: whether the summary is relevant to the reference |
+| unieval-sum | data2text | naturalness: whether the utterance is fluent
informativeness: whether the utterance is informative according to the reference |
+| unieval-dialog | dialogue | naturalness: whether the response is natural in the dialogue
coherence: whether the response is coherent in the dialogue history
understandability: whether the response is understandable in the dialogue |
+
+> **NOTE 1:** Task "data2text" uses the same model as task "summarization".
+
+> **NOTE 2:** In UniEval paper, the `unieval-sum` model demonstrates the best transfer ability and so you can evaluate your customized metric with this model. Details of adding customized metrics can be found in [FAQ](#faq).
+
+> **NOTE 3:** We consider not including all metrics provided in UniEval in our pipeline because the data structure and content of the instructions we want to evaluate are not suitable for direct use of some UniEval metrics.
## Evaluation Process
@@ -215,19 +237,26 @@ The following is an example of a Chinese GPT evaluation prompt. In an evaluation
#### Configuration
-The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics and automatic metrics in key `GPT` and `Metrics`. You can find an example Chinese config file in `config`.
+The following is an example of a Chinese config file. The configuration file can control how the pipeline evaluates the model. You need to specify GPT evaluation metrics, automatic metrics and UniEval metrics in key `GPT`, `Metrics` and `UniEval`(English only). You can find an example English config file in `config`.
```json
{
- "language": "cn",
+ "language": "en",
+ "path_for_UniEval": {
+ "summarization": "path to unieval-sum model",
+ "dialogue": "path to unieval-dialog model",
+ "data2text": "path to unieval-sum model"
+ },
"category": {
"brainstorming": {
"GPT": ["relevance", "creativity", "practicality", "correctness"],
- "Metrics": ["Distinct"]
+ "Metrics": ["Distinct"],
+ "UniEval": ["summarization-fluency", "data2text-naturalness", "data2text-informativeness"]
},
"chat": {
"GPT": [ "relevance", "naturalness", "engagingness", "reasonableness"],
- "Metrics": ["Distinct"]
+ "Metrics": ["Distinct"],
+ "UniEval": ["dialogue-naturalness", "dialogue-coherence", "dialogue-understandability"]
}
}
}
@@ -235,27 +264,33 @@ The following is an example of a Chinese config file. The configuration file can
`"language"`: the language used to evaluate the model capability. We only support Chinese `"cn"` for now.
+`"path_for_UniEval"`: path to the UniEval model.
+
`"category"`: the category/categories needed to evaluate the model capability.
`"GPT"`: the metrics you want to use for GPT evaluation.
`"Metrics"`: the metrics you want to use for automatic metrics evaluation.
+`"UniEval"`: the metrics you want to use for UniEval metrics evaluation. The metric has to be in the `"{task}-{metric}"` format because different tasks have same metrics such as naturalness and coherence.
+
+You can remove the key such as `"Metrics"` to skip evaluating answers using its corresponding evaluation metrics.
+
You can create your config file based on available settings listed in following table.
-| "category" | "GPT" | "Metrics" |
-| :--------------: | :---------------------: | :---------: |
-| "brainstorming" | "language organization" | "BLEU" |
-| "chat" | "relevance" | "ROUGE" |
-| "classification" | "creativity" | "Distinct" |
-| "closed_qa" | "practicality" | "BERTScore" |
-| "extraction" | "correctness" | "Precision" |
-| "generation" | "naturalness" | "Recall" |
-| "open_qa" | "engagingness" | "F1 score" |
-| "rewriting" | "reasonableness" | |
-| "roleplay" | "diversity" | |
-| "summarization" | "fidelity" | |
-| | "conciseness" | |
+| "category" | "GPT" | "Metrics" | "UniEval" |
+| :--------------: | :---------------------: | :---------: | :--------------------------: |
+| "brainstorming" | "language organization" | "BLEU" | "dialogue-naturalness" |
+| "chat" | "relevance" | "ROUGE" | "dialogue-coherence" |
+| "classification" | "creativity" | "Distinct" | "dialogue-understandability" |
+| "closed_qa" | "practicality" | "BERTScore" | "data2text-naturalness" |
+| "extraction" | "correctness" | "Precision" | "data2text-informativeness" |
+| "generation" | "naturalness" | "Recall" | "summarization-coherence" |
+| "open_qa" | "engagingness" | "F1 score" | "summarization-consistency" |
+| "rewriting" | "reasonableness" | "CHRF" | "summarization-fluency" |
+| "roleplay" | "diversity" | | "summarization-relevance" |
+| "summarization" | "fidelity" | | |
+| | "conciseness" | | |
> **NOTE:** For categories which don't have standard answers such as `brainstorming`, you should avoid using automatic metrics such as `BLEU` and `ROUGE` which are based on similarity measures and you should use `Distinct` instead in your config file.
@@ -290,23 +325,36 @@ For example, if you want to add a new metric `persuasiveness` into category `bra
"id": 1,
"category": "brainstorming",
"metrics": {
- "persuasiveness": "说服力(1-5):XXX"
+ "persuasiveness": "persuasiveness(1-5):a short description for persuasiveness"
},
"CoT": {
- "persuasiveness": "XXX\n\n说服力:"
+ "persuasiveness": "CoT for persuasiveness\n\npersuasiveness:"
},
- "prompt": "你是一个好助手。请你为下面“头脑风暴”问题的答案打分。\n\n问题如下:\n\n{question}\n\n答案如下:\n\n{answer}\n\n评分的指标如下:\n\n{metric}\n\n请你遵照以下的评分步骤:\n\n{steps}"
+ "prompt": "You are a good assistant. Please rate the given answer to the \"brainstorming\" question below.\n\nThe question is as follows:\n\n{question}\n\nThe answer is as follows:\n\n{answer}\n\nThe metric for evaluation is as follows:\n\n{metric}\n\nYou should follow the following evaluation steps:\n\n{steps}"
}
}
```
+How can I add a new UniEval evaluation metric?
+
+For example, if you want to add a new metric `persuasiveness` into task `data2text`, you should add a Boolean QA question about the metric in function `add_question` in `unieval/utils.py`. Please do note that how effectively the model would evaluate this metric is unknown and you may need some experiments to test whether the model is capable of evaluating this metric.
+
+```python
+if task == 'data2text':
+ if dimension == 'persuasiveness':
+ cur_input = 'question: Is this a persuasive utterence utterance: ' + output[i]
+```
+
+
+
## To Do
- [x] Add evaluation for English capability
-- [ ] Support UniEval
+- [x] Support UniEval
- [x] Support GPT-4 evaluation
+- [ ] Support GPT evaluation with reference in the prompt
## Citations
@@ -327,4 +375,13 @@ For example, if you want to add a new metric `persuasiveness` into category `bra
archivePrefix={arXiv},
primaryClass={cs.CL}
}
+
+@misc{zhong2022unified,
+ title={Towards a Unified Multi-Dimensional Evaluator for Text Generation},
+ author={Ming Zhong and Yang Liu and Da Yin and Yuning Mao and Yizhu Jiao and Pengfei Liu and Chenguang Zhu and Heng Ji and Jiawei Han},
+ year={2022},
+ eprint={2210.07197},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL}
+}
```
diff --git a/applications/Chat/evaluate/config/config_cn.json b/applications/Chat/evaluate/config/config_cn.json
index a8c7ea8a3135..cf647f79bbf8 100644
--- a/applications/Chat/evaluate/config/config_cn.json
+++ b/applications/Chat/evaluate/config/config_cn.json
@@ -34,7 +34,8 @@
"Metrics": [
"Precision",
"Recall",
- "F1 score"
+ "F1 score",
+ "CHRF"
]
},
"closed_qa": {
@@ -46,7 +47,8 @@
"Metrics": [
"BLEU",
"ROUGE",
- "BERTScore"
+ "BERTScore",
+ "CHRF"
]
},
"extraction": {
@@ -58,7 +60,8 @@
"Metrics": [
"Precision",
"Recall",
- "F1 score"
+ "F1 score",
+ "CHRF"
]
},
"generation": {
@@ -116,7 +119,8 @@
"Metrics": [
"BLEU",
"ROUGE",
- "BERTScore"
+ "BERTScore",
+ "CHRF"
]
}
}
diff --git a/applications/Chat/evaluate/config/config_en.json b/applications/Chat/evaluate/config/config_en.json
index 5b6272b97084..014c61d93a54 100644
--- a/applications/Chat/evaluate/config/config_en.json
+++ b/applications/Chat/evaluate/config/config_en.json
@@ -1,5 +1,10 @@
{
"language": "en",
+ "path_for_UniEval": {
+ "summarization": "path to unieval-sum",
+ "dialogue": "path to unieval-dialog",
+ "data2text": "path to unieval-sum"
+ },
"category": {
"brainstorming": {
"GPT": [
@@ -11,6 +16,11 @@
],
"Metrics": [
"Distinct"
+ ],
+ "UniEval": [
+ "summarization-fluency",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
},
"chat": {
@@ -23,6 +33,14 @@
],
"Metrics": [
"Distinct"
+ ],
+ "UniEval": [
+ "summarization-fluency",
+ "dialogue-naturalness",
+ "dialogue-coherence",
+ "dialogue-understandability",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
},
"classification": {
@@ -34,7 +52,13 @@
"Metrics": [
"Precision",
"Recall",
- "F1 score"
+ "F1 score",
+ "CHRF"
+ ],
+ "UniEval": [
+ "summarization-fluency",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
},
"closed_qa": {
@@ -46,7 +70,13 @@
"Metrics": [
"BLEU",
"ROUGE",
- "BERTScore"
+ "BERTScore",
+ "CHRF"
+ ],
+ "UniEval": [
+ "summarization-fluency",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
},
"extraction": {
@@ -58,7 +88,13 @@
"Metrics": [
"Precision",
"Recall",
- "F1 score"
+ "F1 score",
+ "CHRF"
+ ],
+ "UniEval": [
+ "summarization-fluency",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
},
"generation": {
@@ -71,6 +107,11 @@
"BLEU",
"ROUGE",
"BERTScore"
+ ],
+ "UniEval": [
+ "summarization-fluency",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
},
"open_qa": {
@@ -81,6 +122,11 @@
],
"Metrics": [
"Distinct"
+ ],
+ "UniEval": [
+ "summarization-fluency",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
},
"rewriting": {
@@ -93,6 +139,11 @@
"BLEU",
"ROUGE",
"BERTScore"
+ ],
+ "UniEval": [
+ "summarization-fluency",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
},
"roleplay": {
@@ -104,6 +155,11 @@
],
"Metrics": [
"Distinct"
+ ],
+ "UniEval": [
+ "summarization-fluency",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
},
"summarization": {
@@ -116,7 +172,16 @@
"Metrics": [
"BLEU",
"ROUGE",
- "BERTScore"
+ "BERTScore",
+ "CHRF"
+ ],
+ "UniEval": [
+ "summarization-coherence",
+ "summarization-consistency",
+ "summarization-fluency",
+ "summarization-relevance",
+ "data2text-naturalness",
+ "data2text-informativeness"
]
}
}
diff --git a/applications/Chat/evaluate/eval.py b/applications/Chat/evaluate/eval.py
index 8388d95f748a..180ef438cc43 100644
--- a/applications/Chat/evaluate/eval.py
+++ b/applications/Chat/evaluate/eval.py
@@ -40,7 +40,7 @@ def main(args):
# initialize evaluator
evaluator = Evaluator(metrics_per_category, battle_prompt, gpt_evaluation_prompt, args.gpt_model,
- config["language"])
+ config["language"], config.get("path_for_UniEval", None))
if len(args.model_name_list) == 2:
answers1 = jload(args.answer_file_list[0])
answers2 = jload(args.answer_file_list[1])
diff --git a/applications/Chat/evaluate/evaluator.py b/applications/Chat/evaluate/evaluator.py
index 0bf55ca80d7c..6bb8cdb29431 100644
--- a/applications/Chat/evaluate/evaluator.py
+++ b/applications/Chat/evaluate/evaluator.py
@@ -4,6 +4,7 @@
import gpt_evaluate
import metrics
import pandas as pd
+import unieval
from utils import analyze_automatic_results, get_data_per_category, save_automatic_results
@@ -15,13 +16,15 @@ class Evaluator(object):
"""
def __init__(self, params: Dict[str, Any], battle_prompt: Dict[str, Any], gpt_evaluation_prompt: Dict[str, Any],
- gpt_model: str, language: str) -> None:
+ gpt_model: str, language: str, path_for_UniEval: Dict[str, str]) -> None:
self.params = params
self.battle_prompt = battle_prompt
self.gpt_evaluation_prompt = gpt_evaluation_prompt
self.gpt_model = gpt_model
self.language = language
+ self.path_for_UniEval = path_for_UniEval
self.automatic_metric_stats = dict()
+ self.unieval_metric_stats = dict()
self.gpt_evaluation_results = dict()
self.battle_results = []
@@ -47,16 +50,18 @@ def switch(metric, language):
return metrics.bleu_score(preds=predicts_list, targets=targets_list, language=language)
elif metric == "ROUGE":
return metrics.rouge_score(preds=predicts_list, targets=targets_list, language=language)
- elif (metric == "Distinct"):
+ elif metric == "Distinct":
return metrics.distinct_score(preds=predicts_list, language=language)
- elif (metric == "BERTScore"):
+ elif metric == "BERTScore":
return metrics.bert_score(preds=predicts_list, targets=targets_list, language=language)
- elif (metric == "Precision"):
+ elif metric == "Precision":
return metrics.precision(preds=predicts_list, targets=targets_list, language=language)
- elif (metric == "Recall"):
+ elif metric == "Recall":
return metrics.recall(preds=predicts_list, targets=targets_list, language=language)
- elif (metric == "F1 score"):
+ elif metric == "F1 score":
return metrics.F1_score(preds=predicts_list, targets=targets_list, language=language)
+ elif metric == "CHRF":
+ return metrics.chrf_score(preds=predicts_list, targets=targets_list, language=language)
else:
raise ValueError(f"Unexpected metric")
@@ -69,6 +74,9 @@ def switch(metric, language):
print(f"Category {category} specified in your config doesn't have corresponding answers!")
continue
+ if self.params[category].get("Metrics", None) is None:
+ continue
+
category_metrics = self.params[category]["Metrics"]
self.automatic_metric_stats[category] = {}
@@ -80,12 +88,68 @@ def switch(metric, language):
for metric in category_metrics:
self.automatic_metric_stats[category].update(switch(metric=metric, language=self.language))
+ # UniEval evaluation
+ # self.unieval_metric_stats's key is "task" instead of "category".
+ # Iterating "task" first will avoid repeated loading models because one task corresponds to one UniEval model.
+ # If key is "category", different models will be loaded for multiple times across categories because the user may require different task(models) to evaluate one category.
+ for category in self.params:
+ if len(answers_per_category[category]) == 0:
+ print(f"Category {category} specified in your config doesn't have corresponding answers!")
+ continue
+
+ if self.params[category].get("UniEval", None) is None:
+ continue
+
+ if self.params[category]["UniEval"] and self.language == "cn":
+ raise Exception(
+ "UniEval doesn't support Chinese! Please remove UniEval config in your Chinese config file.")
+
+ category_metrics = self.params[category]["UniEval"]
+
+ for task, metric in [tuple(category_metric.split("-")) for category_metric in category_metrics]:
+ if self.unieval_metric_stats.get(task, None) is None:
+ self.unieval_metric_stats[task] = {category: {metric: 0}}
+ elif self.unieval_metric_stats[task].get(category, None) is None:
+ self.unieval_metric_stats[task][category] = {metric: 0}
+ else:
+ self.unieval_metric_stats[task][category][metric] = 0
+
+ for task in self.unieval_metric_stats:
+ if self.path_for_UniEval is None:
+ raise Exception(f"Please specify the path for UniEval model in the config file!")
+
+ if self.path_for_UniEval.get(task, None) is None:
+ raise Exception(f"Please specify the model path for task {task} in the config file!")
+
+ print(f"Load UniEval model for task {task}.")
+
+ uni_evaluator = unieval.get_evaluator(task, model_name_or_path=self.path_for_UniEval[task])
+ for category in self.unieval_metric_stats[task]:
+ targets_list = [
+ target["target"] if target["target"] else target["output"]
+ for target in targets_per_category[category]
+ ]
+ predicts_list = [answer["output"] for answer in answers_per_category[category]]
+ sources_list = [answer["instruction"] + answer["input"] for answer in answers_per_category[category]]
+
+ data = unieval.convert_data_to_unieval_format(predicts_list, sources_list, targets_list)
+ scores = uni_evaluator.evaluate(data,
+ category,
+ dims=list(self.unieval_metric_stats[task][category].keys()),
+ overall=False)
+ avg_scores = unieval.calculate_average_score(scores)
+
+ self.unieval_metric_stats[task][category].update(avg_scores)
+
# gpt evaluation
for category in self.params:
if len(answers_per_category[category]) == 0:
print(f"Category {category} specified in your config doesn't have corresponding answers!")
continue
+ if self.params[category].get("GPT", None) is None:
+ continue
+
category_metrics = self.params[category]["GPT"]
prompt = self.gpt_evaluation_prompt.get(category, None)
@@ -106,29 +170,43 @@ def save(self, path: str, model_name_list: List[str]) -> None:
save_path = os.path.join(path, "gpt_evaluate", "battle_results")
gpt_evaluate.save_battle_results(self.battle_results, model_name_list[0], model_name_list[1], save_path)
else:
- # Save evaluation results for automatic metrics
- automatic_base_save_path = os.path.join(path, "automatic_results")
- automatic_results_save_path = os.path.join(automatic_base_save_path, "evaluation_results")
-
- save_automatic_results(model_name_list[0], self.automatic_metric_stats, automatic_results_save_path)
-
- # Save charts and csv.
- automatic_analyses_save_path = os.path.join(automatic_base_save_path, "evaluation_analyses")
- analyze_automatic_results(automatic_results_save_path, automatic_analyses_save_path)
-
- # Save evaluation results for GPT evaluation metrics.
- gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
- gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
-
- all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0], self.gpt_evaluation_results,
- gpt_evaluation_results_save_path)
-
- # Start to calculate scores and save statistics.
- gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
- gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations,
- gpt_evaluation_statistics_save_path)
-
- # Save charts and csv.
- gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
- gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path,
- gpt_evaluation_analyses_save_path)
+ if self.automatic_metric_stats:
+ # Save evaluation results for automatic metrics
+ automatic_base_save_path = os.path.join(path, "automatic_results")
+ automatic_results_save_path = os.path.join(automatic_base_save_path, "evaluation_results")
+
+ save_automatic_results(model_name_list[0], self.automatic_metric_stats, automatic_results_save_path)
+
+ # Save charts and csv.
+ automatic_analyses_save_path = os.path.join(automatic_base_save_path, "evaluation_analyses")
+ analyze_automatic_results(automatic_results_save_path, automatic_analyses_save_path)
+
+ if self.unieval_metric_stats:
+ # Save evaluation results for UniEval metrics
+ unieval_base_save_path = os.path.join(path, "unieval_results")
+ unieval_results_save_path = os.path.join(unieval_base_save_path, "evaluation_results")
+
+ unieval.save_unieval_results(model_name_list[0], self.unieval_metric_stats, unieval_results_save_path)
+
+ # Save charts and csv.
+ unieval_analyses_save_path = os.path.join(unieval_base_save_path, "evaluation_analyses")
+ unieval.analyze_unieval_results(unieval_results_save_path, unieval_analyses_save_path)
+
+ if self.gpt_evaluation_results:
+ # Save evaluation results for GPT evaluation metrics.
+ gpt_base_save_path = os.path.join(path, "gpt_evaluate", "gpt_evaluate_results")
+ gpt_evaluation_results_save_path = os.path.join(gpt_base_save_path, "evaluation_results")
+
+ all_evaluations = gpt_evaluate.save_gpt_evaluation_results(model_name_list[0],
+ self.gpt_evaluation_results,
+ gpt_evaluation_results_save_path)
+
+ # Start to calculate scores and save statistics.
+ gpt_evaluation_statistics_save_path = os.path.join(gpt_base_save_path, "evaluation_statistics")
+ gpt_evaluate.save_gpt_evaluation_statistics(model_name_list[0], all_evaluations,
+ gpt_evaluation_statistics_save_path)
+
+ # Save charts and csv.
+ gpt_evaluation_analyses_save_path = os.path.join(gpt_base_save_path, "evaluation_analyses")
+ gpt_evaluate.analyze_gpt_evaluation_statistics(gpt_evaluation_statistics_save_path,
+ gpt_evaluation_analyses_save_path)
diff --git a/applications/Chat/evaluate/gpt_evaluate.py b/applications/Chat/evaluate/gpt_evaluate.py
index b433500dfa04..6702526ac5e6 100644
--- a/applications/Chat/evaluate/gpt_evaluate.py
+++ b/applications/Chat/evaluate/gpt_evaluate.py
@@ -599,7 +599,7 @@ def analyze_gpt_evaluation_statistics(statistics_path: str, save_path: str) -> N
for category in tqdm.tqdm(
frame_per_category.keys(),
- desc=f"category: ",
+ desc=f"GPT evaluation: ",
total=len(frame_per_category.keys()),
):
data = pd.DataFrame(frame_per_category[category])
diff --git a/applications/Chat/evaluate/metrics.py b/applications/Chat/evaluate/metrics.py
index 031f6fa83926..e220226ec041 100644
--- a/applications/Chat/evaluate/metrics.py
+++ b/applications/Chat/evaluate/metrics.py
@@ -4,6 +4,7 @@
import jieba
from bert_score import score
from nltk.translate.bleu_score import sentence_bleu
+from nltk.translate.chrf_score import sentence_chrf
from rouge_chinese import Rouge as Rouge_cn
from rouge_score import rouge_scorer as Rouge_en
from sklearn.metrics import f1_score, precision_score, recall_score
@@ -40,6 +41,27 @@ def bleu_score(preds: List[str], targets: List[str], language: str) -> Dict[str,
return bleu_scores
+def chrf_score(preds: List[str], targets: List[str], language: str) -> Dict[str, float]:
+ """Calculate CHRF Score Metric in sentence level.
+ """
+ chrf_score = {"chrf": 0}
+ cumulative_chrf = []
+
+ for pred, target in zip(preds, targets):
+ if language == "cn":
+ pred_list = ' '.join(jieba.cut(preprocessing_text(pred))).split()
+ target_list = ' '.join(jieba.cut(preprocessing_text(target))).split()
+ elif language == "en":
+ pred_list = preprocessing_text(pred).split()
+ target_list = preprocessing_text(target).split()
+
+ cumulative_chrf.append(sentence_chrf(target_list, pred_list))
+
+ chrf_score["chrf"] = statistics.mean(cumulative_chrf)
+
+ return chrf_score
+
+
def rouge_cn_score(preds: List[str], targets: List[str]) -> Dict[str, float]:
"""Calculate Chinese ROUGE Score Metric
diff --git a/applications/Chat/evaluate/unieval/__init__.py b/applications/Chat/evaluate/unieval/__init__.py
new file mode 100644
index 000000000000..dad8d6ad09fa
--- /dev/null
+++ b/applications/Chat/evaluate/unieval/__init__.py
@@ -0,0 +1,12 @@
+from .evaluator import get_evaluator
+from .utils import (
+ analyze_unieval_results,
+ calculate_average_score,
+ convert_data_to_unieval_format,
+ save_unieval_results,
+)
+
+__all__ = [
+ 'get_evaluator', 'convert_data_to_unieval_format', 'calculate_average_score', 'save_unieval_results',
+ 'analyze_unieval_results'
+]
diff --git a/applications/Chat/evaluate/unieval/evaluator.py b/applications/Chat/evaluate/unieval/evaluator.py
new file mode 100644
index 000000000000..385425e4a576
--- /dev/null
+++ b/applications/Chat/evaluate/unieval/evaluator.py
@@ -0,0 +1,330 @@
+# MIT License
+
+# Copyright (c) 2022 Ming Zhong
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import numpy as np
+from nltk import sent_tokenize
+
+from .scorer import UniEvaluator
+from .utils import add_question
+
+
+class SumEvaluator:
+
+ def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
+ """ Set up evaluator for text summarization """
+ self.scorer = UniEvaluator(
+ model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
+ max_length=max_length,
+ device=device,
+ cache_dir=cache_dir)
+ self.task = 'summarization'
+ self.dimensions = ['coherence', 'consistency', 'fluency', 'relevance']
+
+ def evaluate(self, data, category, dims=None, overall=True):
+ """
+ Get the scores of all the given dimensions
+
+ category: The category to be evaluated.
+
+ dims: A list of dimensions to be evaluated. If dims is None, SumEvaluator will evaluate
+ four dimensions: coherence, consistency, fluency, relevance.
+
+ overall: indicates whether the overall score is to be calculated.
+ Overall score can be customized to a combination of scores based on different
+ dimensions. The default here is the average score of all the given dimensions.
+ """
+ n_data = len(data)
+ eval_scores = [{} for _ in range(n_data)]
+
+ if dims == None:
+ eval_dims = self.dimensions
+ else:
+ assert isinstance(dims, list)
+ eval_dims = dims
+
+ for dim in eval_dims:
+ # Calculate average sentence-level scores for 'consistency' and 'fluency'
+ if dim == 'consistency' or dim == 'fluency':
+ src_list, output_list = [], []
+ n_sents = [] # the number of sentences in each generated summary
+ for i in range(n_data):
+ source = data[i]['source']
+ system_outputs = sent_tokenize(data[i]['system_output'])
+ n_sents.append(len(system_outputs))
+ for j in range(len(system_outputs)):
+ src_list.append(source)
+ output_list.append(system_outputs[j])
+ input_list = add_question(dimension=dim, output=output_list, src=src_list, task=self.task)
+ sent_score = self.scorer.score(input_list, self.task, category, dim)
+
+ # Get average score for each sample
+ start_idx = 0
+ score = []
+ for cur_n_sent in n_sents:
+ score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent)
+ start_idx += cur_n_sent
+
+ # Calculate summary-level score for 'coherence' and 'relevance'
+ elif dim == 'coherence' or dim == 'relevance':
+ src_list, output_list, ref_list = [], [], []
+ for i in range(n_data):
+ src_list.append(data[i]['source'])
+ output_list.append(data[i]['system_output'])
+ if dim == 'relevance':
+ ref_list.append(data[i]['reference'])
+ input_list = add_question(dimension=dim, output=output_list, src=src_list, ref=ref_list, task=self.task)
+ score = self.scorer.score(input_list, self.task, category, dim)
+
+ # Please customize other dimensions here for summarization
+ else:
+ raise NotImplementedError('The input format for this dimension is still undefined. \
+ Please customize it first.')
+
+ for i in range(n_data):
+ eval_scores[i][dim] = score[i]
+
+ # Customize your overall score here.
+ if overall == True:
+ for i in range(n_data):
+ eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
+
+ return eval_scores
+
+
+class DialogEvaluator:
+
+ def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
+ """ Set up evaluator for dialogues """
+ self.scorer = UniEvaluator(
+ model_name_or_path='MingZhong/unieval-dialog' if model_name_or_path == "" else model_name_or_path,
+ max_length=max_length,
+ device=device,
+ cache_dir=cache_dir)
+ self.task = 'dialogue'
+ self.dimensions = ['naturalness', 'coherence', 'engagingness', 'groundedness', 'understandability']
+
+ def evaluate(self, data, category, dims=None, overall=True):
+ """
+ Get the scores of all the given dimensions
+
+ category: The category to be evaluated.
+
+ dims: A list of dimensions to be evaluated. If dims is None, DialogEvaluator will evaluate
+ five dimensions: naturalness, coherence, engagingness, groundedness and understandability.
+
+ overall: indicates whether the overall score is to be calculated.
+ Overall score can be customized to a combination of scores based on different
+ dimensions. The default here is the average score of all the given dimensions.
+ """
+ n_data = len(data)
+ eval_scores = [{} for _ in range(n_data)]
+
+ if dims == None:
+ eval_dims = self.dimensions
+ else:
+ assert isinstance(dims, list)
+ eval_dims = dims
+
+ for dim in eval_dims:
+ # Calculate summation score for 'engagingness'
+ if dim == 'engagingness':
+ src_list, output_list, context_list = [], [], []
+ n_sents = [] # the number of sentences in each generated response
+ for i in range(n_data):
+ source = data[i]['source']
+ context = data[i]['context']
+ system_outputs = sent_tokenize(data[i]['system_output'])
+ n_sents.append(len(system_outputs))
+ for j in range(len(system_outputs)):
+ src_list.append(source)
+ context_list.append(context)
+ output_list.append(system_outputs[j])
+ input_list = add_question(dimension=dim,
+ output=output_list,
+ src=src_list,
+ context=context_list,
+ task=self.task)
+ sent_score = self.scorer.score(input_list, self.task, category, dim)
+
+ # Get the summation score for each sample
+ start_idx = 0
+ score = []
+ for cur_n_sent in n_sents:
+ score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]))
+ start_idx += cur_n_sent
+
+ # Calculate turn-level score for other dimensions
+ elif dim in ['naturalness', 'coherence', 'groundedness', 'understandability']:
+ src_list, output_list, context_list = [], [], []
+ for i in range(n_data):
+ src_list.append(data[i]['source'])
+ output_list.append(data[i]['system_output'])
+ context_list.append(data[i]['context'])
+ input_list = add_question(dimension=dim,
+ output=output_list,
+ src=src_list,
+ context=context_list,
+ task=self.task)
+ score = self.scorer.score(input_list, self.task, category, dim)
+
+ # Please customize other dimensions here for summarization
+ else:
+ raise NotImplementedError('The input format for this dimension is still undefined. \
+ Please customize it first.')
+
+ for i in range(n_data):
+ eval_scores[i][dim] = score[i]
+
+ # Customize your overall score here.
+ if overall == True:
+ for i in range(n_data):
+ eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
+
+ return eval_scores
+
+
+class D2tEvaluator:
+
+ def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
+ """ Set up evaluator for data-to-text """
+ self.scorer = UniEvaluator(
+ model_name_or_path='MingZhong/unieval-sum' if model_name_or_path == "" else model_name_or_path,
+ max_length=max_length,
+ device=device,
+ cache_dir=cache_dir)
+ self.task = 'data2text'
+ self.dimensions = ['naturalness', 'informativeness']
+
+ def evaluate(self, data, category, dims=None, overall=True):
+ """
+ Get the scores of all the given dimensions
+
+ category: The category to be evaluated.
+
+ dims: A list of dimensions to be evaluated. If dims is None, D2tEvaluator will evaluate
+ two dimensions: naturalness and informativeness.
+
+ overall: indicates whether the overall score is to be calculated.
+ Overall score can be customized to a combination of scores based on different
+ dimensions. The default here is the average score of all the given dimensions.
+ """
+ n_data = len(data)
+ eval_scores = [{} for _ in range(n_data)]
+
+ if dims == None:
+ eval_dims = self.dimensions
+ else:
+ assert isinstance(dims, list)
+ eval_dims = dims
+
+ for dim in eval_dims:
+ output_list, ref_list = [], []
+ for i in range(n_data):
+ output_list.append(data[i]['system_output'])
+ ref_list.append(data[i]['reference'])
+
+ input_list = add_question(dimension=dim, output=output_list, ref=ref_list, task=self.task)
+ score = self.scorer.score(input_list, self.task, category, dim)
+
+ for i in range(n_data):
+ eval_scores[i][dim] = score[i]
+
+ # Customize your overall score here.
+ if overall == True:
+ for i in range(n_data):
+ eval_scores[i]['overall'] = np.mean(list(eval_scores[i].values()))
+
+ return eval_scores
+
+
+class FactEvaluator:
+
+ def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
+ """ Set up evaluator for factual consistency detection """
+ self.scorer = UniEvaluator(
+ model_name_or_path='MingZhong/unieval-fact' if model_name_or_path == "" else model_name_or_path,
+ max_length=max_length,
+ device=device,
+ cache_dir=cache_dir)
+ self.task = 'fact'
+ self.dim = 'consistency'
+
+ def evaluate(self, data, category):
+ """
+ Get the factual consistency score (only 1 dimension for this task)
+
+ category: The category to be evaluated.
+ """
+ n_data = len(data)
+ eval_scores = [{} for _ in range(n_data)]
+
+ # Calculate average sentence-level scores for facutal consistency
+ src_list, output_list = [], []
+ n_sents = [] # the number of sentences in the claim
+ for i in range(n_data):
+ source = data[i]['source']
+ system_outputs = sent_tokenize(data[i]['system_output'])
+ n_sents.append(len(system_outputs))
+ for j in range(len(system_outputs)):
+ src_list.append(source)
+ output_list.append(system_outputs[j])
+ input_list = add_question(dimension=self.dim, output=output_list, src=src_list, task=self.task)
+ sent_score = self.scorer.score(input_list, self.task, category, dim)
+
+ # Get average score for each sample
+ start_idx = 0
+ score = []
+ for cur_n_sent in n_sents:
+ score.append(sum(sent_score[start_idx:start_idx + cur_n_sent]) / cur_n_sent)
+ start_idx += cur_n_sent
+
+ for i in range(n_data):
+ eval_scores[i][self.dim] = score[i]
+
+ return eval_scores
+
+
+def get_evaluator(task, model_name_or_path="", max_length=1024, device='cuda:0', cache_dir=None):
+ assert task in ['summarization', 'dialogue', 'data2text', 'fact']
+ if task == 'summarization':
+ return SumEvaluator(model_name_or_path=model_name_or_path,
+ max_length=max_length,
+ device=device,
+ cache_dir=cache_dir)
+ elif task == 'dialogue':
+ return DialogEvaluator(model_name_or_path=model_name_or_path,
+ max_length=max_length,
+ device=device,
+ cache_dir=cache_dir)
+ elif task == 'data2text':
+ return D2tEvaluator(model_name_or_path=model_name_or_path,
+ max_length=max_length,
+ device=device,
+ cache_dir=cache_dir)
+ elif task == 'fact':
+ return FactEvaluator(model_name_or_path=model_name_or_path,
+ max_length=max_length,
+ device=device,
+ cache_dir=cache_dir)
+ else:
+ raise NotImplementedError('Other tasks are not implemented, \
+ please customize specific tasks here.')
diff --git a/applications/Chat/evaluate/unieval/scorer.py b/applications/Chat/evaluate/unieval/scorer.py
new file mode 100644
index 000000000000..2c70bb9f6ded
--- /dev/null
+++ b/applications/Chat/evaluate/unieval/scorer.py
@@ -0,0 +1,101 @@
+# MIT License
+
+# Copyright (c) 2022 Ming Zhong
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
+
+
+class UniEvaluator:
+
+ def __init__(self, model_name_or_path, max_length=1024, device='cuda:0', cache_dir=None):
+ """ Set up model """
+ self.device = device
+ self.max_length = max_length
+
+ self.config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, config=self.config, cache_dir=cache_dir)
+
+ self.model.eval()
+ self.model.to(device)
+
+ self.softmax = nn.Softmax(dim=1)
+
+ self.pos_id = self.tokenizer("Yes")["input_ids"][0]
+ self.neg_id = self.tokenizer("No")["input_ids"][0]
+
+ def score(self, inputs, task, category, dim, batch_size=8):
+ """
+ Get scores for the given samples.
+ final_score = postive_score / (postive_score + negative_score)
+ """
+
+ # The implementation of "forward" in T5 still requires decoder_input_ids.
+ # Therefore, we construct a random one-word target sequence.
+ # The content of the target has no effect on the final scores.
+ tgts = ["No" for _ in range(len(inputs))]
+
+ pos_score_list, neg_score_list = [], []
+ for i in tqdm(range(0, len(inputs), batch_size), desc=f"{category}-({dim}-{task}): "):
+ src_list = inputs[i:i + batch_size]
+ tgt_list = tgts[i:i + batch_size]
+ try:
+ with torch.no_grad():
+ encoded_src = self.tokenizer(src_list,
+ max_length=self.max_length,
+ truncation=True,
+ padding=True,
+ return_tensors='pt')
+ encoded_tgt = self.tokenizer(tgt_list,
+ max_length=self.max_length,
+ truncation=True,
+ padding=True,
+ return_tensors='pt')
+
+ src_tokens = encoded_src['input_ids'].to(self.device)
+ src_mask = encoded_src['attention_mask'].to(self.device)
+
+ tgt_tokens = encoded_tgt['input_ids'].to(self.device)[:, 0].unsqueeze(-1)
+
+ output = self.model(input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens)
+ logits = output.logits.view(-1, self.model.config.vocab_size)
+
+ pos_score = self.softmax(logits)[:, self.pos_id] # Yes
+ neg_score = self.softmax(logits)[:, self.neg_id] # No
+
+ cur_pos_score = [x.item() for x in pos_score]
+ cur_neg_score = [x.item() for x in neg_score]
+ pos_score_list += cur_pos_score
+ neg_score_list += cur_neg_score
+
+ except RuntimeError:
+ print(f'source: {src_list}')
+ print(f'target: {tgt_list}')
+ exit(0)
+
+ score_list = []
+ for i in range(len(pos_score_list)):
+ score_list.append(pos_score_list[i] / (pos_score_list[i] + neg_score_list[i]))
+
+ return score_list
diff --git a/applications/Chat/evaluate/unieval/utils.py b/applications/Chat/evaluate/unieval/utils.py
new file mode 100644
index 000000000000..a77505faa0d2
--- /dev/null
+++ b/applications/Chat/evaluate/unieval/utils.py
@@ -0,0 +1,248 @@
+# MIT License
+
+# Copyright (c) 2022 Ming Zhong
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import os
+from typing import Dict
+
+import matplotlib.pyplot as plt
+import pandas as pd
+import seaborn as sns
+import tqdm
+
+
+def add_question(dimension, output, src=None, ref=None, context=None, task=None):
+ """
+ Add questions to generate input in Bool-QA format for UniEval.
+
+ dimension: specific dimension to be evaluated
+ src: source input for different NLG tasks. For example, source document for summarization
+ and dialogue history for dialogue response generation.
+ output: output text generated by the models
+ ref: human-annotataed groundtruth
+ context: the context needed to evaluate several specific dimension. For example,
+ additional factual information when evaluating engagingness and groundedness in dialogues.
+ """
+
+ input_with_question = []
+ for i in range(len(output)):
+ # For summarization
+ if task == 'summarization':
+ if dimension == 'fluency':
+ cur_input = 'question: Is this a fluent paragraph? paragraph: ' + output[i]
+ elif dimension == 'coherence':
+ cur_input = 'question: Is this a coherent summary to the document? summary: ' + output[
+ i] + ' document: ' + src[i]
+ elif dimension == 'consistency':
+ cur_input = 'question: Is this claim consistent with the document? claim: ' + output[
+ i] + ' document: ' + src[i]
+ elif dimension == 'relevance':
+ cur_input = 'question: Is this summary relevant to the reference? summary: ' + output[
+ i] + ' reference: ' + ref[i]
+ else:
+ raise NotImplementedError(
+ 'The input format for this dimension is still undefined. Please customize it first.')
+ # For dialogues
+ elif task == 'dialogue':
+ if dimension == 'naturalness':
+ cur_input = 'question: Is this a natural response in the dialogue? response: ' + output[i]
+ elif dimension == 'coherence':
+ cur_input = 'question: Is this a coherent response given the dialogue history? response: '\
+ + output[i] + ' dialogue history: ' + src[i]
+ elif dimension == 'engagingness':
+ cur_input = 'question: Is this an engaging and informative response according to the dialogue history and fact? response: '\
+ + output[i] + ' dialogue history: ' + src[i] + ' fact: ' + context[i]
+ elif dimension == 'groundedness':
+ cur_input = 'question: Is this response consistent with knowledge in the fact? response: '\
+ + output[i] + ' fact: ' + context[i]
+ elif dimension == 'understandability':
+ cur_input = 'question: Is this an understandable response in the dialogue? response: ' + output[i]
+ else:
+ raise NotImplementedError(
+ 'The input format for this dimension is still undefined. Please customize it first.')
+ # For data-to-text
+ elif task == 'data2text':
+ if dimension == 'naturalness':
+ cur_input = 'question: Is this a fluent utterance? utterance: ' + output[i]
+ elif dimension == 'informativeness':
+ cur_input = 'question: Is this sentence informative according to the reference? sentence: '\
+ + output[i] + ' reference: ' + ref[i]
+ else:
+ raise NotImplementedError(
+ 'The input format for this dimension is still undefined. Please customize it first.')
+ # For factual consistency detection
+ elif task == 'fact':
+ if dimension == 'consistency':
+ cur_input = 'question: Is this claim consistent with the document? claim: ' + output[
+ i] + ' document: ' + src[i]
+ else:
+ raise NotImplementedError('No other dimensions for the factual consistency detection task.')
+ # For new customized tasks
+ else:
+ raise NotImplementedError('Other tasks are not implemented, please customize specific tasks here.')
+ input_with_question.append(cur_input)
+ return input_with_question
+
+
+def convert_data_to_unieval_format(output_list, src_list=None, ref_list=None):
+ """
+ Convert the data into the unieval's format.
+
+ output_list: a list of model output
+
+ src_list: source input for different NLG tasks. For example, source document for summarization
+ and dialogue history for dialogue response generation
+ ref_list: human-annotated groundtruth
+ """
+ json_data = []
+ for i in range(len(output_list)):
+ cur = {}
+ cur['system_output'] = output_list[i]
+ if src_list is not None:
+ cur['source'] = src_list[i]
+ if ref_list is not None:
+ cur['reference'] = ref_list[i]
+ cur['context'] = ""
+ json_data.append(cur)
+ return json_data
+
+
+def calculate_average_score(scores):
+ """
+ Calculate average scores for different metrics
+
+ scores: a list of scores for different metrics for each answer
+
+ """
+ metrics = {metric: 0 for metric in scores[0]}
+
+ for score in scores:
+ for metric in score:
+ metrics[metric] += score[metric]
+
+ for metric in metrics:
+ metrics[metric] /= len(scores)
+
+ return metrics
+
+
+def save_unieval_results(model_name: str, unieval_metric_stats: Dict[str, Dict], save_path: str) -> None:
+ """
+ Save UniEval evaluation results of different categories for one model.
+
+ """
+
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+ unieval_metric_stats_per_category = {}
+ for task, category_stat in unieval_metric_stats.items():
+ for category, metric_stat in category_stat.items():
+ if unieval_metric_stats_per_category.get(category, None) is None:
+ unieval_metric_stats_per_category[category] = {}
+ for metric, score in metric_stat.items():
+ unieval_metric_stats_per_category[category][f"{metric}-{task}"] = score
+
+ automatic_df = pd.DataFrame(unieval_metric_stats_per_category)
+ automatic_df.to_csv(os.path.join(save_path, f"{model_name}_results.csv"), index=True)
+
+
+def read_unieval_results(results_path: str, file_name: str) -> Dict[str, Dict]:
+ """
+ Read a csv file and return a dictionary which stores scores per metric.
+
+ """
+
+ results = pd.read_csv(os.path.join(results_path, file_name), index_col=0)
+
+ results_dict = {metric: {} for metric in list(results.index)}
+ for i, metric in enumerate(results_dict.keys()):
+ for j, category in enumerate(list(results.columns)):
+ if pd.isnull(results.iloc[i][j]):
+ continue
+ results_dict[metric][category] = results.iloc[i][j]
+
+ return results_dict
+
+
+def analyze_unieval_results(results_path: str, save_path: str) -> None:
+ """
+ Analyze and visualize all csv files in the given folder.
+
+ """
+
+ if not os.path.exists(results_path):
+ raise Exception(f'The given directory "{results_path}" doesn\'t exist! No results found!')
+
+ all_statistics = {}
+
+ for file_name in os.listdir(results_path):
+ if file_name.endswith("_results.csv"):
+ model_name = file_name.split("_results.csv")[0]
+ all_statistics[model_name] = read_unieval_results(results_path, file_name)
+
+ if len(list(all_statistics.keys())) == 0:
+ raise Exception(f'There are no csv files in the given directory "{results_path}"!')
+
+ frame_all = {"model": [], "category": [], "metric": [], "score": []}
+ frame_per_metric = {}
+ for model_name, model_statistics in all_statistics.items():
+ for metric, metric_statistics in model_statistics.items():
+ if frame_per_metric.get(metric) is None:
+ frame_per_metric[metric] = {"model": [], "category": [], "score": []}
+
+ for category, category_score in metric_statistics.items():
+ frame_all["model"].append(model_name)
+ frame_all["category"].append(category)
+ frame_all["metric"].append(metric)
+ frame_all["score"].append(category_score)
+
+ frame_per_metric[metric]["model"].append(model_name)
+ frame_per_metric[metric]["category"].append(category)
+ frame_per_metric[metric]["score"].append(category_score)
+
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+ frame_all = pd.DataFrame(frame_all)
+ frame_all.to_csv(os.path.join(save_path, "unieval_statistics.csv"))
+
+ for metric in tqdm.tqdm(
+ frame_per_metric.keys(),
+ desc=f"UniEval metrics: ",
+ total=len(frame_per_metric.keys()),
+ ):
+ data = pd.DataFrame(frame_per_metric[metric])
+
+ sns.set()
+ fig = plt.figure(figsize=(16, 10))
+
+ fig = sns.barplot(x="category", y="score", hue="model", data=data, dodge=True)
+ fig.set_title(
+ f"Comparison between Different Models for Metric {metric.split('-')[0].title()} in Task {metric.split('-')[1].title()}"
+ )
+ plt.xlabel("Evaluation Category")
+ plt.ylabel("Score")
+
+ figure = fig.get_figure()
+ figure.savefig(os.path.join(save_path, f"{metric}.png"), dpi=400)
+
+ plt.close()
diff --git a/applications/Chat/evaluate/utils.py b/applications/Chat/evaluate/utils.py
index 1f4069386fcd..fefe25f5e764 100644
--- a/applications/Chat/evaluate/utils.py
+++ b/applications/Chat/evaluate/utils.py
@@ -199,7 +199,7 @@ def analyze_automatic_results(results_path: str, save_path: str) -> None:
for metric in tqdm.tqdm(
frame_per_metric.keys(),
- desc=f"metric: ",
+ desc=f"automatic metrics: ",
total=len(frame_per_metric.keys()),
):
data = pd.DataFrame(frame_per_metric[metric])
diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py
index 46714fe1c679..4a7efc165cbd 100644
--- a/colossalai/booster/plugin/gemini_plugin.py
+++ b/colossalai/booster/plugin/gemini_plugin.py
@@ -99,8 +99,11 @@ def save_sharded_model(self,
save_state_dict(shard, checkpoint_file_path, use_safetensors)
index_file.append_meta_data("total_size", total_size)
- index_file.write_index_file(save_index_file)
- logging.info(f"The model is going to be split to checkpoint shards. "
+
+ # only save the index file on the master rank
+ if self.coordinator.is_master():
+ index_file.write_index_file(save_index_file)
+ logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
diff --git a/colossalai/checkpoint_io/index_file.py b/colossalai/checkpoint_io/index_file.py
index 334ecbc04738..a41cc482e054 100644
--- a/colossalai/checkpoint_io/index_file.py
+++ b/colossalai/checkpoint_io/index_file.py
@@ -1,8 +1,8 @@
import json
-from pathlib import Path
-from typing import Any, List, Union
import os
-import json
+from collections import OrderedDict
+from pathlib import Path
+from typing import Any, Dict, List, Union
from .utils import is_dtensor_checkpoint
@@ -22,8 +22,10 @@ class CheckpointIndexFile:
def __init__(self, root_path=None) -> None:
self.root_path = root_path
- self.metadata: dict = dict()
- self.weight_map: dict = dict()
+
+ # use ordered dict to preserve the tensor checkpoint order
+ self.metadata: Dict = OrderedDict()
+ self.weight_map: Dict = OrderedDict()
@staticmethod
def from_file(index_path: Union[str, Path]):
@@ -150,13 +152,13 @@ def get_checkpoint_file(self, param_name: str) -> str:
"""
ckpt_path = self.weight_map[param_name]
return ckpt_path
-
+
def get_all_param_names(self):
"""
Get all the weight keys.
"""
return list(self.weight_map.keys())
-
+
def write_index_file(self, save_index_file):
"""
Write index file.
@@ -164,5 +166,5 @@ def write_index_file(self, save_index_file):
save_index_file = os.path.join(self.root_path, save_index_file)
index = {"metadata": self.metadata, "weight_map": self.weight_map}
with open(save_index_file, "w", encoding="utf-8") as f:
- content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ content = json.dumps(index, indent=2) + "\n"
f.write(content)
diff --git a/colossalai/device/README.md b/colossalai/device/README.md
deleted file mode 100644
index 8f835735bef4..000000000000
--- a/colossalai/device/README.md
+++ /dev/null
@@ -1,73 +0,0 @@
-# 🗄 Device
-
-## 📚 Table of Contents
-
-- [🗄 Device](#-device)
- - [📚 Table of Contents](#-table-of-contents)
- - [🔗 Introduction](#-introduction)
- - [📝 Design](#-design)
- - [🔨 Usage](#-usage)
-
-## 🔗 Introduction
-
-This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.
-
-## 📝 Design
-
-
-This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.
-
-
-## 🔨 Usage
-
-- Create a device mesh
-
-```python
-# this is the list of global ranks involved in the device mesh
-# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
-physical_mesh_id = torch.arange(4)
-mesh_shape = [2, 2]
-device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
-```
-
-- View the mesh
-
-
-```python
-# view the mesh shape
-# expect output
-# [2, 2]
-print(device_mesh.shape)
-
-
-# view the logical mesh with global ranks
-# expect output
-# [
-# [0, 1],
-# [2, 3]
-# ]
-print(device_mesh.logical_mesh_id)
-
-# view the number of devices in the mesh
-# expect output
-# 4
-print(device_mesh.num_devices)
-
-```
-
-- Initialize the process group
-
-```python
-# intialize process group
-device_mesh.init_logical_process_group()
-
-
-# get the process group for a rank with respect to an axis
-# this is the process group involving global ranks 0 and 2
-print(device_mesh.get_process_group(axis=0, global_rank=0))
-
-# get the ranks in the process with respect to an axis
-# expect output
-# [0, 2]
-print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0))
-```
diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py
index 0490a440153e..2a5f747fbc23 100644
--- a/colossalai/device/device_mesh.py
+++ b/colossalai/device/device_mesh.py
@@ -3,19 +3,11 @@
with some changes. """
import operator
-from dataclasses import dataclass
from functools import reduce
-from typing import Dict, List, Union
+from typing import List, Tuple
import torch
import torch.distributed as dist
-from torch.distributed import ProcessGroup
-
-
-@dataclass
-class ProcessGroupContainer:
- process_group: ProcessGroup
- ranks: List[int]
# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
@@ -35,11 +27,9 @@ class DeviceMesh:
during initializing the DeviceMesh instance if the init_process_group set to True.
Otherwise, users need to call create_process_groups_for_logical_mesh manually to init logical process group.
(default: False)
- device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
+ need_flatten(bool, optional): initialize flatten_device_mesh during initializing the DeviceMesh instance if the need_flatten set to True.
"""
- _DIST_BACKEND = {"cuda": "nccl", "cpu": "gloo"}
-
def __init__(self,
physical_mesh_id: torch.Tensor,
mesh_shape: torch.Size = None,
@@ -47,140 +37,48 @@ def __init__(self,
mesh_alpha: List[float] = None,
mesh_beta: List[float] = None,
init_process_group: bool = False,
- device: str = 'cuda'):
- # ============================
- # Physical & Logical Mesh IDs
- # ============================
- self._physical_mesh_id = physical_mesh_id
- assert physical_mesh_id.dim() == 1, "physical_mesh_id should be a 1D tensor."
-
- # logical mesh ids can be obtained via two ways
- # 1. provide physical mesh id and provide mesh shape
- # 2. directly supply the logical mesh id
- assert mesh_shape is None or logical_mesh_id is None, \
- "Only one of mesh_shape and logical_mesh_id can be specified." \
- "Logical mesh IDs are obtained from either mesh_shape + phyiscal_mesh_id or directly from the user-supplied logical_mesh_id"
-
+ need_flatten: bool = True):
+ self.physical_mesh_id = physical_mesh_id
if logical_mesh_id is None:
self.mesh_shape = mesh_shape
- self._logical_mesh_id = self._physical_mesh_id.reshape(self.mesh_shape)
+ self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
else:
self._logical_mesh_id = logical_mesh_id
self.mesh_shape = self._logical_mesh_id.shape
- # ensure two things:
- # 1. logical and physical mesh IDs should contain the same elements
- # 2. there is no duplicate IDs in each mesh, e.g. [2, 2] is not allowed
- assert torch.equal(torch.unique(self._physical_mesh_id), torch.unique(self.logical_mesh_id)), \
- "physical and logical mesh IDs should contain the same elements, please check if you have consistent physical_mesh_id and logical_mesh_id."
- assert torch.unique(self._physical_mesh_id).numel() == self._physical_mesh_id.numel(), \
- "Found duplicate IDs in the phyiscal_mesh_id and this is not allowed, please check your physical_mesh_id again."
- assert torch.unique(self.logical_mesh_id).numel() == self.logical_mesh_id.numel(), \
- "Found duplicate IDs in the logical_mesh_id and this is not allowed, please check your logical_mesh_id again."
-
- # ===============================================
+ # map global rank into logical rank
+ self.convert_map = {}
+ self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
# coefficient for alpha-beta communication model
- # alpha is latency and beta is bandwidth
- # ===============================================
- # if the values are not provided, we assume they are 1 for simplicity
if mesh_alpha is None:
mesh_alpha = [1] * len(self.mesh_shape)
if mesh_beta is None:
mesh_beta = [1] * len(self.mesh_shape)
-
self.mesh_alpha = tuple(mesh_alpha)
self.mesh_beta = tuple(mesh_beta)
-
- # ensure the alpha and beta have the same shape
- assert len(self.mesh_alpha) == len(self.mesh_beta), \
- "mesh_alpha and mesh_beta should have the same length, please check your mesh_alpha and mesh_beta again."
-
- # =========================
- # Device for Process Group
- # =========================
- self._device = device
- self._dist_backend = self._DIST_BACKEND[device]
-
- # =========================
- # Process Group Management
- # =========================
- # the _global_to_local_rank_mapping is structured as follows
- # {
- # : [ , , , ...]
- # }
- self._global_to_local_rank_mapping = dict()
- self._init_global_to_logical_rank_mapping(mapping=self._global_to_local_rank_mapping,
- tensor=self.logical_mesh_id)
-
- # create process group
- self._process_group_dict = {}
- self._ranks_in_the_process_group = {}
- self._global_rank_of_current_process = None
- self._is_initialized = False
-
- # initialize process group if specified
- self._init_ranks_in_the_same_group()
- self._init_process_group = init_process_group
- if init_process_group:
- self.init_logical_process_group()
+ self.init_process_group = init_process_group
+ self.need_flatten = need_flatten
+ if self.init_process_group:
+ self.process_groups_dict = self.create_process_groups_for_logical_mesh()
+ if self.need_flatten and self._logical_mesh_id.dim() > 1:
+ self.flatten_device_mesh = self.flatten()
+ # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
+ # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
+ # self.mesh_beta)
@property
- def shape(self) -> torch.Size:
- """
- Return the shape of the logical mesh.
- """
+ def shape(self):
return self.mesh_shape
@property
- def num_devices(self) -> int:
- """
- Return the number of devices contained in the device mesh.
- """
- return reduce(operator.mul, self._physical_mesh_id.shape, 1)
+ def num_devices(self):
+ return reduce(operator.mul, self.physical_mesh_id.shape, 1)
@property
- def logical_mesh_id(self) -> torch.Tensor:
- """
- Return the logical mesh id.
- """
+ def logical_mesh_id(self):
return self._logical_mesh_id
- def get_process_group(self, axis: int, global_rank: int = None) -> ProcessGroup:
- """
- Return the process group on the specified axis.
-
- Args:
- axis (int): the axis of the process group.
- global_rank (int, optional): the global rank of the process group. If not specified, the current process is used. (default: None)
- """
- if global_rank is None:
- global_rank = self._global_rank_of_current_process
- return self._process_group_dict[global_rank][axis]
-
- def get_process_group_for_all_axes(self, global_rank: int = None) -> Dict[int, ProcessGroup]:
- """
- Return the process groups for all axes.
-
- Args:
- global_rank (int, optional): the global rank of the process
- """
- if global_rank is None:
- global_rank = self._global_rank_of_current_process
- return self._process_group_dict[global_rank]
-
- def get_ranks_in_process_group(self, axis: int, global_rank: int = None) -> List[int]:
- """
- Return the ranks in the process group on the specified axis.
-
- Args:
- axis (int): the axis of the process group.
- global_rank (int, optional): the global rank of the process
- """
- if global_rank is None:
- global_rank = self._global_rank_of_current_process
- return self._ranks_in_the_process_group[global_rank][axis]
-
- def __deepcopy__(self, memo) -> "DeviceMesh":
+ def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
@@ -188,206 +86,111 @@ def __deepcopy__(self, memo) -> "DeviceMesh":
if k != 'process_groups_dict':
setattr(result, k, __import__("copy").deepcopy(v, memo))
else:
- # process group cannot be copied
- # thus, we share them directly
setattr(result, k, v)
+
return result
- def _init_global_to_logical_rank_mapping(self,
- mapping: Dict,
- tensor: torch.Tensor,
- index_list: List[int] = []) -> Dict[int, List[int]]:
+ def flatten(self):
"""
- Build a global rank to local rank mapping for each process group in different axis in the logical device mesh.
-
- Args:
- mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
- tensor (torch.Tensor): the tensor that contains the logical mesh ids.
- index_list (List[int])
-
- Returns:
- mapping (Dict): a dictionary that maps the global rank to the local rank in the logical device mesh.
- The value is a list of integers and each integer represents the local rank in the indexed axis.
+ Flatten the logical mesh into an effective 1d logical mesh,
"""
- for index, inner_tensor in enumerate(tensor):
- # index means the local rank in the current axis
- # inner_tensor refers to the processes with the same local rank
+ flatten_mesh_shape_size = len(self.mesh_shape)
+ flatten_mesh_shape = [self.num_devices]
+ return DeviceMesh(self.physical_mesh_id,
+ tuple(flatten_mesh_shape),
+ mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
+ mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
+ init_process_group=self.init_process_group,
+ need_flatten=False)
+ def _global_rank_to_logical_rank_map(self, tensor, index_list):
+ '''
+ This method is a helper function to build convert_map recursively.
+ '''
+ for index, inner_tensor in enumerate(tensor):
if inner_tensor.numel() == 1:
- # if the inner_tensor only has one element, it means that
- # it already reaches the last axis
- # we append its local_rank in the last axis to the index_list
- # and assign to the mapping
- # the value of the mapping is the the local rank at the indexed axis of the device mesh
- mapping[int(inner_tensor)] = index_list + [index]
+ self.convert_map[int(inner_tensor)] = index_list + [index]
else:
- # we recursively go into the function until we reach the last axis
- # meanwhile, we should add the local rank in the current axis in the index_list
- self._init_global_to_logical_rank_mapping(mapping, inner_tensor, index_list + [index])
+ self._global_rank_to_logical_rank_map(inner_tensor, index_list + [index])
- def init_logical_process_group(self):
+ def create_process_groups_for_logical_mesh(self):
'''
This method is used to initialize the logical process groups which will be used in communications
among logical device mesh.
Note: if init_process_group set to False, you have to call this method manually. Otherwise,
the communication related function, such as ShapeConsistencyManager.apply will raise errors.
'''
- # sanity check
- assert dist.is_initialized, "The torch.distributed should be initialized before calling init_logical_process_group"
- assert not self._is_initialized, "The logical process group has been initialized, do not call init_logical_process_group twice"
-
- # update the global rank of the current process
- self._global_rank_of_current_process = dist.get_rank()
- duplicate_check_list = []
-
- # flatten the global ranks to 1D list
- global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
-
- for global_rank in global_rank_flatten_list:
- # find the other ranks which are in the same process group as global_rank
- ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
-
- for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
- # skip duplicated process group creation
- if ranks_in_same_group in duplicate_check_list:
- continue
-
- # create the process group
- pg_handler = dist.new_group(ranks=ranks_in_same_group, backend=self._dist_backend)
-
- # keep this process group in the process_groups_dict
- for rank in ranks_in_same_group:
- if rank not in self._process_group_dict:
- self._process_group_dict[rank] = dict()
- self._process_group_dict[rank][axis] = pg_handler
-
- # update the init flag
- # we only allow init for once
- self._is_initialized = True
-
- def _init_ranks_in_the_same_group(self):
- """
- This method is used to initialize the ranks_in_the_same_group dictionary.
- """
- # flatten the global ranks to 1D list
- global_rank_flatten_list = self._physical_mesh_id.view(-1).tolist()
-
+ process_groups_dict = {}
+ check_duplicate_list = []
+ global_rank_flatten_list = self.physical_mesh_id.view(-1).tolist()
for global_rank in global_rank_flatten_list:
- # find the other ranks which are in the same process group as global_rank
- ranks_in_same_group_by_axis = self._collate_global_ranks_in_same_process_group(global_rank)
-
- for axis, ranks_in_same_group in ranks_in_same_group_by_axis.items():
- # create dict for each rank
- if global_rank not in self._process_group_dict:
- self._ranks_in_the_process_group[global_rank] = dict()
+ process_groups = self.global_rank_to_process_groups_with_global_rank(global_rank)
+ for axis, process_group in process_groups.items():
+ if axis not in process_groups_dict:
+ process_groups_dict[axis] = []
+ if process_group not in check_duplicate_list:
+ check_duplicate_list.append(process_group)
+ process_group_handler = dist.new_group(process_group)
+ process_groups_dict[axis].append((process_group, process_group_handler))
- # keep this process group in the process_groups_dict
- self._ranks_in_the_process_group[global_rank][axis] = ranks_in_same_group
+ return process_groups_dict
- def global_rank_to_local_rank(self, rank: int, axis: int = None) -> Union[List[int], int]:
- """
- Return the local rank of the given global rank in the logical device mesh.
+ def global_rank_to_logical_rank(self, rank):
+ return self.convert_map[rank]
- Args:
- rank (int): the global rank in the logical device mesh.
- axis (int): the axis of the logical device mesh.
- """
- local_ranks = self._global_to_local_rank_mapping[rank]
- if axis:
- return local_ranks[axis]
- else:
- return local_ranks
-
- def _collate_global_ranks_in_same_process_group(self, global_rank):
+ def global_rank_to_process_groups_with_logical_rank(self, rank):
'''
- Give a global rank and return all global ranks involved in its associated process group in each axis.
-
- Example:
-
- ```python
- sphysical_mesh_id = torch.arange(0, 16)
- mesh_shape = (4, 4)
-
- # logical mesh will look like
- # [[0, 1, 2, 3],
- # [4, 5, 6, 7],
- # [8, 9, 10,11],
- # [12,13,14,15]]
-
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- print(device_mesh.collate_global_ranks_in_same_process_group(0))
-
- # key is axis name
- # value is a list of global ranks in same axis with rank 0
- # output will look like
- # {
- 0: [0, 4, 8, 12],
- 1: [0, 1, 2, 3]
- # }
+ Give a global rank and return all logical process groups of this rank.
+ for example:
+ physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
+ mesh_shape = (4, 4)
+ # [[0, 1, 2, 3],
+ # [4, 5, 6, 7],
+ # [8, 9, 10,11],
+ # [12,13,14,15]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
+ print(device_mesh.global_rank_to_process_groups_with_logical_rank(0))
+ output:
+ # key is axis name
+ # value is a list of logical ranks in same axis with rank 0
+ {0: [[0, 0], [1, 0], [2, 0], [3, 0]], 1: [[0, 0], [0, 1], [0, 2], [0, 3]]}
'''
- # We have init the global rank to local rank by calling _init_global_to_logical_rank_mapping
- # for self._global_to_local_rank_mapping
- # the key is the global rank
- # the value is the list of local ranks corresponding to the global rank with respect of different axes
- # we can see the list of local ranks as the process coordinates for simplicity
- # the key and value are all unique, therefore,
- # we can also to use the coordinates to find the global rank
-
- # =========================================================================
- # Step 1
- # find all the process_coordinates for processes in the same process group
- # as the given global rank
- # =========================================================================
-
- # each
- processes_in_the_same_process_group = {}
-
- for dim in range(self.logical_mesh_id.dim()):
- # iterate over the dimension size so that we can include all processes
- # in the same process group in the given axis
- # the _local_rank refers to the local rank of the current process
- for _local_rank in range(self.logical_mesh_id.shape[dim]):
-
- # if this dimension is not initailized yet,
- # initialize it with an empty array
- if dim not in processes_in_the_same_process_group:
- processes_in_the_same_process_group[dim] = []
-
- # get the local rank corresponding to the global rank
- process_coordinates = self._global_to_local_rank_mapping[global_rank].copy()
-
- # replace the local rank in the given dimension with the
- # lcoal rank of the current process iterated
- process_coordinates[dim] = _local_rank
- processes_in_the_same_process_group[dim].append(process_coordinates)
-
- # =================================================================
- # Step 2
- # Use local rank combination to find its corresponding global rank
- # =================================================================
- # the key of the dict is the axis
- # the value is the list of global ranks which are in the same process group as the given global rank
- global_pg_ranks = {}
- for dim, coordinates_of_all_processes in processes_in_the_same_process_group.items():
- global_pg_ranks[dim] = []
- for process_coordinates in coordinates_of_all_processes:
- # find the global rank by local rank combination
- for _global_rank, _process_coordinates in self._global_to_local_rank_mapping.items():
- if process_coordinates == _process_coordinates:
- global_pg_ranks[dim].append(_global_rank)
- return global_pg_ranks
-
- def flatten(self):
- """
- Flatten the logical mesh into an effective 1d logical mesh,
- """
- flatten_mesh_shape_size = len(self.mesh_shape)
- flatten_mesh_shape = [self.num_devices]
- return DeviceMesh(self._physical_mesh_id,
- tuple(flatten_mesh_shape),
- mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
- mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
- init_process_group=self._init_process_group)
+ process_groups = {}
+ for d in range(self.logical_mesh_id.dim()):
+ for replacer in range(self.logical_mesh_id.shape[d]):
+ if d not in process_groups:
+ process_groups[d] = []
+ process_group_member = self.convert_map[rank].copy()
+ process_group_member[d] = replacer
+ process_groups[d].append(process_group_member)
+ return process_groups
+
+ def global_rank_to_process_groups_with_global_rank(self, rank):
+ '''
+ Give a global rank and return all process groups of this rank.
+ for example:
+ physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
+ mesh_shape = (4, 4)
+ # [[0, 1, 2, 3],
+ # [4, 5, 6, 7],
+ # [8, 9, 10,11],
+ # [12,13,14,15]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
+ print(device_mesh.global_rank_to_process_groups_with_global_rank(0))
+ output:
+ # key is axis name
+ # value is a list of global ranks in same axis with rank 0
+ {0: [0, 4, 8, 12], 1: [0, 1, 2, 3]}
+ '''
+ logical_process_groups = self.global_rank_to_process_groups_with_logical_rank(rank)
+ process_groups = {}
+ for dim, logical_ranks in logical_process_groups.items():
+ process_groups[dim] = []
+ for logical_rank in logical_ranks:
+ for g_rank, l_rank in self.convert_map.items():
+ if l_rank == logical_rank:
+ process_groups[dim].append(g_rank)
+ return process_groups
def all_gather_cost(self, num_bytes, mesh_dim):
num_devices = self.logical_mesh_id.shape[mesh_dim]
@@ -409,3 +212,38 @@ def all_to_all_cost(self, num_bytes, mesh_dim):
penalty_factor = num_devices / 2.0
return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *
(num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001)
+
+
+class FlattenDeviceMesh(DeviceMesh):
+
+ def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None):
+ super().__init__(physical_mesh_id,
+ mesh_shape,
+ mesh_alpha,
+ mesh_beta,
+ init_process_group=False,
+ need_flatten=False)
+ # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars
+ self.mesh_alpha = max(self.mesh_alpha)
+ self.mesh_beta = min(self.mesh_beta)
+ # Different from original process_groups_dict, rank_list is not stored
+ self.process_number_dict = self.create_process_numbers_for_logical_mesh()
+
+ def create_process_numbers_for_logical_mesh(self):
+ '''
+ Build 1d DeviceMesh in column-major(0) and row-major(1)
+ for example:
+ mesh_shape = (2,4)
+ # [[0, 1, 2, 3],
+ # [4, 5, 6, 7]]
+ # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
+ '''
+ num_devices = reduce(operator.mul, self.mesh_shape, 1)
+ process_numbers_dict = {}
+ process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist()
+ process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist()
+ return process_numbers_dict
+
+ def mix_gather_cost(self, num_bytes):
+ num_devices = reduce(operator.mul, self.mesh_shape, 1)
+ return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1)
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index 5d3f3e5530cb..dc0df0517508 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -238,7 +238,7 @@ def initialize(model: nn.Module,
loaded into gpc.config.
Args:
- model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
+ model (:class:`torch.nn.Module` or Callable): Your model instance or a function to build the model.
optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
Your optimizer instance.
criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py
index ca8914362cd6..76f550dc4392 100644
--- a/colossalai/lazy/lazy_init.py
+++ b/colossalai/lazy/lazy_init.py
@@ -1,5 +1,5 @@
from types import MethodType
-from typing import Callable, Dict, Optional, Union
+from typing import Callable, Optional, Union
import torch
import torch.distributed as dist
@@ -8,9 +8,8 @@
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
-from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.d_tensor import DTensor
-from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
+from colossalai.tensor.d_tensor.layout import Layout
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
@@ -173,7 +172,7 @@ def materialize(self) -> torch.Tensor:
self.clean()
return _convert_cls(self, target)
- def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
+ def distribute(self, layout: Layout) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args:
@@ -184,7 +183,7 @@ def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> to
"""
target = self._materialize_data()
self.clean()
- local_tensor = DTensor(target, device_mesh, sharding_spec).local_tensor
+ local_tensor = DTensor(target, layout).local_tensor
return _convert_cls(self, local_tensor)
def clean(self) -> None:
@@ -537,10 +536,7 @@ def apply_fn(name: str, p: LazyTensor):
return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod
- def distribute(module: nn.Module,
- device_mesh: DeviceMesh,
- sharding_spec_dict: Dict[str, ShardingSpec],
- verbose: bool = False) -> nn.Module:
+ def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
@@ -550,7 +546,7 @@ def distribute(module: nn.Module,
"""
def apply_fn(name: str, p: LazyTensor):
- p.distribute(device_mesh, sharding_spec_dict[name])
+ p.distribute(layout_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose)
diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py
index dd873c852936..af38d2a502c2 100644
--- a/colossalai/tensor/comm_spec.py
+++ b/colossalai/tensor/comm_spec.py
@@ -16,66 +16,69 @@ def _all_gather(tensor, comm_spec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
- process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
- process_group = process_groups[comm_spec.logical_process_axis]
-
- tensor_list = [
- torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
- for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
- ]
- # without this contiguous operation, the all gather may get some unexpected results.
- tensor = tensor.contiguous()
- dist.all_gather(tensor_list, tensor, group=process_group)
- output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
- return output
+ process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, process_group in process_groups_list:
+ if dist.get_rank() in rank_list:
+ tensor_list = [
+ torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
+ for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
+ ]
+ # without this contiguous operation, the all gather may get some unexpected results.
+ tensor = tensor.contiguous()
+ dist.all_gather(tensor_list, tensor, group=process_group)
+ output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
+ return output
def _split(tensor, comm_spec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
- process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
- process_group = process_groups[comm_spec.logical_process_axis]
-
- dim = comm_spec.shard_dim
- length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
- start = length * dist.get_rank(process_group)
- output = torch.narrow(tensor, dim, start, length).contiguous()
- return output
+ process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, _ in process_groups_list:
+ if dist.get_rank() in rank_list:
+ dim = comm_spec.shard_dim
+ length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
+ start = length * rank_list.index(dist.get_rank())
+ output = torch.narrow(tensor, dim, start, length).contiguous()
+ return output
def _all_to_all(tensor, comm_spec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
- process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
- process_group = process_groups[comm_spec.logical_process_axis]
- world_size = dist.get_world_size(process_group)
-
- new_shape = list(tensor.shape)
- new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
- new_shape = torch.Size(new_shape)
- output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
- dim = comm_spec.shard_dim
- length = tensor.shape[comm_spec.shard_dim] // world_size
- input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
- group = process_group
- dist.all_to_all(output_tensor_list, input_tensor_list, group)
- output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
- return output
+ process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, process_group in process_groups_list:
+ if dist.get_rank() in rank_list:
+ new_shape = list(tensor.shape)
+ new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
+ new_shape = torch.Size(new_shape)
+ output_tensor_list = [
+ torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
+ ]
+ dim = comm_spec.shard_dim
+ length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
+ input_tensor_list = [
+ torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
+ ]
+ group = process_group
+ dist.all_to_all(output_tensor_list, input_tensor_list, group)
+ output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
+ return output
def _all_reduce(tensor, comm_spec, async_op=False):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
- process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
- process_group = process_groups[comm_spec.logical_process_axis]
-
- if not tensor.is_contiguous():
- tensor = tensor.contiguous()
- dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
- return tensor
+ process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, process_group in process_groups_list:
+ if dist.get_rank() in rank_list:
+ if not tensor.is_contiguous():
+ tensor = tensor.contiguous()
+ dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
+ return tensor
def _mix_gather(tensor, comm_spec):
@@ -411,7 +414,7 @@ def __init__(self,
self.forward_only = forward_only
if isinstance(self.logical_process_axis, list):
if not mix_gather:
- self.device_mesh = self.sharding_spec.device_mesh.flatten()
+ self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
else:
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes
diff --git a/colossalai/tensor/d_tensor/RAEDME.md b/colossalai/tensor/d_tensor/RAEDME.md
deleted file mode 100644
index 95d866388364..000000000000
--- a/colossalai/tensor/d_tensor/RAEDME.md
+++ /dev/null
@@ -1,103 +0,0 @@
-# 🔢 Distributed Tensor
-
-## 📚 Table of Contents
-
-- [🔢 Distributed Tensor](#-distributed-tensor)
- - [📚 Table of Contents](#-table-of-contents)
- - [🔗 Introduction](#-introduction)
- - [📝 Design](#-design)
- - [🔨 Usage](#-usage)
- - [🎈 Progress Log](#-progress-log)
-
-## 🔗 Introduction
-
-Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training.
-It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor.
-
-## 📝 Design
-
-Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension.
-
-Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below:
-
-
-```text
- [1, 2, 3, 4 ]
-A = [4, 5, 6, 7 ]
- [8, 9, 10, 11]
- [12, 13, 14, 15]
-```
-
-`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology.
-
-```text
-| --------------------—————————————————————-|
-| | |
-| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] |
-| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] |
-| | |
-| --------------------——————————————————-----
-| | |
-| [8, 9, 10, 11] | [8, 9, 10, 11] |
-| [12, 13, 14, 15] | [12, 13, 14, 15] |
-| | |
-| --------------------——————————————————-----
-```
-
-`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology.
-
-```text
-| --------------------—————————————————————-|
-| | |
-| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] |
-| | |
-| --------------------——————————————————-----
-| | |
-| [8, 9, 10, 11] | [12, 13, 14, 15] |
-| | |
-| --------------------——————————————————-----
-```
-
-## 🔨 Usage
-
-A sample API usage is given below.
-
-```python
-import torch
-
-import colossalai
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.d_tensor import DTensor, ShardingSpec
-
-colossalai.launch_from_torch(config={})
-
-# define your device mesh
-# assume you have 4 GPUs
-physical_mesh_id = torch.arange(0, 4).reshape(1, 4)
-mesh_shape = (2, 2)
-device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
-# define a tensor
-a = torch.rand(16, 32).cuda()
-
-# create sharding spec for the tensor
-# assume the sharding spec is [S0, R]
-dim_partition_dict = {0: [0]}
-sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)
-
-# create a distributed tensor
-d_tensor = DTensor(a, device_mesh, sharding_spec)
-print(d_tensor)
-
-global_tensor = d_tensor.to_global()
-print(global_tensor)
-```
-
-
-## 🎈 Progress Log
-
-- [x] Support layout conversion
-- [x] Support sharding on 2D device mesh
-- [ ] Support sharding on 3D device mesh
-- [ ] Support sharding 4D device mesh
-- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.)
diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py
index af77f4f0edfc..e69de29bb2d1 100644
--- a/colossalai/tensor/d_tensor/__init__.py
+++ b/colossalai/tensor/d_tensor/__init__.py
@@ -1,4 +0,0 @@
-from .d_tensor import DTensor
-from .sharding_spec import ShardingSpec
-
-__all__ = ['DTensor', 'ShardingSpec']
diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py
index 79b2e3ef936a..159125fa16db 100644
--- a/colossalai/tensor/d_tensor/comm_spec.py
+++ b/colossalai/tensor/d_tensor/comm_spec.py
@@ -24,12 +24,12 @@ class CommSpec:
'''
Communication spec is used to record the communication action. It converts the communication spec
to real action which will be used in runtime. It contains comm_pattern to determine the
- communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
+ communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
- comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
- process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
+ comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
+ process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
@@ -37,7 +37,7 @@ class CommSpec:
def __init__(self,
comm_pattern: CollectiveCommPattern,
- process_group_dict: Dict,
+ process_groups_dict: Dict,
gather_dim: int = None,
shard_dim: int = None,
logical_process_axis: int = None):
@@ -45,7 +45,7 @@ def __init__(self,
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
- self.process_group_dict = process_group_dict
+ self.process_groups_dict = process_groups_dict
def __repr__(self):
res_list = ["CommSpec:("]
@@ -92,56 +92,68 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
- process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
- world_size = dist.get_world_size(process_group)
- tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
- # without this contiguous operation, the all gather may get some unexpected results.
- tensor = tensor.contiguous()
- dist.all_gather(tensor_list, tensor, group=process_group)
- output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
- return output
+ process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, process_group in process_groups_list:
+ if dist.get_rank() in rank_list:
+ tensor_list = [
+ torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
+ ]
+ # without this contiguous operation, the all gather may get some unexpected results.
+ tensor = tensor.contiguous()
+ dist.all_gather(tensor_list, tensor, group=process_group)
+ output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
+ return output
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
- process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
- dim = comm_spec.shard_dim
- length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
- start = length * dist.get_rank(process_group)
- output = torch.narrow(tensor, dim, start, length).contiguous()
- return output
+ process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, _ in process_groups_list:
+ if dist.get_rank() in rank_list:
+ dim = comm_spec.shard_dim
+ length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
+ start = length * rank_list.index(dist.get_rank())
+ output = torch.narrow(tensor, dim, start, length).contiguous()
+ return output
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
- process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
- world_size = dist.get_world_size(process_group)
- new_shape = list(tensor.shape)
- new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
- new_shape = torch.Size(new_shape)
- output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
- dim = comm_spec.shard_dim
- length = tensor.shape[comm_spec.shard_dim] // world_size
- input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
- group = process_group
- dist.all_to_all(output_tensor_list, input_tensor_list, group)
- output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
- return output
+ process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, process_group in process_groups_list:
+ if dist.get_rank() in rank_list:
+ new_shape = list(tensor.shape)
+ new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
+ new_shape = torch.Size(new_shape)
+ output_tensor_list = [
+ torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
+ ]
+ dim = comm_spec.shard_dim
+ length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
+ input_tensor_list = [
+ torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
+ ]
+ group = process_group
+ dist.all_to_all(output_tensor_list, input_tensor_list, group)
+ output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
+ return output
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
- process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
- if not tensor.is_contiguous():
- tensor = tensor.contiguous()
- dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
- return tensor
+ process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, process_group in process_groups_list:
+ if dist.get_rank() in rank_list:
+ if not tensor.is_contiguous():
+ tensor = tensor.contiguous()
+ dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
+ return tensor
class _ReduceGrad(torch.autograd.Function):
@@ -257,7 +269,7 @@ def symbolic(graph, input_):
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
- process_group_dict=comm_spec.process_group_dict,
+ process_groups_dict=comm_spec.process_groups_dict,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)
diff --git a/colossalai/tensor/d_tensor/d_tensor.py b/colossalai/tensor/d_tensor/d_tensor.py
index 6bda0f4e579c..c1fe9d50a048 100644
--- a/colossalai/tensor/d_tensor/d_tensor.py
+++ b/colossalai/tensor/d_tensor/d_tensor.py
@@ -3,119 +3,55 @@
import torch
from torch.utils._pytree import tree_map
-from colossalai.device.device_mesh import DeviceMesh
-
from .layout import Layout
from .layout_converter import LayoutConverter, to_global
from .sharding_spec import ShardingSpec
-__all__ = ['DTensor', 'distribute_tensor', 'distribute_module', 'construct_default_sharding_spec']
-
layout_converter = LayoutConverter()
class DTensor(torch.Tensor):
- """
- DTensor stands for distributed tensor. It is a subclass of `torch.Tensor` and contains meta information
- about the tensor distribution. The meta information includes the device mesh, the sharding specification,
- and the entire shape of the tensor.
-
- During runtime, we will not directly use the DTensor objects for computation. Instead, we will only use the
- `DTensor.local_tensor` for computation. The `DTensor.local_tensor` is the local tensor in the current rank.
- In this way, all tensors involved in computation will only be native PyTorch tensors.
-
- Example:
- ```python
- from colossalai.device import DeviceMesh
-
- # define your device mesh
- # assume you have 4 GPUs
- physical_mesh_id = torch.arange(0, 4).reshape(1, 4)
- mesh_shape = (2, 2)
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
-
- # define a tensor
- x = torch.rand(16, 32)
-
- # create sharding spec for the tensor
- # assume the sharding spec is [S, R]
- dim_partition_dict = {
- 0: 1
- }
- sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)
-
- # create a distributed tensor
- d_tensor = DTensor(x, device_mesh, sharding_spec)
- ```
- Args:
- tensor (`torch.Tensor`): the unsharded tensor.
- device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices.
- sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded.
- """
-
- def __init__(self, tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec):
- # ensure this tensor is not a DTensor
- assert not isinstance(tensor, DTensor), 'The input tensor should not be a DTensor.'
-
- # store meta info
- self.local_tensor = tensor
- self.data_type = tensor.dtype
- self.global_shape = tensor.shape
-
- # create distributed layout
- dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape)
+ def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout):
+ self.local_tensor = local_tensor
+ self.data_type = local_tensor.dtype
+ self.entire_shape = local_tensor.shape
self.dist_layout = dist_layout
-
- # shard the tensor
self._apply_layout()
@staticmethod
- def __new__(cls, tensor, *args, **kwargs):
- return torch.Tensor._make_subclass(cls, tensor, tensor.requires_grad)
+ def __new__(cls, local_tensor, layout):
+ return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
def __repr__(self):
- return f"DTensor(\n{self.to_global()}\n{self.dist_layout}"
+ return f"DTensor({self.to_global()}, {self.dist_layout})"
def __str__(self):
return self.__repr__()
- def layout_convert(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
+ def layout_convert(self, target_layout):
'''
Convert the layout of the tensor from source_spec to target_spec.
- This will update the `local_tensor` and `dist_layout` in place.
-
- Args:
- target_layout (Layout): the target layout specification.
'''
- target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape)
- self.local_tensor = layout_converter.apply(tensor=self.local_tensor,
- source_layout=self.dist_layout,
- target_layout=target_layout)
+ self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
self.dist_layout = target_layout
def _apply_layout(self):
'''
Apply the layout to the local tensor during initializing process.
'''
- # layout converter requires a source and target laytout
- # we construct the source layer for an unsharded tensor
- # and use self.dist_layer as the targer layout for the sharded tensor
source_spec = construct_default_sharding_spec(self.local_tensor)
source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
+ device_type=self.dist_layout.device_type,
sharding_spec=source_spec,
- global_shape=self.global_shape)
- self.local_tensor = layout_converter.apply(tensor=self.local_tensor,
- source_layout=source_layout,
- target_layout=self.dist_layout)
+ entire_shape=self.entire_shape)
+ self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
- # convert all DTensors to native pytorch tensors
- # so that operations will be conducted on native tensors
def filter_arg(arg):
if isinstance(arg, DTensor):
return arg.local_tensor
@@ -124,9 +60,9 @@ def filter_arg(arg):
args = tree_map(filter_arg, args)
kwargs = tree_map(filter_arg, kwargs)
-
- # NOTE: if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
+ # if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
# and op type.
+
return func(*args, **kwargs)
@property
@@ -149,6 +85,7 @@ def to(self, *args, **kwargs):
'''
self.local_tensor = self.local_tensor.to(*args, **kwargs)
self.data_type = self.local_tensor.dtype
+ self.dist_layout.device_type = self.local_tensor.device
# TODO: update the device mesh process groups or we should just cache
# both the cpu process groups and the cuda process groups?
return self
@@ -161,7 +98,7 @@ def to_local(self):
def to_global(self):
'''
- Recover the global tensor from the distributed tensor by returning a new `torch.Tensor` object.
+ Recover the global tensor from the distributed tensor.
Note: This function will all_gather the local tensor to the global tensor and it
will not change the layout of the DTensor. This function is mainly used for debugging or
@@ -170,29 +107,24 @@ def to_global(self):
return to_global(self.local_tensor, self.dist_layout)
-def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> DTensor:
+def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
'''
Distribute the local tensor to the distributed tensor according to the dist_layout specified.
Args:
- tensor (`torch.Tensor`): tensor to be distributed.
- device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices.
- sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded.
+ local_tensor: tensor to be distributed.
+ dist_layout: the layout specification of the distributed tensor.
Returns:
A 'DTensor' object.
'''
- return DTensor(tensor, device_mesh, sharding_spec)
+ return DTensor(local_tensor, dist_layout)
def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
'''
This function converts all the parameters in the module to DTensor(DParam).
- Args:
- module (`torch.nn.Module`): the module to be distributed.
- partition_fn (callable): the partition function which will be used to partition the parameters.
-
Note: This function is subject to future change as the DParam has not been implemented yet.
'''
for name, param in module.named_parameters():
@@ -206,11 +138,5 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
'''
Construct the default sharding specification for the tensor.
-
- Args:
- tensor (`torch.Tensor`): the tensor to be sharded.
-
- Returns:
- A `ShardingSpec` object without any sharding specified.
'''
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py
index 2946611b4b79..ee7ef74a99ae 100644
--- a/colossalai/tensor/d_tensor/layout.py
+++ b/colossalai/tensor/d_tensor/layout.py
@@ -11,32 +11,28 @@
class Layout:
- """
- Layout of a tensor refers to the tensor placement on the device mesh and how the tensor is sharded over the devices.
+ """Layout of a tensor.
- Args:
- device_mesh (`DeviceMesh`): the device mesh to store the tensor distributed.
- sharding_spec (`ShardingSpec`): the sharding specification to describe how the tensor is sharded.
- global_shape (`torch.Size`): the entire shape of the global tensor.
+ Attributes:
+ device_mesh: the device mesh to store the tensor distributed.
+ device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
+ sharding_spec: the sharding specification to describe how the tensor is sharded.
+ entire_shape: the entire shape of the global tensor.
"""
- def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
+ def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec,
+ entire_shape: torch.Size):
self.device_mesh = device_mesh
+ self.device_type = device_type
self.sharding_spec = sharding_spec
- self.global_shape = global_shape
+ self.entire_shape = entire_shape
self._sanity_check()
def __hash__(self) -> int:
return hash(f'{self.sharding_spec}')
- def get_sharded_shape_per_device(self) -> torch.Size:
- """
- Compute the shape of the sharded tensor on each device.
-
- Returns:
- `torch.Size`: the shape of the sharded tensor on each device.
- """
- sharded_shape = list(self.global_shape)
+ def get_sharded_shape_per_device(self):
+ sharded_shape = list(self.entire_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
@@ -60,7 +56,7 @@ def _sanity_check(self):
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items():
- tensor_dim_size = self.global_shape[dim]
+ tensor_dim_size = self.entire_shape[dim]
num_devices = 1
for element in shard_list:
diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py
index 6eff92ea6b13..cf02aac309f4 100644
--- a/colossalai/tensor/d_tensor/layout_converter.py
+++ b/colossalai/tensor/d_tensor/layout_converter.py
@@ -3,8 +3,10 @@
from dataclasses import dataclass
from typing import Dict, List, Tuple
+import numpy as np
import torch
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout
@@ -26,21 +28,13 @@ class LayoutConverterOptions:
pass
-def to_global(distributed_tensor: "DTensor", layout: Layout) -> torch.Tensor:
- """
- Convert a distributed tensor to the global tensor with the given layout.
- This function returns a native `torch.Tensor` object.
-
-
- Args:
- distributed_tensor (`DTensor`): the distributed tensor to be converted.
- layout (`Layout`): the target layout specification.
- """
+def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
global_layout = Layout(device_mesh=layout.device_mesh,
+ device_type=layout.device_type,
sharding_spec=global_sharding_spec,
- global_shape=layout.global_shape)
+ entire_shape=layout.entire_shape)
with torch.no_grad():
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
return global_tensor
@@ -55,9 +49,6 @@ def set_layout_converting_options(options: LayoutConverterOptions):
class LayoutConverter(metaclass=SingletonMeta):
- """
- LayoutConverter is a singleton class which converts the layout of a distributed tensor.
- """
def __init__(self):
self._options = None
@@ -100,14 +91,15 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- global_shape = (4, 4, 4)
+ entire_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
- global_shape=global_shape)
+ entire_shape=entire_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout)
for layout, comm_spec in rst_dict.items():
@@ -120,12 +112,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
source_spec = source_layout.sharding_spec
-
- # the key of the dict is the axis
- # the value is the process group
- current_rank = source_layout.device_mesh._global_rank_of_current_process
- process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
-
+ process_groups_dict = source_layout.device_mesh.process_groups_dict
for target_pair in source_spec.dim_partition_dict.items():
shard_list = all_gather_simulator(target_pair)
index = target_pair[0]
@@ -143,7 +130,7 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co
logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec(
comm_pattern,
- process_group_dict=process_group_dict,
+ process_groups_dict=process_groups_dict,
gather_dim=gather_dim,
# shard_dim will be used during backward
shard_dim=gather_dim,
@@ -154,7 +141,8 @@ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, Co
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
- global_shape=source_layout.global_shape)
+ device_type=source_layout.device_type,
+ entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
@@ -179,14 +167,15 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- global_shape = (4, 4, 4)
+ entire_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
- global_shape=global_shape)
+ entire_shape=entire_shape)
rst_dict = layout_converter.all_to_all_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
@@ -199,12 +188,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com
'''
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
-
- # the key of the dict is the axis
- # the value is the process group
- current_rank = source_layout.device_mesh._global_rank_of_current_process
- process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
-
+ process_groups_dict = source_layout.device_mesh.process_groups_dict
source_spec = source_layout.sharding_spec
tensor_dims = source_spec.dims
for f_index in range(tensor_dims - 1):
@@ -245,7 +229,7 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com
shard_dim = f_index
logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern,
- process_group_dict=process_group_dict,
+ process_groups_dict,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
@@ -268,7 +252,8 @@ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, Com
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
- global_shape=source_layout.global_shape)
+ device_type=source_layout.device_type,
+ entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
@@ -293,15 +278,16 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- global_shape = (4, 4, 4)
+ entire_shape = (4, 4, 4)
dim_partition_dict = {0: [0]}
# [S0,R,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
- global_shape=global_shape)
+ entire_shape=entire_shape)
rst_dict = layout_converter.shard_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
@@ -315,11 +301,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
source_spec = source_layout.sharding_spec
-
- # the key of the dict is the axis
- # the value is the process group
- current_rank = source_layout.device_mesh._global_rank_of_current_process
- process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
+ process_groups_dict = source_layout.device_mesh.process_groups_dict
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
@@ -347,7 +329,7 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec
shard_dim = index
logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern,
- process_group_dict=process_group_dict,
+ process_groups_dict,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
@@ -358,7 +340,8 @@ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec
dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
- global_shape=source_layout.global_shape)
+ device_type=source_layout.device_type,
+ entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
@@ -416,7 +399,7 @@ def layout_converting(self, source_layout: Layout,
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- global_shape = (4, 4, 4)
+ entire_shape = (4, 4, 4)
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
@@ -424,14 +407,16 @@ def layout_converting(self, source_layout: Layout,
# [R,S01,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
- global_shape=global_shape)
+ entire_shape=entire_shape)
# [S01,R,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
- global_shape=global_shape)
+ entire_shape=entire_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
@@ -520,19 +505,21 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- global_shape = (4, 4, 4)
+ entire_shape = (4, 4, 4)
# [S0,R,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
- global_shape=global_shape)
+ entire_shape=entire_shape)
# [R,S0,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
- global_shape=global_shape)
+ entire_shape=entire_shape)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
@@ -567,4 +554,3 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo
for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor)
return tensor
- return tensor
diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py
index 45b05e10e297..565012b58a03 100644
--- a/colossalai/tensor/d_tensor/sharding_spec.py
+++ b/colossalai/tensor/d_tensor/sharding_spec.py
@@ -116,21 +116,21 @@ def build_difference_2d_dict(self):
def dim_diff(self, other):
'''
- The difference between two DimSpec.
+ The difference between two _DimSpec.
Argument:
- other(DimSpec): the dim spec to compare with.
+ other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
- ```python
- dim_spec = DimSpec([0])
- other_dim_spec = DimSpec([0, 1])
+ dim_spec = _DimSpec([0])
+ other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
- # output: 5
- ```
+
+ Output:
+ 5
'''
difference = self.difference_dict[(str(self), str(other))]
return difference
@@ -142,13 +142,9 @@ class ShardingSpec:
[R, R, S0, S1], which means
Argument:
- dim_size (int): The number of dimensions of the tensor to be sharded.
- dim_partition_dict (Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
- and the value of the key describe which logical axis will be sharded in that dimension. Defaults to None.
- E.g. {0: [0, 1]} means the first dimension of the tensor will be sharded in logical axis 0 and 1.
- sharding_sequence (List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
- Generally, users should specify either dim_partition_dict or sharding_sequence.
- If both are given, users must ensure that they are consistent with each other. Defaults to None.
+ dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
+ and the value of the key describe which logical axis will be sharded in that dimension.
+ sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
def __init__(self,
@@ -212,7 +208,6 @@ def spec_diff(self, other):
pair of sharding sequence.
Example:
- ```python
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
@@ -224,8 +219,10 @@ def spec_diff(self, other):
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
- # output: 25
- ```
+
+ Output:
+ 25
+
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
diff --git a/colossalai/trainer/_trainer.py b/colossalai/trainer/_trainer.py
index 60bbc4eeee32..bfe1c403fd48 100644
--- a/colossalai/trainer/_trainer.py
+++ b/colossalai/trainer/_trainer.py
@@ -31,9 +31,9 @@ class Trainer:
>>> # Initialize your engine, train_dataloader, test_dataloader, lr_scheduler
>>> engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion)
>>> # Beginning training progress
- >>> timier = ...
+ >>> timer = ...
>>> logger = ...
- >>> trainer = Trainer(engine=engine, logger=logger, timer=timier)
+ >>> trainer = Trainer(engine=engine, logger=logger, timer=timer)
>>> # add hooks you would like to use here.
>>> hook_list = []
>>> trainer.fit(
@@ -56,7 +56,7 @@ def __init__(
timer: MultiTimer = None,
logger: DistributedLogger = None,
):
- # training-ralated params
+ # training-related params
self._engine = engine
self._max_epochs = 0
self._cur_epoch = 0
@@ -118,7 +118,7 @@ def _set_current_step(self, epoch: int):
self._cur_step = epoch * self._steps_per_epoch
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
- """Call timer funciton with a given timer name.
+ """Call timer function with a given timer name.
Args:
action (str): Function to be called on timer.
diff --git a/colossalai/utils/data_sampler/data_parallel_sampler.py b/colossalai/utils/data_sampler/data_parallel_sampler.py
index 945dc54b397a..2318e07a7f8d 100644
--- a/colossalai/utils/data_sampler/data_parallel_sampler.py
+++ b/colossalai/utils/data_sampler/data_parallel_sampler.py
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
-# adpated from torch.utils.data.DistributedSampler
+# adapted from torch.utils.data.DistributedSampler
import math
import random
diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py
index f49607376439..21bc530934d3 100644
--- a/colossalai/utils/model/utils.py
+++ b/colossalai/utils/model/utils.py
@@ -70,7 +70,7 @@ def _init_subclass(cls, **kwargs):
cls.__init__ = preprocess_after(cls.__init__)
# Replace .__init__() for all existing subclasses of torch.nn.Module
- # Excution self._post_init_method after the default init function.
+ # Execution self._post_init_method after the default init function.
substitute_init_recursively(torch.nn.modules.module.Module, _enable_class, set())
# holding on to the current __init__subclass__ for exit
diff --git a/colossalai/utils/profiler/legacy/comm_profiler.py b/colossalai/utils/profiler/legacy/comm_profiler.py
index a4f5729c97ec..334f0113ee90 100644
--- a/colossalai/utils/profiler/legacy/comm_profiler.py
+++ b/colossalai/utils/profiler/legacy/comm_profiler.py
@@ -111,7 +111,7 @@ def append(s: str = None):
res.append(sep)
if self.warn_flag:
- append("Warnning: there exists multiple communication operations in the same time. As a result, "
+ append("Warning: there exists multiple communication operations in the same time. As a result, "
"the profiling result is not accurate.")
if self.total_cuda_time == 0:
@@ -123,12 +123,12 @@ def append(s: str = None):
append("total number of calls: {}".format(self.total_count))
append("All events:")
- seperation = '-' * 74
+ separation = '-' * 74
row_format = '{:^10}' + '{:^12}' * 2 + '{:^16}' + '{:^12}' * 2
- append(seperation)
+ append(separation)
append(row_format.format('Location', 'GPU time', 'Percentage', 'Comm volume', 'Bandwidth', 'Num of calls'))
- append(seperation)
+ append(separation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].self_cuda_time)
for location, event in show_list:
diff --git a/colossalai/utils/profiler/legacy/pcie_profiler.py b/colossalai/utils/profiler/legacy/pcie_profiler.py
index 526222941ef9..8f812f5cfc7b 100644
--- a/colossalai/utils/profiler/legacy/pcie_profiler.py
+++ b/colossalai/utils/profiler/legacy/pcie_profiler.py
@@ -130,12 +130,12 @@ def append(s: str = None):
append("Possible data transmission events in PCIE:")
- seperation = '-' * 62
+ separation = '-' * 62
row_format = '{:^10}' + '{:^12}' + '{:^16}' + '{:^12}' * 2
- append(seperation)
+ append(separation)
append(row_format.format('Location', 'GPU time', 'Trans volume', 'Bandwidth', 'Num of calls'))
- append(seperation)
+ append(separation)
show_list = sorted(self.ops_record.items(), key=lambda kv: -kv[1].cuda_time)
for location, event in show_list:
diff --git a/colossalai/utils/profiler/legacy/prof_utils.py b/colossalai/utils/profiler/legacy/prof_utils.py
index 87ad644a7ecc..2f7eee827651 100644
--- a/colossalai/utils/profiler/legacy/prof_utils.py
+++ b/colossalai/utils/profiler/legacy/prof_utils.py
@@ -32,9 +32,9 @@ def _format_memory(nbytes):
return str(nbytes) + ' B'
-def _format_bandwidth(volme: float or int, time_us: int):
+def _format_bandwidth(volume: float or int, time_us: int):
sec_div_mb = (1000.0 / 1024.0)**2
- mb_per_sec = volme / time_us * sec_div_mb
+ mb_per_sec = volume / time_us * sec_div_mb
if mb_per_sec >= 1024.0:
return '{:.3f} GB/s'.format(mb_per_sec / 1024.0)
diff --git a/colossalai/utils/rank_recorder/README.md b/colossalai/utils/rank_recorder/README.md
index e30a925d2a92..da8a6039d543 100644
--- a/colossalai/utils/rank_recorder/README.md
+++ b/colossalai/utils/rank_recorder/README.md
@@ -1,5 +1,5 @@
# Rank Recorder
-This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualise the json file easily.
+This is a useful tool to get the records of certain functions in each rank. The records of each rank will dump into a json file after the end of multiple process program. You can parse and visualize the json file easily.
Before using the tool, you should ensure dist.is_initialized() return true before exit of program.
@@ -20,7 +20,7 @@ with recorder(record_name, current_rank) as r:
```
## Example
-This is a demo to display kernel select in cuda and visualise the cost of several procedures in each rank.
+This is a demo to display kernel select in cuda and visualize the cost of several procedures in each rank.
```python
import time
diff --git a/colossalai/utils/rank_recorder/rank_recorder.py b/colossalai/utils/rank_recorder/rank_recorder.py
index c088ceeb2e87..40bb7e184a12 100644
--- a/colossalai/utils/rank_recorder/rank_recorder.py
+++ b/colossalai/utils/rank_recorder/rank_recorder.py
@@ -133,7 +133,7 @@ def merge_recode(self):
with open(self.export_name + '.json', 'w', encoding='utf-8') as f:
json.dump(recoders, f, ensure_ascii=False)
- def visualise_record(self):
+ def visualize_record(self):
with open(self.export_name + '.json', 'r', encoding='utf-8') as f:
records = json.load(f)
records = dict(records)
@@ -171,7 +171,7 @@ def exit_worker(self):
if rank == 1:
# take the base time of rank 0 as standard
self.merge_recode()
- self.visualise_record()
+ self.visualize_record()
recorder = Recorder()
diff --git a/colossalai/zero/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py
index a7682eaf62e9..51da9be2b1f8 100644
--- a/colossalai/zero/gemini/chunk/chunk.py
+++ b/colossalai/zero/gemini/chunk/chunk.py
@@ -416,7 +416,7 @@ def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Ten
Copy data slice to the memory space indexed by the input tensor in the chunk.
Args:
- tensor (torch.Tensor): the tensor used to retrive meta information
+ tensor (torch.Tensor): the tensor used to retrieve meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
# sanity check
diff --git a/colossalai/zero/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py
index 77368d06d255..38d34f14863e 100644
--- a/colossalai/zero/gemini/chunk/manager.py
+++ b/colossalai/zero/gemini/chunk/manager.py
@@ -157,7 +157,7 @@ def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -
Copy data to the chunk.
Args:
- tensor (torch.Tensor): the tensor used to retrive meta information
+ tensor (torch.Tensor): the tensor used to retrieve meta information
data (torch.Tensor): the tensor to be copied to the chunk
"""
chunk = self.tensor_chunk_map[tensor]
diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py
index 7e23fdb425f8..094320c4aff4 100644
--- a/colossalai/zero/gemini/gemini_ddp.py
+++ b/colossalai/zero/gemini/gemini_ddp.py
@@ -716,7 +716,10 @@ def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict]
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
- if self.current_block_size + tensor_size > self.max_shard_size:
+
+ # before we return the current block and create a new block,
+ # we need to ensure that the current block is not empty
+ if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
diff --git a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py
index f5eb05b4f22a..83903bbf4023 100644
--- a/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py
+++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py
@@ -25,7 +25,7 @@ def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = N
# override
def record_model_data_volume(self) -> None:
"""
- record model data volumn on cuda and cpu.
+ record model data volume on cuda and cpu.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_mem = self._chunk_manager.total_mem['cuda']
diff --git a/colossalai/zero/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py
index f8d99dbce7a4..4bb585677d5b 100644
--- a/colossalai/zero/gemini/memory_tracer/memory_monitor.py
+++ b/colossalai/zero/gemini/memory_tracer/memory_monitor.py
@@ -45,7 +45,7 @@ def clear(self):
class AsyncMemoryMonitor(MemoryMonitor):
"""
- An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
+ An Async Memory Monitor running during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
@@ -67,7 +67,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
async_mem_monitor.save('log.pkl')
Args:
- power (int, optional): the power of time interva. Defaults to 10.
+ power (int, optional): the power of time interval. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
diff --git a/colossalai/zero/gemini/memory_tracer/utils.py b/colossalai/zero/gemini/memory_tracer/utils.py
index 6962c058110e..65f6ba775139 100644
--- a/colossalai/zero/gemini/memory_tracer/utils.py
+++ b/colossalai/zero/gemini/memory_tracer/utils.py
@@ -7,7 +7,7 @@ def colo_model_optimizer_usage(optim) -> Tuple[int, int]:
"""Trace the optimizer memory usage
Args:
- optim (ShardedOptimV2): an instance of ShardedOptimver
+ optim (ShardedOptimV2): an instance of ShardedOptimizer
Returns:
Tuple[int, int]: cuda/cpu memory usage in Byte
diff --git a/colossalai/zero/gemini/utils.py b/colossalai/zero/gemini/utils.py
index e52b5b836b0b..6f4a253b504b 100644
--- a/colossalai/zero/gemini/utils.py
+++ b/colossalai/zero/gemini/utils.py
@@ -73,7 +73,7 @@ def get_static_torch_model(zero_ddp_model,
zero_ddp_model (ZeroDDP): a zero ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
- only_rank_0 (bool): if True, only rank0 has the coverted torch model
+ only_rank_0 (bool): if True, only rank0 has the converted torch model
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
diff --git a/colossalai/zero/legacy/gemini/ophooks/utils.py b/colossalai/zero/legacy/gemini/ophooks/utils.py
index 84e8298c1d51..f88ad2b00e9e 100644
--- a/colossalai/zero/legacy/gemini/ophooks/utils.py
+++ b/colossalai/zero/legacy/gemini/ophooks/utils.py
@@ -88,7 +88,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook],
name: str = "",
filter_fn: Optional[Callable] = None):
- r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
+ r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
assert isinstance(ophook_list, (list, tuple))
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
@@ -103,7 +103,7 @@ def register_ophooks_recursively(module: torch.nn.Module,
if len(list(module.parameters(recurse=False))) == 0:
return
- # return from flitered module
+ # return from filtered module
if filter_fn is not None and filter_fn(module):
return
diff --git a/colossalai/zero/legacy/gemini/tensor_utils.py b/colossalai/zero/legacy/gemini/tensor_utils.py
index b7f23e0253fd..843e330ee2c6 100644
--- a/colossalai/zero/legacy/gemini/tensor_utils.py
+++ b/colossalai/zero/legacy/gemini/tensor_utils.py
@@ -77,7 +77,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
move a tensor to the target_device
Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
- target_device: a traget device, if type is int, it the index of cuda card.
+ target_device: a target device, if type is int, it the index of cuda card.
"""
if not isinstance(target_device, torch.device):
target_device = torch.device(f'cuda:{target_device}')
diff --git a/colossalai/zero/legacy/init_ctx/init_context.py b/colossalai/zero/legacy/init_ctx/init_context.py
index a3fa46b38b5a..84e2d2f4f8e1 100644
--- a/colossalai/zero/legacy/init_ctx/init_context.py
+++ b/colossalai/zero/legacy/init_ctx/init_context.py
@@ -46,7 +46,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
1. Convert the model to fp16.
- 2. The paramaters of the module are adapted to type ShardedParameter.
+ 2. The parameters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
Args:
diff --git a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py
index be3842beb208..e7064277fb3c 100644
--- a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py
+++ b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py
@@ -69,7 +69,7 @@ class ShardedModelV2(nn.Module):
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
Defaults to 'cuda'.
- gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
+ gradient_predivide_factor (Optional[float], optional): Gradient is divided by this value before reduce-scatter. Defaults to 1.0.
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
@@ -205,7 +205,7 @@ def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> N
exit(0)
"""
if self._use_memory_tracer:
- self.logger.error(f'dump memort tracer collected information to a {filename}', ranks=[0])
+ self.logger.error(f'dump memory tracer collected information to a {filename}', ranks=[0])
if gpc.get_global_rank() == 0:
with open(filename, 'w+') as f:
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
@@ -385,7 +385,7 @@ def _save_grad(self, param: Parameter, grad: torch.Tensor):
# make parameters point to gradient
assert param.colo_attr.saved_grad.is_null(
- ), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
+ ), 'Gradient accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.grad_payload_reset(grad.data)
# release the memory of param
diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py
index afc98e7a7f54..218f7603bc54 100644
--- a/colossalai/zero/low_level/_utils.py
+++ b/colossalai/zero/low_level/_utils.py
@@ -261,7 +261,7 @@ def sync_param(flat_tensor, tensor_list):
share the same memory space. This function will update the tensor list so that
they point to the same value.
- :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit
+ :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor list
:param tensor_list: A list of tensors corresponding to the flattened tensor
:type flat_tensor: torch.Tensor
:type tensor_list: List[torch.Tensor]
diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py
index d4d03e5b5fcd..ee03c0f0ae15 100644
--- a/colossalai/zero/low_level/low_level_optim.py
+++ b/colossalai/zero/low_level/low_level_optim.py
@@ -207,8 +207,8 @@ def __init__(
for param in self._working_param_groups[group_id]:
self._param_store.set_param_reduction_state(param, False)
- # intialize communication stream for
- # communication-compuation overlapping
+ # initialize communication stream for
+ # communication-computation overlapping
if self._overlap_communication:
self._comm_stream = torch.cuda.Stream()
@@ -269,7 +269,7 @@ def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)]
- # partititon the parameters in a greedy fashion
+ # partition the parameters in a greedy fashion
sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True)
for param in sorted_params:
# allocate this parameter to the rank with
@@ -297,7 +297,7 @@ def _attach_reduction_hook(self):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad:
- # determines the reduction destionation rank
+ # determines the reduction destination rank
# this is only valid for stage 2
# dst_rank = None means using all-reduce
# else using reduce
diff --git a/docs/sidebars.json b/docs/sidebars.json
index c3cfbbeef689..8be40e4512f9 100644
--- a/docs/sidebars.json
+++ b/docs/sidebars.json
@@ -64,7 +64,6 @@
},
"features/pipeline_parallel",
"features/nvme_offload",
- "features/lazy_init",
"features/cluster_utils"
]
},
diff --git a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md
index 22d52fb3cd1a..978ac32fc78e 100644
--- a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md
+++ b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md
@@ -141,16 +141,16 @@ for mn, module in model.named_modules():
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
elif 'wte' in mn or 'wpe' in mn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
elif 'c_attn' in mn or 'c_proj' in mn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
```
The modified model is illustrated below.
diff --git a/docs/source/en/features/lazy_init.md b/docs/source/en/features/lazy_init.md
deleted file mode 100644
index 40f5da1cb84d..000000000000
--- a/docs/source/en/features/lazy_init.md
+++ /dev/null
@@ -1,71 +0,0 @@
-# Lazy initialization
-
-Author: Hongxin Liu
-
-**Prerequisite**
-- [Booster API](../basics/booster_api.md)
-- [Booster Plugins](../basics/booster_plugins.md)
-- [Booster Checkpoint](../basics/booster_checkpoint.md)
-
-**Related discussion**
-- [Lazy initialization of model](https://github.com/hpcaitech/ColossalAI/discussions/3124)
-
-## Introduction
-
-LazyTensor allows DL framework (PyTorch) to execute operations lazily, by storing all operations related to it and reruning them when it's required to be materialized.
-
-LazyInit defers model initialization and it's based on LazyTensor.
-
-This is especially useful when we use model parallelism to train large models, in which case the model cannot fit in GPU memory. Through this, we can initialize model tensors using meta tensor and do static analysis to get shard strategy. And then materialize each tensor and apply the shard strategy. The static analysis can be omitted if the shard strategy is known in advance.
-
-## Usage
-
-You may use lazy initialization when using Gemini, tensor parallelism, pipeline parallelism, and auto-parallelism. In other cases, you may not need to use lazy initialization.
-
-Gemini is compatible with lazy initialization. You can use them together directly.
-
-```python
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin
-from colossalai.lazy import LazyInitContext
-from colossalai.nn.optimizer import HybridAdam
-from torch.nn import Linear
-import colossalai
-
-colossalai.launch_from_torch({})
-
-plugin = GeminiPlugin()
-booster = Booster(plugin=plugin)
-
-with LazyInitContext():
- model = Linear(10, 10)
-
-optimizer = HybridAdam(model.parameters())
-model, optimizer, *_ = booster.boost(model, optimizer)
-```
-
-Note that using lazy initialization when using Gemini is not necessary but recommended. If you don't use lazy initialization, you may get OOM error when initializing the model. If you use lazy initialization, you can avoid this error.
-
-> ⚠ Lazy initialization support for tensor parallelism, pipeline parallelism, and auto-parallelism is still under development.
-
-### Load from pretrained model
-
-We should not load pretrained weight in `LazyInitContext`. If so, lazy initialization is meaningless, as the checkpoint is loaded and it takes much GPU memory. A recommended way is to initialize model from scratch in `LazyInitContext` and load pretrained weight outside `LazyInitContext` after calling `Booster.boost()`.
-
-
-```python
-with LazyInitContext():
- model = GPT2LMHeadModel(config)
-
-optimizer = ...
-lr_scheduler = ...
-dataloader = ...
-model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
-
-booster.load_model(model, pretrained_path)
-```
-
-
-As booster supports both pytorch-fashion checkpoint and huggingface/transformers-fashion pretrained weight, the `pretrained_path` of the above pseudo-code can be either a checkpoint file path or a pretrained weight path. Note that it does not support loading pretrained weights from network. You should download the pretrained weight first and then use a local path.
-
-
diff --git a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md
index c4131e593437..b4e0d18a2647 100644
--- a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md
+++ b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md
@@ -126,16 +126,16 @@ for mn, module in model.named_modules():
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
elif 'wte' in mn or 'wpe' in mn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
elif 'c_attn' in mn or 'c_proj' in mn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
```
修改后的模型如下图所示。
diff --git a/docs/source/zh-Hans/features/lazy_init.md b/docs/source/zh-Hans/features/lazy_init.md
deleted file mode 100644
index 9a3cd90caa8d..000000000000
--- a/docs/source/zh-Hans/features/lazy_init.md
+++ /dev/null
@@ -1,71 +0,0 @@
-# 惰性初始化
-
-作者: Hongxin Liu
-
-**前置教程**
-- [Booster API](../basics/booster_api.md)
-- [Booster 插件](../basics/booster_plugins.md)
-- [Booster Checkpoint](../basics/booster_checkpoint.md)
-
-**相关讨论**
-- [模型的惰性初始化](https://github.com/hpcaitech/ColossalAI/discussions/3124)
-
-## 引言
-
-LazyTensor 允许深度学习框架 (PyTorch) 延迟执行操作,方法是存储与其相关的所有操作并在需要具体化时重新运行它们。
-
-LazyInit 基于 LazyTensor,并支持延迟模型初始化。
-
-这在我们使用模型并行来训练大型模型时特别有用,在这种情况下模型无法容纳在 GPU 内存中。通过这个,我们可以使用 Meta 张量初始化模型张量并进行静态分析以获得分片策略。然后具体化每个张量并应用分片策略。如果事先知道分片策略,则可以省略静态分析。
-
-## 用法
-
-您可以在使用 Gemini、张量并行、流水线并行和自动并行时使用惰性初始化。在其他情况下,您可能不需要使用惰性初始化。
-
-Gemini 与惰性初始化兼容。您可以直接将它们一起使用。
-
-```python
-from colossalai.booster import Booster
-from colossalai.booster.plugin import GeminiPlugin
-from colossalai.lazy import LazyInitContext
-from colossalai.nn.optimizer import HybridAdam
-from torch.nn import Linear
-import colossalai
-
-colossalai.launch_from_torch({})
-
-plugin = GeminiPlugin()
-booster = Booster(plugin=plugin)
-
-with LazyInitContext():
- model = Linear(10, 10)
-
-optimizer = HybridAdam(model.parameters())
-model, optimizer, *_ = booster.boost(model, optimizer)
-```
-
-请注意,在使用 Gemini 时使用惰性初始化不是必需的,但建议使用。如果不使用惰性初始化,在初始化模型时可能会出现 OOM 错误。如果使用惰性初始化,则可以避免此错误。
-
-> ⚠ 对张量并行、流水线并行和自动并行的惰性初始化支持仍在开发中。
-
-### 从预训练模型加载
-
-我们不应该在 `LazyInitContext` 中加载预训练权重。如果这样,惰性初始化就没有意义,因为检查点已加载并且需要大量 GPU 内存。推荐的方法是在 `LazyInitContext` 中初始化模型,并在调用 `Booster.boost()` 后在 `LazyInitContext` 之外加载预训练权重。
-
-
-```python
-with LazyInitContext():
- model = GPT2LMHeadModel(config)
-
-optimizer = ...
-lr_scheduler = ...
-dataloader = ...
-model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
-
-booster.load_model(model, pretrained_path)
-```
-
-
-由于 booster 同时支持 pytorch 风格的 checkpoint 和 huggingface/transformers 风格的预训练权重,上述伪代码的 `pretrained_path` 可以是 checkpoint 文件路径或预训练权重路径。请注意,它不支持从网络加载预训练权重。您应该先下载预训练的权重,然后使用本地路径。
-
-
diff --git a/examples/community/roberta/README.md b/examples/community/roberta/README.md
index 8aefa327a4b4..000fce63f35f 100644
--- a/examples/community/roberta/README.md
+++ b/examples/community/roberta/README.md
@@ -44,7 +44,7 @@ following the `README.md`, load the h5py generated by preprocess of step 1 to pr
## 3. Finetune
-The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transfomers from Hugging Face to finetune downstream application.
+The checkpoint produced by this repo can replace `pytorch_model.bin` from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main) directly. Then use transformers from Hugging Face to finetune downstream application.
## Contributors
The example is contributed by AI team from [Moore Threads](https://www.mthreads.com/). If you find any problems for pretraining, please file an issue or send an email to yehua.zhang@mthreads.com. At last, welcome any form of contribution!
diff --git a/examples/community/roberta/preprocessing/README.md b/examples/community/roberta/preprocessing/README.md
index 17cc2f4dc22c..2ed747541280 100644
--- a/examples/community/roberta/preprocessing/README.md
+++ b/examples/community/roberta/preprocessing/README.md
@@ -25,10 +25,10 @@ Firstly, each file has multiple documents, and each document contains multiple s
In this example, split 200G Corpus into 100 shard, and each shard is about 2G. The size of the shard is memory-dependent, taking into account the number of servers, the memory used by the tokenizer, and the memory used by the multi-process training to read the shard (n data parallel requires n\*shard_size memory). **To sum up, data preprocessing and model pretraining requires fighting with hardware, not just GPU.**
```python
-python sentence_split.py --input_path /orginal_corpus --output_path /shard --shard 100
+python sentence_split.py --input_path /original_corpus --output_path /shard --shard 100
# This step takes a short time
```
-* `--input_path`: all original corpus, e.g., /orginal_corpus/0.json /orginal_corpus/1.json ...
+* `--input_path`: all original corpus, e.g., /original_corpus/0.json /original_corpus/1.json ...
* `--output_path`: all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
* `--shard`: Number of shard, e.g., 10, 50, or 100
@@ -76,7 +76,7 @@ make
* `--input_path`: location of all shard with split sentences, e.g., /shard/0.txt, /shard/1.txt ...
* `--output_path`: location of all h5 with token_id, input_mask, segment_ids and masked_lm_positions, e.g., /h5/0.h5, /h5/1.h5 ...
-* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenzier.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main)
+* `--tokenizer_path`: tokenizer path contains huggingface tokenizer.json. Download config.json, special_tokens_map.json, vocab.txt and tokenizer.json from [hfl/chinese-roberta-wwm-ext-large](https://huggingface.co/hfl/chinese-roberta-wwm-ext-large/tree/main)
* `--backend`: python or c++, **specifies c++ can obtain faster preprocess speed**
* `--dupe_factor`: specifies how many times the preprocessor repeats to create the input from the same article/document
* `--worker`: number of process
diff --git a/examples/community/roberta/pretraining/README.md b/examples/community/roberta/pretraining/README.md
index c248fc1f5708..8abe48aa6c0e 100644
--- a/examples/community/roberta/pretraining/README.md
+++ b/examples/community/roberta/pretraining/README.md
@@ -13,7 +13,7 @@ bash run_pretrain.sh
* `--bert_config`: config.json which represent model
* `--mlm`: model type of backbone, bert or deberta_v2
-2. if resume training from earylier checkpoint, run the script below.
+2. if resume training from earlier checkpoint, run the script below.
```shell
bash run_pretrain_resume.sh
diff --git a/examples/community/roberta/pretraining/arguments.py b/examples/community/roberta/pretraining/arguments.py
index 40210c4b1be7..e0702ceb59b0 100644
--- a/examples/community/roberta/pretraining/arguments.py
+++ b/examples/community/roberta/pretraining/arguments.py
@@ -46,7 +46,7 @@ def parse_args():
type=int,
default=1,
help="This param makes sure that a certain task is repeated for this time steps to \
- optimise on the back propogation speed with APEX's DistributedDataParallel")
+ optimize on the back propagation speed with APEX's DistributedDataParallel")
parser.add_argument("--max_predictions_per_seq",
"--max_pred",
default=80,
@@ -73,12 +73,12 @@ def parse_args():
help="location of saving checkpoint, which contains model and optimizer")
parser.add_argument('--seed', type=int, default=42, help="random seed for initialization")
parser.add_argument('--vscode_debug', action='store_true', help="use vscode to debug")
- parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoin")
+ parser.add_argument('--load_pretrain_model', default='', type=str, help="location of model's checkpoint")
parser.add_argument(
'--load_optimizer_lr',
default='',
type=str,
- help="location of checkpoint, which contains optimerzier, learning rate, epoch, shard and global_step")
+ help="location of checkpoint, which contains optimizer, learning rate, epoch, shard and global_step")
parser.add_argument('--resume_train', action='store_true', help="whether resume training from a early checkpoint")
parser.add_argument('--mlm', default='bert', type=str, help="model type, bert or deberta")
parser.add_argument('--checkpoint_activations', action='store_true', help="whether to use gradient checkpointing")
diff --git a/examples/community/roberta/pretraining/model/bert.py b/examples/community/roberta/pretraining/model/bert.py
index a5da1bea6f65..abdf925d0540 100644
--- a/examples/community/roberta/pretraining/model/bert.py
+++ b/examples/community/roberta/pretraining/model/bert.py
@@ -327,7 +327,7 @@ def forward(
attention_scores = attention_scores + relative_position_scores
elif self.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhld,lrd->bhlr", key_layer, positional_embedding)
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
diff --git a/examples/community/roberta/pretraining/run_pretraining.py b/examples/community/roberta/pretraining/run_pretraining.py
index 9a6ffc1c5661..a72bdf775644 100644
--- a/examples/community/roberta/pretraining/run_pretraining.py
+++ b/examples/community/roberta/pretraining/run_pretraining.py
@@ -78,7 +78,7 @@ def main():
default_pg=shard_pg):
config, model, numel = get_model(args, logger)
- # asign running configurations
+ # assign running configurations
gemini_config = None
if args.distplan.startswith("CAI_ZeRO"):
optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
diff --git a/examples/community/roberta/pretraining/utils/exp_util.py b/examples/community/roberta/pretraining/utils/exp_util.py
index 0cdb56bad031..4a2c9d8a47ad 100644
--- a/examples/community/roberta/pretraining/utils/exp_util.py
+++ b/examples/community/roberta/pretraining/utils/exp_util.py
@@ -97,7 +97,7 @@ def throughput_calculator(numel, args, config, iteration_time, total_iterations,
def synchronize():
if not torch.distributed.is_available():
return
- if not torch.distributed.is_intialized():
+ if not torch.distributed.is_initialized():
return
world_size = torch.distributed.get_world_size()
if world_size == 1:
diff --git a/examples/community/roberta/pretraining/utils/global_vars.py b/examples/community/roberta/pretraining/utils/global_vars.py
index 7b0c5a2be73d..9eef19e71614 100644
--- a/examples/community/roberta/pretraining/utils/global_vars.py
+++ b/examples/community/roberta/pretraining/utils/global_vars.py
@@ -110,7 +110,7 @@ def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
- # polutes the runs list, so we just add each as a scalar
+ # pollutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
diff --git a/examples/images/dreambooth/README.md b/examples/images/dreambooth/README.md
index 7c117d841e24..ba4c1a71034a 100644
--- a/examples/images/dreambooth/README.md
+++ b/examples/images/dreambooth/README.md
@@ -37,7 +37,7 @@ The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Mode
## Training
-We provide the script `colossalai.sh` to run the training task with colossalai. Meanwhile, we also provided traditional training process of dreambooth, `dreambooth.sh`, for possible comparation. For instance, the script of training process for [stable-diffusion-v1-4] model can be modified into:
+We provide the script `colossalai.sh` to run the training task with colossalai. Meanwhile, we also provided traditional training process of dreambooth, `dreambooth.sh`, for possible comparison. For instance, the script of training process for [stable-diffusion-v1-4] model can be modified into:
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
@@ -92,6 +92,29 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
--placement="cuda"
```
+## New API
+We have modified our previous implementation of Dreambooth with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in `train_dreambooth_colossalai.py`.
+We have also offer a shell script `test_ci.sh` for you to go through all our plugins for the booster.
+For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.
+
+## Performance
+
+| Strategy | #GPU | Batch Size | GPU RAM(GB) | speedup |
+|:--------------:|:----:|:----------:|:-----------:|:-------:|
+| Traditional | 1 | 16 | oom | \ |
+| Traditional | 1 | 8 | 61.81 | 1 |
+| torch_ddp | 4 | 16 | oom | \ |
+| torch_ddp | 4 | 8 | 41.97 | 0.97 |
+| gemini | 4 | 16 | 53.29 | \ |
+| gemini | 4 | 8 | 29.36 | 2.00 |
+| low_level_zero | 4 | 16 | 52.80 | \ |
+| low_level_zero | 4 | 8 | 28.87 | 2.02 |
+
+The evaluation is performed on 4 Nvidia A100 GPUs with 80GB memory each, with GPU 0 & 1, 2 & 3 connected with NVLink.
+We finetuned the [stable-diffusion-v1-4](https://huggingface.co/stabilityai/stable-diffusion-v1-4) model with 512x512 resolution on the [Teyvat](https://huggingface.co/datasets/Fazzie/Teyvat) dataset and compared
+the memory cost and the throughput for the plugins.
+
+
## Inference
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt.
diff --git a/examples/images/dreambooth/colossalai.sh b/examples/images/dreambooth/colossalai.sh
index 227d8b8bdb04..db4562dbc921 100755
--- a/examples/images/dreambooth/colossalai.sh
+++ b/examples/images/dreambooth/colossalai.sh
@@ -1,22 +1,18 @@
-export MODEL_NAME=
-export INSTANCE_DIR=
-export CLASS_DIR="path-to-class-images"
-export OUTPUT_DIR="path-to-save-model"
-
-HF_DATASETS_OFFLINE=1
-TRANSFORMERS_OFFLINE=1
+HF_DATASETS_OFFLINE=1
+TRANSFORMERS_OFFLINE=1
DIFFUSERS_OFFLINE=1
-torchrun --nproc_per_node 2 --master_port=25641 train_dreambooth_colossalai.py \
- --pretrained_model_name_or_path=$MODEL_NAME \
- --instance_data_dir=$INSTANCE_DIR \
- --output_dir=$OUTPUT_DIR \
- --instance_prompt="a photo of a dog" \
+torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
+ --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
+ --instance_data_dir="/data/dreambooth/Teyvat/data" \
+ --output_dir="./weight_output" \
+ --instance_prompt="a picture of a dog" \
--resolution=512 \
+ --plugin="gemini" \
--train_batch_size=1 \
- --gradient_accumulation_steps=1 \
--learning_rate=5e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
- --placement="cuda" \
+ --test_run=True \
+ --placement="auto" \
diff --git a/examples/images/dreambooth/dreambooth.sh b/examples/images/dreambooth/dreambooth.sh
index e063bc8279c5..f6b8f5e1b87e 100644
--- a/examples/images/dreambooth/dreambooth.sh
+++ b/examples/images/dreambooth/dreambooth.sh
@@ -1,7 +1,7 @@
python train_dreambooth.py \
- --pretrained_model_name_or_path= ## Your Model Path \
- --instance_data_dir= ## Your Training Input Pics Path \
- --output_dir="path-to-save-model" \
+ --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
+ --instance_data_dir="/data/dreambooth/Teyvat/data" \
+ --output_dir="./weight_output" \
--instance_prompt="a photo of a dog" \
--resolution=512 \
--train_batch_size=1 \
diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh
index e69de29bb2d1..21f45adae2a0 100644
--- a/examples/images/dreambooth/test_ci.sh
+++ b/examples/images/dreambooth/test_ci.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+set -xe
+pip install -r requirements.txt
+
+HF_DATASETS_OFFLINE=1
+TRANSFORMERS_OFFLINE=1
+DIFFUSERS_OFFLINE=1
+
+# "torch_ddp" "torch_ddp_fp16" "low_level_zero"
+for plugin in "gemini"; do
+ torchrun --nproc_per_node 4 --standalone train_dreambooth_colossalai.py \
+ --pretrained_model_name_or_path="/data/dreambooth/diffuser/stable-diffusion-v1-4" \
+ --instance_data_dir="/data/dreambooth/Teyvat/data" \
+ --output_dir="./weight_output" \
+ --instance_prompt="a picture of a dog" \
+ --resolution=512 \
+ --plugin=$plugin \
+ --train_batch_size=1 \
+ --learning_rate=5e-6 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --test_run=True \
+ --num_class_images=200 \
+ --placement="auto" # "cuda"
+done
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py
index d07febea0a84..888b28de8306 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai.py
@@ -4,6 +4,7 @@
import os
from pathlib import Path
from typing import Optional
+import shutil
import torch
import torch.nn.functional as F
@@ -21,9 +22,12 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
+from colossalai.zero import ColoInitContext
from colossalai.zero.gemini import get_static_torch_model
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
disable_existing_loggers()
logger = get_dist_logger()
@@ -58,6 +62,13 @@ def parse_args(input_args=None):
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
+ parser.add_argument(
+ "--externel_unet_path",
+ type=str,
+ default=None,
+ required=False,
+ help="Path to the externel unet model.",
+ )
parser.add_argument(
"--revision",
type=str,
@@ -187,12 +198,19 @@ def parse_args(input_args=None):
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument("--test_run", default=False, help="Whether to use a smaller dataset for test run.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
+ parser.add_argument('-p',
+ '--plugin',
+ type=str,
+ default='torch_ddp',
+ choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
+ help="plugin to use")
parser.add_argument(
"--logging_dir",
type=str,
@@ -250,6 +268,7 @@ def __init__(
class_prompt=None,
size=512,
center_crop=False,
+ test=False,
):
self.size = size
self.center_crop = center_crop
@@ -260,6 +279,8 @@ def __init__(
raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir())
+ if test:
+ self.instance_images_path = self.instance_images_path[:10]
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
@@ -339,18 +360,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}"
-# Gemini + ZeRO DDP
-def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
- from colossalai.nn.parallel import GeminiDDP
-
- model = GeminiDDP(model,
- device=get_current_device(),
- placement_policy=placement_policy,
- pin_memory=True,
- search_range_mb=64)
- return model
-
-
def main(args):
if args.seed is None:
colossalai.launch_from_torch(config={})
@@ -392,7 +401,7 @@ def main(args):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
- hash_image = hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha256(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
@@ -452,12 +461,18 @@ def main(args):
revision=args.revision,
)
- logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
- with ColoInitContext(device=get_current_device()):
+
+ if args.externel_unet_path is None:
+ logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
- subfolder="unet",
- revision=args.revision,
- low_cpu_mem_usage=False)
+ subfolder="unet",
+ revision=args.revision,
+ low_cpu_mem_usage=False)
+ else:
+ logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
+ unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
+ revision=args.revision,
+ low_cpu_mem_usage=False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -468,10 +483,22 @@ def main(args):
if args.scale_lr:
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
- unet = gemini_zero_dpp(unet, args.placement)
+ # Use Booster API to use Gemini/Zero with ColossalAI
+
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
+
+ booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero
- optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
+ optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
@@ -486,6 +513,7 @@ def main(args):
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
+ test=args.test_run
)
def collate_fn(examples):
@@ -554,6 +582,8 @@ def collate_fn(examples):
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)
+
# Train!
total_batch_size = args.train_batch_size * world_size
@@ -642,36 +672,24 @@ def collate_fn(examples):
if global_step % args.save_steps == 0:
torch.cuda.synchronize()
- torch_unet = get_static_torch_model(unet)
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
if local_rank == 0:
- pipeline = DiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=torch_unet,
- revision=args.revision,
- )
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
- pipeline.save_pretrained(save_path)
+ if not os.path.exists(os.path.join(save_path, "config.json")):
+ shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
if global_step >= args.max_train_steps:
break
-
torch.cuda.synchronize()
- unet = get_static_torch_model(unet)
+ booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
+ logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
if local_rank == 0:
- pipeline = DiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unet,
- revision=args.revision,
- )
-
- pipeline.save_pretrained(args.output_dir)
- logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
-
+ if not os.path.exists(os.path.join(args.output_dir, "config.json")):
+ shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
-
if __name__ == "__main__":
args = parse_args()
main(args)
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
index 6715b473a567..dce65ff514b7 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
@@ -4,6 +4,7 @@
import os
from pathlib import Path
from typing import Optional
+import shutil
import torch
import torch.nn.functional as F
@@ -23,9 +24,12 @@
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
disable_existing_loggers()
logger = get_dist_logger()
@@ -60,6 +64,13 @@ def parse_args(input_args=None):
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
+ parser.add_argument(
+ "--externel_unet_path",
+ type=str,
+ default=None,
+ required=False,
+ help="Path to the externel unet model.",
+ )
parser.add_argument(
"--revision",
type=str,
@@ -195,6 +206,12 @@ def parse_args(input_args=None):
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
+ parser.add_argument('-p',
+ '--plugin',
+ type=str,
+ default='torch_ddp',
+ choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
+ help="plugin to use")
parser.add_argument(
"--logging_dir",
type=str,
@@ -341,18 +358,6 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}"
-# Gemini + ZeRO DDP
-def gemini_zero_dpp(model: torch.nn.Module, placement_policy: str = "auto"):
- from colossalai.nn.parallel import GeminiDDP
-
- model = GeminiDDP(model,
- device=get_current_device(),
- placement_policy=placement_policy,
- pin_memory=True,
- search_range_mb=64)
- return model
-
-
def main(args):
if args.seed is None:
colossalai.launch_from_torch(config={})
@@ -394,7 +399,7 @@ def main(args):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
- hash_image = hashlib.sha1(image.tobytes()).hexdigest()
+ hash_image = hashlib.sha256(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
@@ -454,32 +459,42 @@ def main(args):
revision=args.revision,
)
- logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
- with ColoInitContext(device=get_current_device()):
+
+ if args.externel_unet_path is None:
+ logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
- subfolder="unet",
- revision=args.revision,
- low_cpu_mem_usage=False)
- unet.requires_grad_(False)
-
- # Set correct lora layers
- lora_attn_procs = {}
- for name in unet.attn_processors.keys():
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
- if name.startswith("mid_block"):
- hidden_size = unet.config.block_out_channels[-1]
- elif name.startswith("up_blocks"):
- block_id = int(name[len("up_blocks.")])
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
- elif name.startswith("down_blocks"):
- block_id = int(name[len("down_blocks.")])
- hidden_size = unet.config.block_out_channels[block_id]
-
- lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size,
- cross_attention_dim=cross_attention_dim)
-
- unet.set_attn_processor(lora_attn_procs)
- lora_layers = AttnProcsLayers(unet.attn_processors)
+ subfolder="unet",
+ revision=args.revision,
+ low_cpu_mem_usage=False)
+ else:
+ logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
+ unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
+ revision=args.revision,
+ low_cpu_mem_usage=False)
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.revision,
+ low_cpu_mem_usage=False)
+ unet.requires_grad_(False)
+
+ # Set correct lora layers
+ lora_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim)
+
+ unet.set_attn_processor(lora_attn_procs)
+ lora_layers = AttnProcsLayers(unet.attn_processors)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
@@ -490,10 +505,22 @@ def main(args):
if args.scale_lr:
args.learning_rate = args.learning_rate * args.train_batch_size * world_size
- unet = gemini_zero_dpp(unet, args.placement)
+ # Use Booster API to use Gemini/Zero with ColossalAI
+
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
+
+ booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero
- optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
+ optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
# load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
@@ -576,6 +603,8 @@ def collate_fn(examples):
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+ unet, optimizer, _, _, lr_scheduler = booster.boost(unet, optimizer, lr_scheduler=lr_scheduler)
+
# Train!
total_batch_size = args.train_batch_size * world_size
@@ -664,27 +693,24 @@ def collate_fn(examples):
if global_step % args.save_steps == 0:
torch.cuda.synchronize()
- torch_unet = get_static_torch_model(unet)
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
if local_rank == 0:
- save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
- torch_unet = torch_unet.to(torch.float32)
- torch_unet.save_attn_procs(save_path)
+ if not os.path.exists(os.path.join(save_path, "config.json")):
+ shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
if global_step >= args.max_train_steps:
break
-
torch.cuda.synchronize()
- torch_unet = get_static_torch_model(unet)
+ booster.save_model(unet, os.path.join(args.output_dir, "diffusion_pytorch_model.bin"))
+ logger.info(f"Saving model checkpoint to {args.output_dir} on rank {local_rank}")
if local_rank == 0:
- torch_unet = torch_unet.to(torch.float32)
- torch_unet.save_attn_procs(save_path)
- logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
-
+ if not os.path.exists(os.path.join(args.output_dir, "config.json")):
+ shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
-
if __name__ == "__main__":
args = parse_args()
main(args)
diff --git a/examples/images/vit/README.md b/examples/images/vit/README.md
index 4423d85d19e0..7c4147b76457 100644
--- a/examples/images/vit/README.md
+++ b/examples/images/vit/README.md
@@ -1,61 +1,28 @@
-# Vision Transformer with ColoTensor
+## Overview
-# Overview
+Vision Transformer is a class of Transformer model tailored for computer vision tasks. It was first proposed in paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) and achieved SOTA results on various tasks at that time.
-In this example, we will run Vision Transformer with ColoTensor.
+In our example, we are using pretrained weights of ViT loaded from HuggingFace.
+We adapt the ViT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
-We use model **ViTForImageClassification** from Hugging Face [Link](https://huggingface.co/docs/transformers/model_doc/vit) for unit test.
-You can change world size or decide whether use DDP in our code.
+## Run Demo
-We use model **vision_transformer** from timm [Link](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) for training example.
-
-(2022/6/28) The default configuration now supports 2DP+2TP with gradient accumulation and checkpoint support. Zero is not supported at present.
-
-# Requirement
-
-Install colossalai version >= 0.1.11
-
-## Unit test
-To run unit test, you should install pytest, transformers with:
-```shell
-pip install pytest transformers
+By running the following script:
+```bash
+bash run_demo.sh
```
+You will finetune a a [ViT-base](https://huggingface.co/google/vit-base-patch16-224) model on this [dataset](https://huggingface.co/datasets/beans), with more than 8000 images of bean leaves. This dataset is for image classification task and there are 3 labels: ['angular_leaf_spot', 'bean_rust', 'healthy'].
-## Training example
-To run training example with ViT-S, you should install **NVIDIA DALI** from [Link](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html) for dataloader support.
-You also need to install timm and titans for model/dataloader support with:
-```shell
-pip install timm titans
-```
+The script can be modified if you want to try another set of hyperparameters or change to another ViT model with different size.
-### Data preparation
-You can download the ImageNet dataset from the [ImageNet official website](https://www.image-net.org/download.php). You should get the raw images after downloading the dataset. As we use **NVIDIA DALI** to read data, we use the TFRecords dataset instead of raw Imagenet dataset. This offers better speedup to IO. If you don't have TFRecords dataset, follow [imagenet-tools](https://github.com/ver217/imagenet-tools) to build one.
+The demo code refers to this [blog](https://huggingface.co/blog/fine-tune-vit).
-Before you start training, you need to set the environment variable `DATA` so that the script knows where to fetch the data for DALI dataloader.
-```shell
-export DATA=/path/to/ILSVRC2012
-```
-# How to run
+## Run Benchmark
-## Unit test
-In your terminal
-```shell
-pytest test_vit.py
+You can run benchmark for ViT model by running the following script:
+```bash
+bash run_benchmark.sh
```
-
-This will evaluate models with different **world_size** and **use_ddp**.
-
-## Training example
-Modify the settings in run.sh according to your environment.
-For example, if you set `--nproc_per_node=8` in `run.sh` and `TP_WORLD_SIZE=2` in your config file,
-data parallel size will be automatically calculated as 4.
-Thus, the parallel strategy is set to 4DP+2TP.
-
-Then in your terminal
-```shell
-sh run.sh
-```
-
-This will start ViT-S training with ImageNet.
+The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your own set of hyperparameters for testing.
\ No newline at end of file
diff --git a/examples/images/vit/args.py b/examples/images/vit/args.py
new file mode 100644
index 000000000000..e4a873a9eb52
--- /dev/null
+++ b/examples/images/vit/args.py
@@ -0,0 +1,124 @@
+from colossalai import get_default_parser
+
+def parse_demo_args():
+
+ parser = get_default_parser()
+ parser.add_argument(
+ "--model_name_or_path",
+ type=str,
+ default="google/vit-base-patch16-224",
+ help="Path to pretrained model or model identifier from huggingface.co/models."
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ default="./output_model.bin",
+ help="The path of your saved model after finetuning."
+ )
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
+ )
+ parser.add_argument(
+ "--num_epoch",
+ type=int,
+ default=3,
+ help="Number of epochs."
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=32,
+ help="Batch size (per dp group) for the training dataloader."
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=3e-4,
+ help="Initial learning rate (after the potential warmup period) to use."
+ )
+ parser.add_argument(
+ "--warmup_ratio",
+ type=float,
+ default=0.3,
+ help="Ratio of warmup steps against total training steps."
+ )
+ parser.add_argument(
+ "--weight_decay",
+ type=float,
+ default=0.1,
+ help="Weight decay to use."
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="A seed for reproducible training."
+ )
+
+ args = parser.parse_args()
+ return args
+
+def parse_benchmark_args():
+
+ parser = get_default_parser()
+
+ parser.add_argument(
+ "--model_name_or_path",
+ type=str,
+ default="google/vit-base-patch16-224",
+ help="Path to a pretrained model or model identifier from huggingface.co/models."
+ )
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=8,
+ help="Batch size (per dp group) for the training dataloader."
+ )
+ parser.add_argument(
+ "--num_labels",
+ type=int,
+ default=10,
+ help="Number of labels for classification."
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-5,
+ help="Initial learning rate (after the potential warmup period) to use."
+ )
+ parser.add_argument(
+ "--weight_decay",
+ type=float,
+ default=0.0,
+ help="Weight decay to use."
+ )
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=20,
+ help="Total number of training steps to perform."
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="A seed for reproducible training."
+ )
+ parser.add_argument(
+ "--mem_cap",
+ type=int,
+ default=0,
+ help="Limit on the usage of space for each GPU (in GB)."
+ )
+ args = parser.parse_args()
+
+ return args
\ No newline at end of file
diff --git a/examples/images/vit/configs/vit_1d_tp2.py b/examples/images/vit/configs/vit_1d_tp2.py
deleted file mode 100644
index fbf399f2e50d..000000000000
--- a/examples/images/vit/configs/vit_1d_tp2.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from colossalai.amp import AMP_TYPE
-
-# hyperparameters
-# BATCH_SIZE is as per GPU
-# global batch size = BATCH_SIZE x data parallel size
-BATCH_SIZE = 256
-LEARNING_RATE = 3e-3
-WEIGHT_DECAY = 0.3
-NUM_EPOCHS = 300
-WARMUP_EPOCHS = 32
-
-# model config
-IMG_SIZE = 224
-PATCH_SIZE = 16
-HIDDEN_SIZE = 384
-DEPTH = 12
-NUM_HEADS = 6
-MLP_RATIO = 4
-NUM_CLASSES = 1000
-CHECKPOINT = False
-SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
-
-USE_DDP = True
-TP_WORLD_SIZE = 2
-TP_TYPE = 'row'
-parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
-
-fp16 = dict(mode=AMP_TYPE.NAIVE)
-clip_grad_norm = 1.0
-gradient_accumulation = 8
-
-LOG_PATH = "./log"
diff --git a/examples/images/vit/configs/vit_1d_tp2_ci.py b/examples/images/vit/configs/vit_1d_tp2_ci.py
deleted file mode 100644
index e491e4ada45e..000000000000
--- a/examples/images/vit/configs/vit_1d_tp2_ci.py
+++ /dev/null
@@ -1,32 +0,0 @@
-from colossalai.amp import AMP_TYPE
-
-# hyperparameters
-# BATCH_SIZE is as per GPU
-# global batch size = BATCH_SIZE x data parallel size
-BATCH_SIZE = 8
-LEARNING_RATE = 3e-3
-WEIGHT_DECAY = 0.3
-NUM_EPOCHS = 3
-WARMUP_EPOCHS = 1
-
-# model config
-IMG_SIZE = 224
-PATCH_SIZE = 16
-HIDDEN_SIZE = 32
-DEPTH = 2
-NUM_HEADS = 4
-MLP_RATIO = 4
-NUM_CLASSES = 10
-CHECKPOINT = False
-SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
-
-USE_DDP = True
-TP_WORLD_SIZE = 2
-TP_TYPE = 'row'
-parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
-
-fp16 = dict(mode=AMP_TYPE.NAIVE)
-clip_grad_norm = 1.0
-gradient_accumulation = 2
-
-LOG_PATH = "./log_ci"
diff --git a/examples/images/vit/data.py b/examples/images/vit/data.py
new file mode 100644
index 000000000000..00fde707b173
--- /dev/null
+++ b/examples/images/vit/data.py
@@ -0,0 +1,32 @@
+import torch
+from torch.utils.data import Dataset
+from datasets import load_dataset
+
+class BeansDataset(Dataset):
+
+ def __init__(self, image_processor, split='train'):
+
+ super().__init__()
+ self.image_processor = image_processor
+ self.ds = load_dataset('beans')[split]
+ self.label_names = self.ds.features['labels'].names
+ self.num_labels = len(self.label_names)
+ self.inputs = []
+ for example in self.ds:
+ self.inputs.append(self.process_example(example))
+
+ def __len__(self):
+ return len(self.inputs)
+
+ def __getitem__(self, idx):
+ return self.inputs[idx]
+
+ def process_example(self, example):
+ input = self.image_processor(example['image'], return_tensors='pt')
+ input['labels'] = example['labels']
+ return input
+
+
+def beans_collator(batch):
+ return {'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
+ 'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)}
diff --git a/examples/images/vit/requirements.txt b/examples/images/vit/requirements.txt
index 1f69794ebe70..edad87ca380f 100644
--- a/examples/images/vit/requirements.txt
+++ b/examples/images/vit/requirements.txt
@@ -1,8 +1,6 @@
colossalai >= 0.1.12
torch >= 1.8.1
numpy>=1.24.1
-timm>=0.6.12
-titans>=0.0.7
tqdm>=4.61.2
-transformers>=4.25.1
-nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist
+transformers>=4.20.0
+datasets
\ No newline at end of file
diff --git a/examples/images/vit/run.sh b/examples/images/vit/run.sh
deleted file mode 100644
index 84fe58f11a6a..000000000000
--- a/examples/images/vit/run.sh
+++ /dev/null
@@ -1,15 +0,0 @@
-export DATA=/data/scratch/imagenet/tf_records
-export OMP_NUM_THREADS=4
-
-# resume
-# CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \
-# --nproc_per_node 4 train.py \
-# --config configs/vit_1d_tp2.py \
-# --resume_from checkpoint/epoch_10 \
-# --master_port 29598 | tee ./out 2>&1
-
-# train
-CUDA_VISIBLE_DEVICES=4,5,6,7 colossalai run \
---nproc_per_node 4 train.py \
---config configs/vit_1d_tp2.py \
---master_port 29598 | tee ./out 2>&1
diff --git a/examples/images/vit/run_benchmark.sh b/examples/images/vit/run_benchmark.sh
new file mode 100644
index 000000000000..2487bf81ee2b
--- /dev/null
+++ b/examples/images/vit/run_benchmark.sh
@@ -0,0 +1,27 @@
+set -xe
+pip install -r requirements.txt
+
+export BS=8
+export MEMCAP=0
+export GPUNUM=1
+
+for BS in 8 32 128
+do
+for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
+do
+for GPUNUM in 1 4
+do
+
+MODEL_PATH="google/vit-base-patch16-224"
+torchrun \
+ --standalone \
+ --nproc_per_node ${GPUNUM} \
+ vit_benchmark.py \
+ --model_name_or_path ${MODEL_PATH} \
+ --mem_cap ${MEMCAP} \
+ --plugin ${PLUGIN} \
+ --batch_size ${BS}
+
+done
+done
+done
diff --git a/examples/images/vit/run_demo.sh b/examples/images/vit/run_demo.sh
new file mode 100644
index 000000000000..2d140dd6e423
--- /dev/null
+++ b/examples/images/vit/run_demo.sh
@@ -0,0 +1,44 @@
+set -xe
+pip install -r requirements.txt
+
+# model name or path
+MODEL="google/vit-base-patch16-224"
+
+# path for saving model
+OUTPUT_PATH="./output_model.bin"
+
+# plugin(training strategy)
+# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"
+PLUGIN="gemini"
+
+# number of gpus to use
+GPUNUM=4
+
+# batch size per gpu
+BS=16
+
+# learning rate
+LR="2e-4"
+
+# number of epoch
+EPOCH=3
+
+# weight decay
+WEIGHT_DECAY=0.05
+
+# ratio of warmup steps
+WARMUP_RATIO=0.3
+
+# run the script for demo
+torchrun \
+ --standalone \
+ --nproc_per_node ${GPUNUM} \
+ vit_train_demo.py \
+ --model_name_or_path ${MODEL} \
+ --output_path ${OUTPUT_PATH} \
+ --plugin ${PLUGIN} \
+ --batch_size ${BS} \
+ --num_epoch ${EPOCH} \
+ --learning_rate ${LR} \
+ --weight_decay ${WEIGHT_DECAY} \
+ --warmup_ratio ${WARMUP_RATIO}
diff --git a/examples/images/vit/test_ci.sh b/examples/images/vit/test_ci.sh
index 41d25ee23521..8606015c0397 100644
--- a/examples/images/vit/test_ci.sh
+++ b/examples/images/vit/test_ci.sh
@@ -1,9 +1,19 @@
-export OMP_NUM_THREADS=4
-
+set -xe
pip install -r requirements.txt
-# train
-colossalai run \
---nproc_per_node 4 train.py \
---config configs/vit_1d_tp2_ci.py \
---dummy_data
+BS=8
+for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
+do
+for GPUNUM in 1 4
+do
+
+torchrun \
+ --standalone \
+ --nproc_per_node ${GPUNUM} \
+ vit_benchmark.py \
+ --model_name_or_path "google/vit-base-patch16-224" \
+ --plugin ${PLUGIN} \
+ --batch_size ${BS}
+
+done
+done
diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py
deleted file mode 100644
index c0ae35bca871..000000000000
--- a/examples/images/vit/test_vit.py
+++ /dev/null
@@ -1,160 +0,0 @@
-import os
-import random
-
-import numpy as np
-import pytest
-import torch
-from torch.nn.parallel import DistributedDataParallel as DDP
-from vit import get_training_components
-
-import colossalai
-from colossalai.context import ParallelMode
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
-from colossalai.nn.parallel.data_parallel import ColoDDP
-from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
-from colossalai.testing import rerun_if_address_is_in_use, spawn
-from colossalai.utils.cuda import get_current_device
-from colossalai.zero import ColoInitContext
-
-
-def set_seed(seed):
- random.seed(seed)
- os.environ['PYTHONHASHSEED'] = str(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.backends.cudnn.deterministic = True
-
-
-def tensor_equal(A, B):
- return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
-
-
-def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
- assert tensor.ndim == shard.ndim
- if tensor.shape == shard.shape:
- return tensor_equal(tensor, shard)
- else:
- dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
- if dims_not_eq.numel() == 1:
- # 1D shard
- dim = dims_not_eq.item()
- world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
- rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
- return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
- else:
- raise
-
-
-# Only for all Linear, it's 1d_row split because Linear will be transposed when calculating.
-# But for other layers, it's 1d_col split.
-# Layernorm is not supported for now.
-# patch_embeddings.projection has nn.Conv2d
-# https://github.com/huggingface/transformers/blob/dcb08b99f44919425f8ba9be9ddcc041af8ec25e/src/transformers/models/vit/modeling_vit.py#L182
-def init_1d_row_for_linear_weight_spec(model, world_size: int):
- pg = ProcessGroup(tp_degree=world_size)
- spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
- with DistSpecManager.no_grad():
- for n, p in model.named_parameters():
- if 'weight' in n and 'layernorm' not in n and 'embeddings.patch_embeddings.projection.weight' not in n:
- p.set_process_group(pg)
- p.set_tensor_spec(*spec)
-
-
-# Similarly, it's col split for Linear but row split for others.
-def init_1d_col_for_linear_weight_bias_spec(model, world_size: int):
- pg = ProcessGroup(tp_degree=world_size)
- spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
- with DistSpecManager.no_grad():
- for n, p in model.named_parameters():
- if ('weight' in n
- or 'bias' in n) and 'layernorm' not in n and 'embeddings.patch_embeddings.projection' not in n:
- p.set_process_group(pg)
- p.set_tensor_spec(*spec)
-
-
-def check_param_equal(model, torch_model):
- for p, torch_p in zip(model.parameters(), torch_model.parameters()):
- assert tensor_shard_equal(torch_p, p)
-
-
-def check_grad_equal(model, torch_model):
- for p, torch_p in zip(model.parameters(), torch_model.parameters()):
- if (torch_p.grad.shape == p.grad.shape):
- assert torch.allclose(torch_p.grad, p.grad, rtol=1e-3, atol=2.0) == True
- else:
- dims_not_eq = torch.nonzero(torch.tensor(torch_p.grad.shape) != torch.tensor(p.grad.shape))
- dim = dims_not_eq.item()
- world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
- rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
- assert torch.allclose(torch_p.grad.chunk(world_size, dim)[rank], p.grad, rtol=1e-3, atol=2.0) == True
-
-
-def run_vit(init_spec_func, use_ddp):
- model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_training_components()
- with ColoInitContext(device=get_current_device()):
- model = model_builder()
- model = model.cuda()
- torch_model = model_builder().cuda()
- if use_ddp:
- model = ColoDDP(model)
- torch_model = DDP(torch_model,
- device_ids=[gpc.get_global_rank()],
- process_group=gpc.get_group(ParallelMode.DATA))
- for torch_p, p in zip(torch_model.parameters(), model.parameters()):
- torch_p.data.copy_(p)
-
- world_size = torch.distributed.get_world_size()
- init_spec_func(model, world_size)
-
- check_param_equal(model, torch_model)
- model.train()
- torch_model.train()
- set_seed(gpc.get_local_rank(ParallelMode.DATA))
-
- optimizer = optimizer_class(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
- torch_optimizer = optimizer_class(torch_model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
-
- for i, image_dict in enumerate(train_dataloader):
- if use_ddp:
- model.zero_grad()
- else:
- optimizer.zero_grad()
- logits = model(image_dict['pixel_values'])
- torch_logits = torch_model(image_dict['pixel_values'])
- assert tensor_equal(torch_logits.logits, logits.logits)
- loss = criterion(logits.logits, image_dict['label'])
- torch_loss = criterion(torch_logits.logits, image_dict['label'])
- if use_ddp:
- model.backward(loss)
- else:
- loss.backward()
- torch_loss.backward()
- check_grad_equal(model, torch_model)
- optimizer.step()
- torch_optimizer.step()
- check_param_equal(model, torch_model)
- break
-
-
-def run_dist(rank, world_size, port, use_ddp):
- if use_ddp and world_size == 1:
- return
- tp_world_size = world_size // 2 if use_ddp else world_size
- config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
- colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- run_vit(init_1d_row_for_linear_weight_spec, use_ddp)
- run_vit(init_1d_col_for_linear_weight_bias_spec, use_ddp)
-
-
-@pytest.mark.dist
-@pytest.mark.parametrize('world_size', [1, 4])
-@pytest.mark.parametrize('use_ddp', [False, True])
-@rerun_if_address_is_in_use()
-def test_vit(world_size, use_ddp):
- spawn(run_dist, world_size, use_ddp=use_ddp)
-
-
-if __name__ == '__main__':
- test_vit(1, False)
diff --git a/examples/images/vit/train.py b/examples/images/vit/train.py
deleted file mode 100644
index b42cf2bedc6b..000000000000
--- a/examples/images/vit/train.py
+++ /dev/null
@@ -1,174 +0,0 @@
-import os
-
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-import torch.nn.functional as F
-from timm.models.vision_transformer import _create_vision_transformer
-from titans.dataloader.imagenet import build_dali_imagenet
-from tqdm import tqdm
-from vit import DummyDataLoader
-
-import colossalai
-from colossalai.core import global_context as gpc
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.nn import CrossEntropyLoss
-from colossalai.nn._ops import *
-from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-from colossalai.nn.parallel.data_parallel import ColoDDP
-from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
-from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext
-
-
-def init_1d_row_for_linear_weight_spec(model, world_size: int):
- pg = ProcessGroup(tp_degree=world_size)
- spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
- with DistSpecManager.no_grad():
- for n, p in model.named_parameters():
- if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n:
- p.set_process_group(pg)
- p.set_tensor_spec(*spec)
-
-
-# Similarly, it's col split for Linear but row split for others.
-def init_1d_col_for_linear_weight_bias_spec(model, world_size: int):
- pg = ProcessGroup(tp_degree=world_size)
- spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
- with DistSpecManager.no_grad():
- for n, p in model.named_parameters():
- if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n
- and 'patch_embed.proj.bias' not in n):
- p.set_process_group(pg)
- p.set_tensor_spec(*spec)
-
-
-def init_spec_func(model, tp_type):
- world_size = torch.distributed.get_world_size()
- if tp_type == 'row':
- init_1d_row_for_linear_weight_spec(model, world_size)
- elif tp_type == 'col':
- init_1d_col_for_linear_weight_bias_spec(model, world_size)
- else:
- raise NotImplemented
-
-
-def train_imagenet():
-
- parser = colossalai.get_default_parser()
- parser.add_argument('--resume_from', default=False, action='store_true')
- parser.add_argument('--dummy_data', default=False, action='store_true')
-
- args = parser.parse_args()
- colossalai.launch_from_torch(config=args.config)
- use_ddp = gpc.config.USE_DDP
-
- disable_existing_loggers()
-
- logger = get_dist_logger()
- if hasattr(gpc.config, 'LOG_PATH'):
- if gpc.get_global_rank() == 0:
- log_path = gpc.config.LOG_PATH
- if not os.path.exists(log_path):
- os.mkdir(log_path)
- logger.log_to_file(log_path)
-
- logger.info('Build data loader', ranks=[0])
- if not args.dummy_data:
- root = os.environ['DATA']
- train_dataloader, test_dataloader = build_dali_imagenet(root,
- train_batch_size=gpc.config.BATCH_SIZE,
- test_batch_size=gpc.config.BATCH_SIZE)
- else:
- train_dataloader = DummyDataLoader(length=10,
- batch_size=gpc.config.BATCH_SIZE,
- category=gpc.config.NUM_CLASSES,
- image_size=gpc.config.IMG_SIZE,
- return_dict=False)
- test_dataloader = DummyDataLoader(length=5,
- batch_size=gpc.config.BATCH_SIZE,
- category=gpc.config.NUM_CLASSES,
- image_size=gpc.config.IMG_SIZE,
- return_dict=False)
-
- logger.info('Build model', ranks=[0])
-
- model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
- patch_size=gpc.config.PATCH_SIZE,
- embed_dim=gpc.config.HIDDEN_SIZE,
- depth=gpc.config.DEPTH,
- num_heads=gpc.config.NUM_HEADS,
- mlp_ratio=gpc.config.MLP_RATIO,
- num_classes=gpc.config.NUM_CLASSES,
- drop_rate=0.1,
- attn_drop_rate=0.1,
- weight_init='jax')
-
- with ColoInitContext(device=get_current_device()):
- model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs)
- init_spec_func(model, gpc.config.TP_TYPE)
-
- world_size = torch.distributed.get_world_size()
- model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size))
- logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0])
- optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
-
- criterion = CrossEntropyLoss()
- lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
- total_steps=gpc.config.NUM_EPOCHS,
- warmup_steps=gpc.config.WARMUP_EPOCHS)
-
- start_epoch = 0
- if args.resume_from:
- load_model = torch.load(args.resume_from + '_model.pth')
- start_epoch = load_model['epoch']
- model.load_state_dict(load_model['model'])
- load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank()))
- optimizer.load_state_dict(load_optim['optim'])
-
- for epoch in range(start_epoch, gpc.config.NUM_EPOCHS):
- model.train()
- for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False):
- x, y = x.cuda(), y.cuda()
- output = model(x)
- loss = criterion(output, y)
- loss = loss / gpc.config.gradient_accumulation
- if use_ddp:
- model.backward(loss)
- else:
- loss.backward()
- if (index + 1) % gpc.config.gradient_accumulation == 0:
- optimizer.step()
- if use_ddp:
- model.zero_grad()
- else:
- optimizer.zero_grad()
-
- logger.info(
- f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}",
- ranks=[0])
-
- model.eval()
- test_loss = 0
- correct = 0
- test_sum = 0
- with torch.no_grad():
- for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False):
- x, y = x.cuda(), y.cuda()
- output = model(x)
- test_loss += F.cross_entropy(output, y, reduction='sum').item()
- pred = output.argmax(dim=1, keepdim=True)
- correct += pred.eq(y.view_as(pred)).sum().item()
- test_sum += y.size(0)
-
- test_loss /= test_sum
- logger.info(
- f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})",
- ranks=[0])
-
- lr_scheduler.step()
-
-
-if __name__ == '__main__':
- train_imagenet()
diff --git a/examples/images/vit/vit.py b/examples/images/vit/vit.py
deleted file mode 100644
index f22e8ea90cec..000000000000
--- a/examples/images/vit/vit.py
+++ /dev/null
@@ -1,95 +0,0 @@
-from abc import ABC, abstractmethod
-
-import torch
-import torch.nn as nn
-from transformers import ViTConfig, ViTForImageClassification
-
-from colossalai.utils.cuda import get_current_device
-
-
-class DummyDataGenerator(ABC):
-
- def __init__(self, length=10):
- self.length = length
-
- @abstractmethod
- def generate(self):
- pass
-
- def __iter__(self):
- self.step = 0
- return self
-
- def __next__(self):
- if self.step < self.length:
- self.step += 1
- return self.generate()
- else:
- raise StopIteration
-
- def __len__(self):
- return self.length
-
-
-class DummyDataLoader(DummyDataGenerator):
-
- def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True):
- super().__init__(length)
- self.batch_size = batch_size
- self.channel = channel
- self.category = category
- self.image_size = image_size
- self.return_dict = return_dict
-
- def generate(self):
- image_dict = {}
- image_dict['pixel_values'] = torch.rand(
- self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1
- image_dict['label'] = torch.randint(self.category, (self.batch_size,),
- dtype=torch.int64,
- device=get_current_device())
- if not self.return_dict:
- return image_dict['pixel_values'], image_dict['label']
- return image_dict
-
-
-class ViTCVModel(nn.Module):
-
- def __init__(self,
- hidden_size=768,
- num_hidden_layers=12,
- num_attention_heads=12,
- image_size=224,
- patch_size=16,
- num_channels=3,
- num_labels=8,
- checkpoint=False):
- super().__init__()
- self.checkpoint = checkpoint
- self.model = ViTForImageClassification(
- ViTConfig(hidden_size=hidden_size,
- num_hidden_layers=num_hidden_layers,
- num_attention_heads=num_attention_heads,
- image_size=image_size,
- patch_size=patch_size,
- num_channels=num_channels,
- num_labels=num_labels))
- if checkpoint:
- self.model.gradient_checkpointing_enable()
-
- def forward(self, pixel_values):
- return self.model(pixel_values=pixel_values)
-
-
-def vit_base_s(checkpoint=True):
- return ViTCVModel(checkpoint=checkpoint)
-
-
-def vit_base_micro(checkpoint=True):
- return ViTCVModel(hidden_size=32, num_hidden_layers=2, num_attention_heads=4, checkpoint=checkpoint)
-
-
-def get_training_components():
- trainloader = DummyDataLoader()
- testloader = DummyDataLoader()
- return vit_base_micro, trainloader, testloader, torch.optim.Adam, torch.nn.functional.cross_entropy
diff --git a/examples/images/vit/vit_benchmark.py b/examples/images/vit/vit_benchmark.py
new file mode 100644
index 000000000000..11d480bba65f
--- /dev/null
+++ b/examples/images/vit/vit_benchmark.py
@@ -0,0 +1,129 @@
+import time
+
+import torch
+import transformers
+from transformers import ViTConfig, ViTForImageClassification
+import tqdm
+
+import colossalai
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.utils import get_current_device
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+
+from args import parse_benchmark_args
+
+def format_num(num: int, bytes=False):
+ """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
+ factor = 1024 if bytes else 1000
+ suffix = "B" if bytes else ""
+ for unit in ["", " K", " M", " G", " T", " P"]:
+ if num < factor:
+ return f"{num:.2f}{unit}{suffix}"
+ num /= factor
+
+
+def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
+ pixel_values = torch.randn(batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float)
+ labels = torch.randint(0, num_labels, (batch_size, ), device=torch.cuda.current_device(), dtype=torch.int64)
+ return pixel_values, labels
+
+
+def colo_memory_cap(size_in_GB):
+ from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
+ cuda_capacity = colo_device_memory_capacity(get_current_device())
+ if size_in_GB * (1024**3) < cuda_capacity:
+ colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
+ print(f"Limiting GPU memory usage to {size_in_GB} GB")
+
+
+def main():
+
+ args = parse_benchmark_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+ world_size = coordinator.world_size
+
+ # Manage loggers
+ disable_existing_loggers()
+ logger = get_dist_logger()
+ if coordinator.is_master():
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # Whether to set limit on memory capacity
+ if args.mem_cap > 0:
+ colo_memory_cap(args.mem_cap)
+
+ # Build ViT model
+ config = ViTConfig.from_pretrained(args.model_name_or_path)
+ model = ViTForImageClassification(config)
+ logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
+ # Set plugin
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(device=get_current_device(),
+ placement_policy='cpu',
+ pin_memory=True,
+ strict_ddp_mode=True,
+ initial_scale=2**5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ logger.info(f"Set plugin as {args.plugin}", ranks=[0])
+
+ # Set optimizer
+ optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
+
+ # Set booster
+ booster = Booster(plugin=plugin, **booster_kwargs)
+ model, optimizer, _, _, _ = booster.boost(model, optimizer)
+
+
+ # Start training.
+ logger.info(f"Start testing", ranks=[0])
+ progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
+
+ torch.cuda.synchronize()
+ model.train()
+ start_time = time.time()
+
+ for _ in range(args.max_train_steps):
+
+ pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
+ optimizer.zero_grad()
+ outputs = model(pixel_values=pixel_values, labels=labels)
+ loss = outputs['loss']
+ booster.backward(loss, optimizer)
+ optimizer.step()
+
+ torch.cuda.synchronize()
+ progress_bar.update(1)
+
+ # Compute Statistics
+ end_time = time.time()
+ throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
+ max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
+
+ logger.info(f"Testing finished, "
+ f"batch size per gpu: {args.batch_size}, "
+ f"plugin: {args.plugin}, "
+ f"throughput: {throughput}, "
+ f"maximum memory usage per gpu: {max_mem}.",
+ ranks=[0])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/images/vit/vit_train_demo.py b/examples/images/vit/vit_train_demo.py
new file mode 100644
index 000000000000..3a739f10b5d0
--- /dev/null
+++ b/examples/images/vit/vit_train_demo.py
@@ -0,0 +1,177 @@
+import torch
+import torch.distributed as dist
+import transformers
+from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
+from tqdm import tqdm
+
+import colossalai
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.utils import get_current_device
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+
+from args import parse_demo_args
+from data import BeansDataset, beans_collator
+
+
+def move_to_cuda(batch, device):
+ return {k: v.to(device) for k, v in batch.items()}
+
+
+def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
+
+ torch.cuda.synchronize()
+ model.train()
+
+ with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
+
+ for batch in pbar:
+
+ # Foward
+ optimizer.zero_grad()
+ batch = move_to_cuda(batch, torch.cuda.current_device())
+ outputs = model(**batch)
+ loss = outputs['loss']
+
+ # Backward
+ booster.backward(loss, optimizer)
+ optimizer.step()
+ lr_scheduler.step()
+
+ # Print batch loss
+ pbar.set_postfix({'loss': loss.item()})
+
+
+@torch.no_grad()
+def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
+
+ model.eval()
+ accum_loss = torch.zeros(1, device=get_current_device())
+ total_num = torch.zeros(1, device=get_current_device())
+ accum_correct = torch.zeros(1, device=get_current_device())
+
+ for batch in eval_dataloader:
+ batch = move_to_cuda(batch, torch.cuda.current_device())
+ outputs = model(**batch)
+ val_loss, logits = outputs[:2]
+ accum_loss += (val_loss / len(eval_dataloader))
+ if num_labels > 1:
+ preds = torch.argmax(logits, dim=1)
+ elif num_labels == 1:
+ preds = logits.squeeze()
+
+ labels = batch["labels"]
+ total_num += batch["labels"].shape[0]
+ accum_correct += (torch.sum(preds == labels))
+
+ dist.all_reduce(accum_loss)
+ dist.all_reduce(total_num)
+ dist.all_reduce(accum_correct)
+ avg_loss = "{:.4f}".format(accum_loss.item())
+ accuracy = "{:.4f}".format(accum_correct.item() / total_num.item())
+ if coordinator.is_master():
+ print(f"Evaluation result for epoch {epoch + 1}: \
+ average_loss={avg_loss}, \
+ accuracy={accuracy}.")
+
+
+
+
+def main():
+
+ args = parse_demo_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+ world_size = coordinator.world_size
+
+ # Manage loggers
+ disable_existing_loggers()
+ logger = get_dist_logger()
+ if coordinator.is_master():
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # Prepare Dataset
+ image_processor = ViTImageProcessor.from_pretrained(args.model_name_or_path)
+ train_dataset = BeansDataset(image_processor, split='train')
+ eval_dataset = BeansDataset(image_processor, split='validation')
+
+
+ # Load pretrained ViT model
+ config = ViTConfig.from_pretrained(args.model_name_or_path)
+ config.num_labels = train_dataset.num_labels
+ config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
+ config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
+ model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
+ config=config,
+ ignore_mismatched_sizes=True)
+ logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
+ # Set plugin
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(device=get_current_device(),
+ placement_policy='cpu',
+ pin_memory=True,
+ strict_ddp_mode=True,
+ initial_scale=2**5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ logger.info(f"Set plugin as {args.plugin}", ranks=[0])
+
+ # Prepare dataloader
+ train_dataloader = plugin.prepare_dataloader(train_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=beans_collator)
+ eval_dataloader = plugin.prepare_dataloader(eval_dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=beans_collator)
+
+ # Set optimizer
+ optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
+
+ # Set lr scheduler
+ total_steps = len(train_dataloader) * args.num_epoch
+ num_warmup_steps = int(args.warmup_ratio * total_steps)
+ lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
+ total_steps=(len(train_dataloader) * args.num_epoch),
+ warmup_steps=num_warmup_steps)
+
+ # Set booster
+ booster = Booster(plugin=plugin, **booster_kwargs)
+ model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
+ optimizer=optimizer,
+ dataloader=train_dataloader,
+ lr_scheduler=lr_scheduler)
+
+ # Finetuning
+ logger.info(f"Start finetuning", ranks=[0])
+ for epoch in range(args.num_epoch):
+ train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
+ evaluate_model(epoch, model, eval_dataloader, eval_dataset.num_labels, coordinator)
+ logger.info(f"Finish finetuning", ranks=[0])
+
+ # Save the finetuned model
+ booster.save_model(model, args.output_path)
+ logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/examples/language/bert/README.md b/examples/language/bert/README.md
new file mode 100644
index 000000000000..81c3f03fffca
--- /dev/null
+++ b/examples/language/bert/README.md
@@ -0,0 +1,34 @@
+## Overview
+
+This directory includes two parts: Using the Booster API finetune Huggingface Bert and AlBert models and benchmarking Bert and AlBert models with different Booster Plugin.
+
+## Finetune
+```
+bash test_ci.sh
+```
+
+## Benchmark
+```
+bash benchmark.sh
+```
+
+Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util.
+
+## Results
+
+### Bert
+
+| | max cuda mem | throughput(sample/s) | params |
+| :-----| -----------: | :--------: | :----: |
+| ddp | 21.44 GB | 3.0 | 82M |
+| ddp_fp16 | 16.26 GB | 11.3 | 82M |
+| gemini | 11.0 GB | 12.9 | 82M |
+| low_level_zero | 11.29 G | 14.7 | 82M |
+
+### AlBert
+| | max cuda mem | throughput(sample/s) | params |
+| :-----| -----------: | :--------: | :----: |
+| ddp | OOM | | |
+| ddp_fp16 | OOM | | |
+| gemini | 69.39 G | 1.3 | 208M |
+| low_level_zero | 56.89 G | 1.4 | 208M |
\ No newline at end of file
diff --git a/examples/language/bert/benchmark.py b/examples/language/bert/benchmark.py
new file mode 100644
index 000000000000..ae8b2269a534
--- /dev/null
+++ b/examples/language/bert/benchmark.py
@@ -0,0 +1,174 @@
+import argparse
+
+import torch
+from benchmark_utils import benchmark
+from torch.utils.data import DataLoader, Dataset
+from transformers import (
+ AlbertConfig,
+ AlbertForSequenceClassification,
+ BertConfig,
+ BertForSequenceClassification,
+ get_linear_schedule_with_warmup,
+)
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.optimizer import HybridAdam
+
+# ==============================
+# Prepare Hyperparameters
+# ==============================
+NUM_EPOCHS = 3
+BATCH_SIZE = 32
+LEARNING_RATE = 2.4e-5
+WEIGHT_DECAY = 0.01
+WARMUP_FRACTION = 0.1
+SEQ_LEN = 512
+VOCAB_SIZE = 1000
+NUM_LABELS = 10
+DATASET_LEN = 1000
+
+
+class RandintDataset(Dataset):
+
+ def __init__(self, dataset_length: int, sequence_length: int, vocab_size: int, n_class: int):
+
+ self._sequence_length = sequence_length
+ self._vocab_size = vocab_size
+ self._n_class = n_class
+ self._dataset_length = dataset_length
+ self._datas = torch.randint(
+ low=0,
+ high=self._vocab_size,
+ size=(self._dataset_length, self._sequence_length,),
+ dtype=torch.long,
+ )
+ self._labels = torch.randint(low=0, high=self._n_class, size=(self._dataset_length, 1), dtype=torch.long)
+
+ def __len__(self):
+ return self._dataset_length
+
+ def __getitem__(self, idx):
+ return self._datas[idx], self._labels[idx]
+
+
+def main():
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
+ parser.add_argument('-p',
+ '--plugin',
+ type=str,
+ default='torch_ddp',
+ choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
+ help="plugin to use")
+ parser.add_argument(
+ "--model_type",
+ type=str,
+ default="bert",
+ help="bert or albert",
+ )
+
+ args = parser.parse_args()
+
+ # ==============================
+ # Launch Distributed Environment
+ # ==============================
+ colossalai.launch_from_torch(config={}, seed=42)
+ coordinator = DistCoordinator()
+
+ # local_batch_size = BATCH_SIZE // coordinator.world_size
+ lr = LEARNING_RATE * coordinator.world_size
+
+ # ==============================
+ # Instantiate Plugin and Booster
+ # ==============================
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2**5)
+
+ booster = Booster(plugin=plugin, **booster_kwargs)
+
+ # ==============================
+ # Prepare Dataloader
+ # ==============================
+
+ train_dataset = RandintDataset(dataset_length=DATASET_LEN,
+ sequence_length=SEQ_LEN,
+ vocab_size=VOCAB_SIZE,
+ n_class=NUM_LABELS)
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
+
+ # ====================================
+ # Prepare model, optimizer
+ # ====================================
+ # bert pretrained model
+
+ if args.model_type == "bert":
+ cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
+ model = BertForSequenceClassification(cfg)
+ elif args.model_type == "albert":
+ cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
+ model = AlbertForSequenceClassification(cfg)
+ else:
+ raise RuntimeError
+
+ # optimizer
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": WEIGHT_DECAY,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
+
+ # lr scheduler
+ total_steps = len(train_dataloader) * NUM_EPOCHS
+ num_warmup_steps = int(WARMUP_FRACTION * total_steps)
+ lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=total_steps,
+ )
+
+ # criterion
+ criterion = lambda inputs: inputs[0]
+
+ # ==============================
+ # Boost with ColossalAI
+ # ==============================
+ model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
+
+ # ==============================
+ # Benchmark model
+ # ==============================
+
+ results = benchmark(model,
+ booster,
+ optimizer,
+ lr_scheduler,
+ train_dataloader,
+ criterion=criterion,
+ epoch_num=NUM_EPOCHS)
+
+ coordinator.print_on_master(results)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/bert/benchmark.sh b/examples/language/bert/benchmark.sh
new file mode 100755
index 000000000000..9453d1373f2f
--- /dev/null
+++ b/examples/language/bert/benchmark.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+set -xe
+
+pip install -r requirements.txt
+
+for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
+ torchrun --standalone --nproc_per_node 2 benchmark.py --plugin $plugin --model_type "bert"
+ torchrun --standalone --nproc_per_node 2 benchmark.py --plugin $plugin --model_type "albert"
+done
diff --git a/examples/language/bert/benchmark_utils.py b/examples/language/bert/benchmark_utils.py
new file mode 100644
index 000000000000..886017a41826
--- /dev/null
+++ b/examples/language/bert/benchmark_utils.py
@@ -0,0 +1,146 @@
+import inspect
+from logging import getLogger
+from time import time
+from typing import Callable
+
+import torch
+import yaml
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from colossalai.booster import Booster
+from colossalai.cluster import DistCoordinator
+
+logger = getLogger("colossalai-booster-benchmark")
+_INVALID = float("nan")
+
+
+def format_num(num: int, bytes=False):
+ """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
+ factor = 1024 if bytes else 1000
+ suffix = "B" if bytes else ""
+ for unit in ["", " K", " M", " G", " T", " P"]:
+ if num < factor:
+ return f"{num:.2f}{unit}{suffix}"
+ num /= factor
+
+
+def _is_valid(val):
+ return val == val
+
+
+def get_call_arg_names(module_or_fn):
+ if isinstance(module_or_fn, torch.nn.Module):
+ return inspect.getfullargspec(module_or_fn.forward)[0][1:]
+ return inspect.getfullargspec(module_or_fn)[0]
+
+
+def measure_params(model):
+ num_params = _INVALID
+
+ try:
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ except AttributeError as e:
+ logger.error(f"Unable to measure model params due to error: {e}")
+
+ return num_params
+
+
+def warm_up(
+ model,
+ booster,
+ dataloader,
+ criterion,
+ optimizer,
+ lr_scheduler,
+ num_runs=10,
+):
+ for i, data in enumerate(dataloader):
+ if i > num_runs:
+ break
+ inputs, labels = data[0].cuda(), data[1].cuda()
+ outputs = model(inputs, labels=labels)
+ loss = criterion(outputs)
+ booster.backward(loss, optimizer)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+
+def fmt(d: dict):
+ return yaml.dump(d)
+
+
+def benchmark(
+ model: torch.nn.Module,
+ booster: Booster,
+ optimizer: torch.optim.Optimizer,
+ lr_scheduler: LRScheduler,
+ dataloader: DataLoader,
+ criterion: Callable = None,
+ warm_up_fn=warm_up,
+ epoch_num: int = 3,
+ batch_size: int = 32,
+ warm_up_steps: int = 3,
+):
+ results = {}
+ model_device = torch.cuda.current_device()
+
+ # Warm up
+ warm_up_fn(
+ model,
+ booster,
+ dataloader,
+ criterion,
+ optimizer,
+ lr_scheduler,
+ num_runs=warm_up_steps,
+ )
+ # Measure params
+ params = measure_params(model)
+ if _is_valid(params):
+ results["params"] = format_num(params)
+ logger.info(f"Model parameters: {params} ({format_num(params)})")
+
+ # Measure Allocated Memory and Throughput
+ memory = {}
+ throughput = {}
+ torch.cuda.reset_peak_memory_stats(device=model_device)
+ pre_mem = torch.cuda.memory_allocated(device=model_device)
+
+ start_time = time()
+
+ for epoch in range(epoch_num):
+ with tqdm(dataloader, desc=f'Epoch [{epoch + 1}/{epoch_num}]',
+ disable=not DistCoordinator().is_master()) as pbar:
+ for data in pbar:
+ inputs, labels = data[0].cuda(), data[1].cuda()
+ outputs = model(inputs, labels=labels)
+ loss = criterion(outputs)
+ booster.backward(loss, optimizer)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ end_time = time()
+
+ all_sample = epoch_num * len(dataloader)
+
+ post_mem = torch.cuda.memory_allocated(device=model_device)
+ max_mem = torch.cuda.max_memory_allocated(device=model_device)
+
+ memory[f"batch_size_{batch_size}"] = {
+ "cuda_pre_training_bytes": format_num(pre_mem, bytes=True),
+ "cuda_max_training_bytes": format_num(max_mem, bytes=True),
+ "cuda_post_training_bytes": format_num(post_mem, bytes=True),
+ }
+ logger.info(fmt({f"Memory results (batch_size={batch_size})": memory[f"batch_size_{batch_size}"]}))
+
+ throughput[f"batch_size_{batch_size}"] = {"throughput:": "{:.1f}".format(all_sample * DistCoordinator().world_size / (end_time - start_time))}
+ logger.info(fmt({f"Throughput results (batch_size={batch_size})": throughput[f"batch_size_{batch_size}"]}))
+
+ results["throughput"] = throughput
+ results["memory"] = memory
+
+ return results
diff --git a/examples/language/bert/data.py b/examples/language/bert/data.py
new file mode 100644
index 000000000000..981cedcca8c2
--- /dev/null
+++ b/examples/language/bert/data.py
@@ -0,0 +1,127 @@
+import datasets
+from transformers import AutoTokenizer, PreTrainedTokenizer
+
+from colossalai.booster.plugin.dp_plugin_base import DPPluginBase
+
+
+class GLUEDataBuilder:
+
+ task_text_field_map = {
+ "cola": ["sentence"],
+ "sst2": ["sentence"],
+ "mrpc": ["sentence1", "sentence2"],
+ "qqp": ["question1", "question2"],
+ "stsb": ["sentence1", "sentence2"],
+ "mnli": ["premise", "hypothesis"],
+ "qnli": ["question", "sentence"],
+ "rte": ["sentence1", "sentence2"],
+ "wnli": ["sentence1", "sentence2"],
+ "ax": ["premise", "hypothesis"],
+ }
+
+ glue_task_num_labels = {
+ "cola": 2,
+ "sst2": 2,
+ "mrpc": 2,
+ "qqp": 2,
+ "stsb": 1,
+ "mnli": 3,
+ "qnli": 2,
+ "rte": 2,
+ "wnli": 2,
+ "ax": 3,
+ }
+
+ loader_columns = [
+ "datasets_idx",
+ "input_ids",
+ "token_type_ids",
+ "attention_mask",
+ "start_positions",
+ "end_positions",
+ "labels",
+ ]
+
+ def __init__(
+ self,
+ model_name_or_path: str,
+ plugin: DPPluginBase,
+ task_name: str = "mrpc",
+ max_seq_length: int = 128,
+ train_batch_size: int = 32,
+ eval_batch_size: int = 32,
+ **kwargs,
+ ):
+ super().__init__()
+ self.model_name_or_path = model_name_or_path
+ self.task_name = task_name
+ self.max_seq_length = max_seq_length
+ self.train_batch_size = train_batch_size
+ self.eval_batch_size = eval_batch_size
+ self.plugin = plugin
+
+ self.text_fields = self.task_text_field_map[task_name]
+ self.num_labels = self.glue_task_num_labels[task_name]
+ self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
+ self.setup()
+
+ def setup(self):
+ self.dataset = datasets.load_dataset("glue", self.task_name)
+
+ for split in self.dataset.keys():
+ self.dataset[split] = self.dataset[split].map(
+ self.convert_to_features,
+ batched=True,
+ remove_columns=["label"],
+ )
+ self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
+ self.dataset[split].set_format(type="torch", columns=self.columns)
+
+ self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
+
+ def prepare_data(self):
+ datasets.load_dataset("glue", self.task_name)
+ AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
+
+ def train_dataloader(self):
+ return self.plugin.prepare_dataloader(self.dataset["train"],
+ batch_size=self.train_batch_size,
+ shuffle=True,
+ drop_last=True)
+
+ def val_dataloader(self):
+ if len(self.eval_splits) == 1:
+ return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size)
+ elif len(self.eval_splits) > 1:
+ return [
+ self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
+ for x in self.eval_splits
+ ]
+
+ def test_dataloader(self):
+ if len(self.eval_splits) == 1:
+ return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size)
+ elif len(self.eval_splits) > 1:
+ return [
+ self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size)
+ for x in self.eval_splits
+ ]
+
+ def convert_to_features(self, example_batch):
+
+ # Either encode single sentence or sentence pairs
+ if len(self.text_fields) > 1:
+ texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
+ else:
+ texts_or_text_pairs = example_batch[self.text_fields[0]]
+
+ # Tokenize the text/text pairs
+ features = self.tokenizer.batch_encode_plus(texts_or_text_pairs,
+ max_length=self.max_seq_length,
+ padding='max_length',
+ truncation=True)
+
+ # Rename label to labels to make it easier to pass to model forward
+ features["labels"] = example_batch["label"]
+
+ return features
diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py
new file mode 100644
index 000000000000..b209ffde85a4
--- /dev/null
+++ b/examples/language/bert/finetune.py
@@ -0,0 +1,220 @@
+import argparse
+from typing import List, Union
+
+import evaluate
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from data import GLUEDataBuilder
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from transformers import (
+ AlbertForSequenceClassification,
+ AutoConfig,
+ BertForSequenceClassification,
+ get_linear_schedule_with_warmup,
+)
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.utils import get_current_device
+
+# ==============================
+# Prepare Hyperparameters
+# ==============================
+NUM_EPOCHS = 3
+BATCH_SIZE = 32
+LEARNING_RATE = 2.4e-5
+WEIGHT_DECAY = 0.01
+WARMUP_FRACTION = 0.1
+
+
+def move_to_cuda(batch):
+ return {k: v.cuda() for k, v in batch.items()}
+
+
+@torch.no_grad()
+def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str,
+ eval_splits: List[str], coordinator: DistCoordinator):
+ metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
+ model.eval()
+
+ def evaluate_subset(dataloader: DataLoader):
+ accum_loss = torch.zeros(1, device=get_current_device())
+ for batch in dataloader:
+ batch = move_to_cuda(batch)
+ outputs = model(**batch)
+ val_loss, logits = outputs[:2]
+ accum_loss.add_(val_loss)
+
+ if num_labels > 1:
+ preds = torch.argmax(logits, axis=1)
+ elif num_labels == 1:
+ preds = logits.squeeze()
+
+ labels = batch["labels"]
+
+ metric.add_batch(predictions=preds, references=labels)
+
+ results = metric.compute()
+ dist.all_reduce(accum_loss.div_(len(dataloader)))
+ if coordinator.is_master():
+ results['loss'] = accum_loss.item() / coordinator.world_size
+ return results
+
+ if isinstance(test_dataloader, DataLoader):
+ return evaluate_subset(test_dataloader)
+ else:
+ assert len(test_dataloader) == len(eval_splits)
+ final_results = {}
+ for split, sub_loader in zip(eval_splits, test_dataloader):
+ results = evaluate_subset(sub_loader)
+ final_results.update({f'{k}_{split}': v for k, v in results.items()})
+ return final_results
+
+
+def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
+ booster: Booster, coordinator: DistCoordinator):
+ model.train()
+ with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
+ for batch in pbar:
+ # Forward pass
+ batch = move_to_cuda(batch)
+ outputs = model(**batch)
+ loss = outputs[0]
+
+ # Backward and optimize
+ booster.backward(loss, optimizer)
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_scheduler.step()
+
+ # Print log info
+ pbar.set_postfix({'loss': loss.item()})
+
+
+def main():
+ # ==============================
+ # Parse Arguments
+ # ==============================
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-t', '--task', default='mrpc', help="GLUE task to run")
+ parser.add_argument('-p',
+ '--plugin',
+ type=str,
+ default='torch_ddp',
+ choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
+ help="plugin to use")
+ parser.add_argument(
+ "--model_type",
+ type=str,
+ default="bert",
+ help="bert or albert",
+ )
+ parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
+ args = parser.parse_args()
+
+ if args.model_type == 'bert':
+ model_name = "bert-base-uncased"
+ elif args.model_type == 'albert':
+ model_name = "albert-xxlarge-v2"
+ else:
+ raise RuntimeError
+ # ==============================
+ # Launch Distributed Environment
+ # ==============================
+ colossalai.launch_from_torch(config={}, seed=42)
+ coordinator = DistCoordinator()
+
+ # local_batch_size = BATCH_SIZE // coordinator.world_size
+ lr = LEARNING_RATE * coordinator.world_size
+
+ # ==============================
+ # Instantiate Plugin and Booster
+ # ==============================
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2**5)
+
+ booster = Booster(plugin=plugin, **booster_kwargs)
+
+ # ==============================
+ # Prepare Dataloader
+ # ==============================
+ data_builder = GLUEDataBuilder(model_name,
+ plugin,
+ args.task,
+ train_batch_size=BATCH_SIZE,
+ eval_batch_size=BATCH_SIZE)
+ train_dataloader = data_builder.train_dataloader()
+ test_dataloader = data_builder.test_dataloader()
+
+ # ====================================
+ # Prepare model, optimizer
+ # ====================================
+ # bert pretrained model
+
+ cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
+ if model_name == "bert-base-uncased":
+ model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
+ elif model_name == "albert-xxlarge-v2":
+ model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
+ else:
+ raise RuntimeError
+
+ # optimizer
+ no_decay = ["bias", "LayerNorm.weight"]
+ optimizer_grouped_parameters = [
+ {
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
+ "weight_decay": WEIGHT_DECAY,
+ },
+ {
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8)
+
+ # lr scheduler
+ total_steps = len(train_dataloader) * NUM_EPOCHS
+ num_warmup_steps = int(WARMUP_FRACTION * total_steps)
+ lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=total_steps,
+ )
+
+ # ==============================
+ # Boost with ColossalAI
+ # ==============================
+ model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
+
+ # ==============================
+ # Train model
+ # ==============================
+ for epoch in range(NUM_EPOCHS):
+ train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
+
+ results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
+ coordinator)
+
+ if coordinator.is_master():
+ print(results)
+ if args.target_f1 is not None and 'f1' in results:
+ assert results['f1'] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/bert/requirements.txt b/examples/language/bert/requirements.txt
new file mode 100644
index 000000000000..377422c260ad
--- /dev/null
+++ b/examples/language/bert/requirements.txt
@@ -0,0 +1,9 @@
+colossalai
+evaluate
+datasets
+torch
+tqdm
+transformers
+scipy
+scikit-learn
+ptflops
diff --git a/examples/language/bert/run_gemini.sh b/examples/language/bert/run_gemini.sh
deleted file mode 100644
index d791334e8c97..000000000000
--- a/examples/language/bert/run_gemini.sh
+++ /dev/null
@@ -1,22 +0,0 @@
-set -x
-# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]
-export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
-
-# The following options only valid when DISTPLAN="colossalai"
-export GPUNUM=${GPUNUM:-1}
-export PLACEMENT=${PLACEMENT:-"cpu"}
-export BATCH_SIZE=${BATCH_SIZE:-16}
-
-# bert | albert
-export MODEL_TYPE=${MODEL_TYPE:-"bert"}
-export TRAIN_STEP=${TRAIN_STEP:-10}
-
-mkdir -p gemini_logs
-
-env CUDA_LAUNCH_BLOCKING=1 torchrun --standalone --nproc_per_node=${GPUNUM} ./train_bert_demo.py \
---model_type=${MODEL_TYPE} \
---batch_size=${BATCH_SIZE} \
---placement=${PLACEMENT} \
---distplan=${DISTPLAN} \
---train_step=${TRAIN_STEP} \
-2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_${PLACEMENT}.log
diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh
old mode 100644
new mode 100755
index 42c63fec50c0..7fc6daabb2f3
--- a/examples/language/bert/test_ci.sh
+++ b/examples/language/bert/test_ci.sh
@@ -1,2 +1,8 @@
-set -x
-env GPUNUM=1 bash run_gemini.sh
+#!/bin/bash
+set -xe
+
+pip install -r requirements.txt
+
+for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
+ torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
+done
diff --git a/examples/language/bert/train_bert_demo.py b/examples/language/bert/train_bert_demo.py
deleted file mode 100644
index 9a0278b2c711..000000000000
--- a/examples/language/bert/train_bert_demo.py
+++ /dev/null
@@ -1,331 +0,0 @@
-import os
-from functools import partial
-from time import time
-
-import psutil
-import torch
-from packaging import version
-from torch import nn
-from torch.nn.parallel import DistributedDataParallel as DDP
-from transformers import AlbertConfig, AlbertForSequenceClassification, BertConfig, BertForSequenceClassification
-
-import colossalai
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.nn.optimizer import HybridAdam
-from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
-from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
-
-CAI_VERSION = colossalai.__version__
-
-
-def get_tflops(model_numel, batch_size, seq_len, step_time):
- return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
-
-
-def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
- from contextlib import nullcontext
-
- from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
- if enable_flag:
- return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
- schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
- on_trace_ready=tensorboard_trace_handler(save_dir),
- record_shapes=True,
- profile_memory=True)
- else:
-
- class DummyProfiler:
-
- def __init__(self):
- self.step_number = 0
-
- def step(self):
- self.step_number += 1
-
- return nullcontext(DummyProfiler())
-
-
-def get_time_stamp():
- import time
- cur_time = time.strftime("%d-%H:%M", time.localtime())
- return cur_time
-
-
-def get_bert_data(batch_size: int, sequence_length: int, vacob_size: int, n_class: int, device: torch.device):
- input = torch.randint(
- low=0,
- high=vacob_size,
- size=(batch_size, sequence_length),
- device=device,
- dtype=torch.long,
- )
- label = torch.randint(low=0, high=n_class, size=(batch_size,), device=device, dtype=torch.long)
- return input, label
-
-
-def parse_args():
- parser = colossalai.get_default_parser()
- parser.add_argument(
- "--distplan",
- type=str,
- default='CAI_Gemini',
- help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
- )
- parser.add_argument(
- "--placement",
- type=str,
- default='cpu',
- help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=8,
- help="batch size per DP group of training.",
- )
- parser.add_argument(
- "--model_type",
- type=str,
- default="bert",
- help="bert or albert",
- )
- parser.add_argument(
- "--train_step",
- type=int,
- default=10,
- help="training iterations for test",
- )
-
- args = parser.parse_args()
- return args
-
-
-SEQ_LEN = 512
-VOCAB_SIZE = 1000
-NUM_LABELS = 10
-
-
-# Parameter Sharding Strategies for Tensor Parallelism
-def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
- spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
- param.set_tensor_spec(*spec)
-
-
-def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
- split_param_single_dim_tp1d(0, param, pg)
-
-
-def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
- split_param_single_dim_tp1d(-1, param, pg)
-
-
-def get_cpu_mem():
- return psutil.Process().memory_info().rss / 1024**2
-
-
-def get_gpu_mem():
- return torch.cuda.memory_allocated() / 1024**2
-
-
-def get_mem_info(prefix=''):
- return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
-
-
-def get_model_size(model: nn.Module):
- total_numel = 0
- for module in model.modules():
- for p in module.parameters(recurse=False):
- total_numel += p.numel()
- return total_numel
-
-
-def model_builder(args):
- if args.model_type == "bert":
- cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
- return BertForSequenceClassification(cfg)
- elif args.model_type == "albert":
- cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
- return AlbertForSequenceClassification(cfg)
- else:
- raise RuntimeError
-
-
-def model_size_formatter(numel: int) -> str:
- GB_SIZE = 10**9
- MB_SIZE = 10**6
- KB_SIZE = 10**3
- if numel >= GB_SIZE:
- return f'{numel / GB_SIZE:.1f}B'
- elif numel >= MB_SIZE:
- return f'{numel / MB_SIZE:.1f}M'
- elif numel >= KB_SIZE:
- return f'{numel / KB_SIZE:.1f}K'
- else:
- return str(numel)
-
-
-def set_cpu_maximum_parallelism():
- conf_str = torch.__config__.parallel_info()
- inter_str = conf_str.split("hardware_concurrency() : ")[1]
- max_concurrency = inter_str.split('\n')[0]
- os.environ["OMP_NUM_THREADS"] = max_concurrency
- print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
-
-
-def main():
- # version check
- # this example is supposed to work for versions greater than 0.2.0
- assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
-
- set_cpu_maximum_parallelism()
- args = parse_args()
-
- # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
- if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
- raise TypeError(f"{args.distplan} is error")
-
- # batch size per DP degree
- BATCH_SIZE = args.batch_size
-
- NUM_STEPS = args.train_step
-
- WARMUP_STEPS = 1
- assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
- assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
- PROF_FLAG = False # The flag of profiling, False by default
-
- disable_existing_loggers()
- colossalai.launch_from_torch(config={})
-
- logger = get_dist_logger()
- logger.info(f" {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
-
- torch.manual_seed(123)
- if args.distplan.startswith("CAI"):
- # all param must use the same process group.
- world_size = torch.distributed.get_world_size()
-
- # build a base-bert model
- with ColoInitContext(device=get_current_device(), dtype=torch.half):
- model = model_builder(args)
- # model = BertForSequenceClassification(BertConfig(vocal_size = VOCAB_SIZE))
-
- # asign running configurations
- gemini_config = None
- if args.distplan.startswith("CAI_ZeRO"):
- optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
- elif args.distplan == "CAI_Gemini":
- gemini_config = dict(strict_ddp_mode=True,
- device=get_current_device(),
- placement_policy=args.placement,
- pin_memory=True,
- hidden_dim=model.config.hidden_size,
- search_range_mb=128)
- optim_config = dict(gpu_margin_mem_ratio=0.)
- else:
- raise RuntimeError
-
- # build a highly optimized gpu/cpu optimizer
- optimizer = HybridAdam(model.parameters(), lr=1e-3)
-
- if args.distplan == "CAI_ZeRO1":
- zero_stage = 1
- elif args.distplan == "CAI_ZeRO2":
- zero_stage = 2
- elif args.distplan == "CAI_Gemini":
- zero_stage = 3
- else:
- raise RuntimeError
-
- # wrap your model and optimizer
- model = zero_model_wrapper(model, zero_stage, gemini_config)
- optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
-
- logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
- elif args.distplan.startswith("Pytorch"):
- model = model_builder(args).cuda()
- model = DDP(model)
- if args.distplan.endswith("DDP"):
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
- elif args.distplan.endswith("ZeRO"):
- from torch.distributed.optim import ZeroRedundancyOptimizer
- optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
- else:
- raise RuntimeError
-
- # model is shared after TP
- numel = get_model_size(model)
- logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
- logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
-
- # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
- # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
- # = batch_per_DP_group * numel * seq_len * 8
- get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
-
- torch.cuda.synchronize()
- model.train()
- tflops_list = []
-
- def train_step():
- # we just use randomly generated data here
- input_ids, labels = get_bert_data(BATCH_SIZE,
- SEQ_LEN,
- VOCAB_SIZE,
- NUM_LABELS,
- device=torch.cuda.current_device())
- optimizer.zero_grad()
-
- start = time()
- outputs = model(input_ids, labels=labels)
- loss, logits = outputs[:2]
- torch.cuda.synchronize()
- fwd_end = time()
- fwd_time = fwd_end - start
- logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
-
- if args.distplan.startswith("CAI"):
- optimizer.backward(loss)
- elif args.distplan.startswith("Pytorch"):
- loss.backward()
- else:
- raise RuntimeError
-
- torch.cuda.synchronize()
- bwd_end = time()
- bwd_time = bwd_end - fwd_end
- logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
-
- optimizer.step()
- torch.cuda.synchronize()
- optim_time = time() - bwd_end
- step_time = time() - start
- logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
-
- step_tflops = get_tflops_func(step_time)
- logger.info(
- f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
- ranks=[0],
- )
- if n >= WARMUP_STEPS:
- tflops_list.append(step_tflops)
-
- demo_profiler = get_profile_context(PROF_FLAG,
- WARMUP_STEPS,
- NUM_STEPS - WARMUP_STEPS,
- save_dir=f"profile/{get_time_stamp()}-demo")
-
- with demo_profiler as prof:
- for n in range(NUM_STEPS):
- train_step()
- prof.step()
-
- tflops_list.sort()
- median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
- logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
- torch.cuda.synchronize()
-
-
-if __name__ == '__main__':
- main()
diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py
index 92751c7e2f47..4b78624f0110 100644
--- a/examples/language/gpt/gemini/train_gpt_demo.py
+++ b/examples/language/gpt/gemini/train_gpt_demo.py
@@ -162,7 +162,7 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
else:
@@ -173,9 +173,9 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
elif 'c_attn' in mn or 'c_proj' in mn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
@@ -237,7 +237,7 @@ def main():
if args.tp_degree > 1:
tensor_parallelize(model, tp_pg)
- # asign running configurations
+ # assign running configurations
if args.distplan == "CAI_ZeRO1":
zero_stage = 1
elif args.distplan == "CAI_ZeRO2":
diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py
index 6369b9f8c5a1..d825ae92a285 100644
--- a/examples/language/gpt/titans/model/embed.py
+++ b/examples/language/gpt/titans/model/embed.py
@@ -305,7 +305,7 @@ def forward(ctx, vocab_parallel_logits, target):
@staticmethod
def backward(ctx, grad_output):
- # Retreive tensors from the forward path.
+ # Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as their gradient.
diff --git a/examples/language/opt/README.md b/examples/language/opt/README.md
index c2fd254571c7..37e1ff4d9008 100644
--- a/examples/language/opt/README.md
+++ b/examples/language/opt/README.md
@@ -19,15 +19,35 @@ Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/fa
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
-We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
-the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
## Our Modifications
-We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
-## Quick Start
-You can launch training by using the following bash script
+We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
+the tokenization).
+
+We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin.
+
+## Run Demo
+By running the following script:
```bash
-bash ./run_gemini.sh
+bash run_demo.sh
```
+You will finetune a [facebook/opt-350m](https://huggingface.co/facebook/opt-350m) model on this [dataset](https://huggingface.co/datasets/hugginglearners/netflix-shows), which contains more than 8000 comments on Netflix shows.
+
+The script can be modified if you want to try another set of hyperparameters or change to another OPT model with different size.
+
+The demo code is adapted from this [blog](https://medium.com/geekculture/fine-tune-eleutherai-gpt-neo-to-generate-netflix-movie-descriptions-in-only-47-lines-of-code-40c9b4c32475) and the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
+
+
+
+## Run Benchmark
+
+You can run benchmark for OPT model by running the following script:
+```bash
+bash run_benchmark.sh
+```
+The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing.
+
+
+
diff --git a/examples/language/opt/args.py b/examples/language/opt/args.py
new file mode 100644
index 000000000000..16730be7ebea
--- /dev/null
+++ b/examples/language/opt/args.py
@@ -0,0 +1,120 @@
+from colossalai import get_default_parser
+
+
+def parse_demo_args():
+
+ parser = get_default_parser()
+ parser.add_argument(
+ "--model_name_or_path",
+ type=str,
+ default="facebook/opt-350m",
+ help="Path to pretrained model or model identifier from huggingface.co/models."
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ default="./output_model.bin",
+ help="The path of your saved model after finetuning."
+ )
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
+ )
+ parser.add_argument(
+ "--num_epoch",
+ type=int,
+ default=10,
+ help="Number of epochs."
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=32,
+ help="Batch size (per dp group) for the training dataloader."
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-5,
+ help="Initial learning rate (after the potential warmup period) to use."
+ )
+ parser.add_argument(
+ "--warmup_ratio",
+ type=float,
+ default=0.1,
+ help="Ratio of warmup steps against total training steps."
+ )
+ parser.add_argument(
+ "--weight_decay",
+ type=float,
+ default=0.01,
+ help="Weight decay to use."
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="A seed for reproducible training."
+ )
+
+ args = parser.parse_args()
+ return args
+
+
+
+def parse_benchmark_args():
+
+ parser = get_default_parser()
+ parser.add_argument(
+ "--model_name_or_path",
+ type=str,
+ default="facebook/opt-125m",
+ help="Path to pretrained model or model identifier from huggingface.co/models."
+ )
+ parser.add_argument(
+ "--plugin",
+ type=str,
+ default="gemini",
+ help="Plugin to use. Valid plugins include 'torch_ddp','torch_ddp_fp16','gemini','low_level_zero'."
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=32,
+ help="Batch size (per dp group) for the training dataloader."
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-5,
+ help="Initial learning rate (after the potential warmup period) to use."
+ )
+ parser.add_argument(
+ "--weight_decay",
+ type=float,
+ default=0.0,
+ help="Weight decay to use."
+ )
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=20,
+ help="Total number of training steps to perform."
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="A seed for reproducible training."
+ )
+ parser.add_argument(
+ "--mem_cap",
+ type=int,
+ default=0,
+ help="Limit on the usage of space for each GPU (in GB)."
+ )
+ args = parser.parse_args()
+
+ return args
\ No newline at end of file
diff --git a/examples/language/opt/benchmark.sh b/examples/language/opt/benchmark.sh
deleted file mode 100644
index 0d04b5e9b33c..000000000000
--- a/examples/language/opt/benchmark.sh
+++ /dev/null
@@ -1,21 +0,0 @@
-export BS=16
-export MEMCAP=0
-export MODEL="6.7b"
-export GPUNUM=1
-
-for MODEL in "6.7b" "13b" "1.3b"
-do
-for GPUNUM in 8 1
-do
-for BS in 16 24 32 8
-do
-for MEMCAP in 0 40
-do
-pkill -9 torchrun
-pkill -9 python
-
-env BS=$BS MEM_CAP=$MEMCAP MODEL=$MODEL GPUNUM=$GPUNUM bash ./run_gemini.sh
-done
-done
-done
-done
diff --git a/examples/language/opt/data.py b/examples/language/opt/data.py
new file mode 100644
index 000000000000..6cfffb5fc95b
--- /dev/null
+++ b/examples/language/opt/data.py
@@ -0,0 +1,37 @@
+import torch
+from torch.utils.data import Dataset
+from datasets import load_dataset
+
+
+class NetflixDataset(Dataset):
+
+ def __init__(self, tokenizer):
+
+ super().__init__()
+
+ self.tokenizer = tokenizer
+ self.input_ids = []
+ self.attn_masks = []
+ self.labels = []
+ self.txt_list = netflix_descriptions = load_dataset("hugginglearners/netflix-shows", split="train")['description']
+ self.max_length = max([len(self.tokenizer.encode(description)) for description in netflix_descriptions])
+
+ for txt in self.txt_list:
+ encodings_dict = self.tokenizer('' + txt + '',
+ truncation=True,
+ max_length=self.max_length,
+ padding="max_length")
+ self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
+ self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
+
+ def __len__(self):
+ return len(self.input_ids)
+
+ def __getitem__(self, idx):
+ return self.input_ids[idx], self.attn_masks[idx]
+
+
+def netflix_collator(data):
+ return {'input_ids': torch.stack([x[0] for x in data]),
+ 'attention_mask': torch.stack([x[1] for x in data]),
+ 'labels': torch.stack([x[0] for x in data])}
diff --git a/examples/language/opt/opt_benchmark.py b/examples/language/opt/opt_benchmark.py
new file mode 100755
index 000000000000..2d69036b50c6
--- /dev/null
+++ b/examples/language/opt/opt_benchmark.py
@@ -0,0 +1,137 @@
+import time
+
+import torch
+import transformers
+from transformers import AutoConfig, OPTForCausalLM
+from transformers.utils.versions import require_version
+import tqdm
+
+import colossalai
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.tensor import ProcessGroup, ShardSpec
+from colossalai.utils import get_current_device
+from colossalai.zero import ColoInitContext
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+
+from args import parse_benchmark_args
+
+require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
+
+
+def format_num(num: int, bytes=False):
+ """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
+ factor = 1024 if bytes else 1000
+ suffix = "B" if bytes else ""
+ for unit in ["", " K", " M", " G", " T", " P"]:
+ if num < factor:
+ return f"{num:.2f}{unit}{suffix}"
+ num /= factor
+
+
+def get_data(batch_size, seq_len, vocab_size):
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
+ attention_mask = torch.ones_like(input_ids)
+ return input_ids, attention_mask
+
+
+def colo_memory_cap(size_in_GB):
+ from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
+ cuda_capacity = colo_device_memory_capacity(get_current_device())
+ if size_in_GB * (1024**3) < cuda_capacity:
+ colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
+ print(f"Limiting GPU memory usage to {size_in_GB} GB")
+
+
+def main():
+
+ args = parse_benchmark_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+ world_size = coordinator.world_size
+
+ # Manage loggers
+ disable_existing_loggers()
+ logger = get_dist_logger()
+ if coordinator.is_master():
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+
+ # Whether to set limit of memory capacity
+ if args.mem_cap > 0:
+ colo_memory_cap(args.mem_cap)
+
+ # Build OPT model
+ config = AutoConfig.from_pretrained(args.model_name_or_path)
+ model = OPTForCausalLM(config=config)
+ logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
+ # Set plugin
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(device=get_current_device(),
+ placement_policy='cpu',
+ pin_memory=True,
+ strict_ddp_mode=True,
+ initial_scale=2**5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ logger.info(f"Set plugin as {args.plugin}", ranks=[0])
+
+ # Set optimizer
+ optimizer = HybridAdam(model.parameters(), lr=args.learning_rate)
+
+ # Set booster
+ booster = Booster(plugin=plugin, **booster_kwargs)
+ model, optimizer, _, _, _ = booster.boost(model, optimizer)
+
+ SEQ_LEN = 1024
+ VOCAB_SIZE = 50257
+
+ # Start training.
+ logger.info(f"Start testing", ranks=[0])
+ progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
+
+ torch.cuda.synchronize()
+ model.train()
+ start_time = time.time()
+
+ for _ in range(args.max_train_steps):
+
+ input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
+ optimizer.zero_grad()
+ outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
+ loss = outputs['loss']
+ booster.backward(loss, optimizer)
+ optimizer.step()
+
+ torch.cuda.synchronize()
+ progress_bar.update(1)
+
+ # Compute Statistics
+ end_time = time.time()
+ throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
+ max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
+
+ logger.info(f"Testing finished, "
+ f"batch size per gpu: {args.batch_size}, "
+ f"plugin: {args.plugin}, "
+ f"throughput: {throughput}, "
+ f"maximum memory usage per gpu: {max_mem}.",
+ ranks=[0])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/language/opt/opt_train_demo.py b/examples/language/opt/opt_train_demo.py
new file mode 100644
index 000000000000..fa7feca9c9a9
--- /dev/null
+++ b/examples/language/opt/opt_train_demo.py
@@ -0,0 +1,142 @@
+import time
+
+import torch
+import datasets
+import transformers
+from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
+from transformers import get_linear_schedule_with_warmup
+from transformers.utils.versions import require_version
+from tqdm import tqdm
+
+import colossalai
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.tensor import ProcessGroup, ShardSpec
+from colossalai.utils import get_current_device
+from colossalai.zero import ColoInitContext
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
+from colossalai.cluster import DistCoordinator
+
+from args import parse_demo_args
+from data import NetflixDataset, netflix_collator
+
+require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
+require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
+
+
+def move_to_cuda(batch, device):
+ return {k: v.to(device) for k, v in batch.items()}
+
+
+def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
+
+ torch.cuda.synchronize()
+ model.train()
+
+ with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
+
+ for batch in pbar:
+
+ # Forward
+ optimizer.zero_grad()
+ batch = move_to_cuda(batch, torch.cuda.current_device())
+
+ outputs = model(use_cache=False, **batch)
+ loss = outputs['loss']
+
+ # Backward
+ booster.backward(loss, optimizer)
+ optimizer.step()
+ lr_scheduler.step()
+
+ # Print batch loss
+ pbar.set_postfix({'loss': loss.item()})
+
+
+def main():
+
+ args = parse_demo_args()
+
+ # Launch ColossalAI
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+ coordinator = DistCoordinator()
+ world_size = coordinator.world_size
+
+ # Manage loggers
+ disable_existing_loggers()
+ logger = get_dist_logger()
+ if coordinator.is_master():
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+
+ # Build OPT model
+ config = AutoConfig.from_pretrained(args.model_name_or_path)
+ model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
+ logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
+
+ # Enable gradient checkpointing
+ model.gradient_checkpointing_enable()
+
+ # Set plugin
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(device=get_current_device(),
+ placement_policy='cpu',
+ pin_memory=True,
+ strict_ddp_mode=True,
+ initial_scale=2**5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2**5)
+ logger.info(f"Set plugin as {args.plugin}", ranks=[0])
+
+ # Prepare tokenizer and dataloader
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
+ dataset = NetflixDataset(tokenizer)
+ dataloader = plugin.prepare_dataloader(dataset,
+ batch_size=args.batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=netflix_collator)
+
+ # Set optimizer
+ optimizer = HybridAdam(model.parameters(),
+ lr=(args.learning_rate * world_size),
+ weight_decay=args.weight_decay)
+
+ # Set lr scheduler
+ total_steps = len(dataloader) * args.num_epoch
+ num_warmup_steps = int(args.warmup_ratio * total_steps)
+ lr_scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=len(dataloader) * args.num_epoch
+ )
+
+ # Set booster
+ booster = Booster(plugin=plugin, **booster_kwargs)
+ model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
+ optimizer=optimizer,
+ dataloader=dataloader,
+ lr_scheduler=lr_scheduler)
+
+ # Start finetuning
+ logger.info(f"Start finetuning", ranks=[0])
+ for epoch in range(args.num_epoch):
+ train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator)
+
+ # Finish training and evaluate
+ logger.info(f"Finish finetuning", ranks=[0])
+ booster.save_model(model, args.output_path)
+ logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt
index 137a69e80498..4422216e6a1c 100644
--- a/examples/language/opt/requirements.txt
+++ b/examples/language/opt/requirements.txt
@@ -1,2 +1,4 @@
colossalai >= 0.1.12
torch >= 1.8.1
+datasets >= 1.8.0
+transformers >= 4.20.0
\ No newline at end of file
diff --git a/examples/language/opt/run_benchmark.sh b/examples/language/opt/run_benchmark.sh
new file mode 100644
index 000000000000..76c5e8601989
--- /dev/null
+++ b/examples/language/opt/run_benchmark.sh
@@ -0,0 +1,30 @@
+set -xe
+pip install -r requirements.txt
+
+export BS=32
+export MEMCAP=0
+export GPUNUM=1
+
+# acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`
+export MODEL="125m"
+
+for BS in 8 32 128
+do
+for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
+do
+for GPUNUM in 1 4
+do
+
+MODLE_PATH="facebook/opt-${MODEL}"
+torchrun \
+ --standalone \
+ --nproc_per_node ${GPUNUM} \
+ opt_benchmark.py \
+ --model_name_or_path ${MODLE_PATH} \
+ --mem_cap ${MEMCAP} \
+ --plugin ${PLUGIN} \
+ --batch_size ${BS}
+
+done
+done
+done
diff --git a/examples/language/opt/run_demo.sh b/examples/language/opt/run_demo.sh
new file mode 100644
index 000000000000..0c9759c34039
--- /dev/null
+++ b/examples/language/opt/run_demo.sh
@@ -0,0 +1,44 @@
+set -xe
+pip install -r requirements.txt
+
+# model name or path
+MODEL="facebook/opt-350m"
+
+# path for saving model
+OUTPUT_PATH="./output_model.bin"
+
+# plugin(training strategy)
+# can only be one of "torch_ddp"/"torch_ddp_fp16"/"low_level_zero"/"gemini"
+PLUGIN="gemini"
+
+# number of gpus to use
+GPUNUM=4
+
+# batch size per gpu
+BS=16
+
+# learning rate
+LR="5e-5"
+
+# number of epoch
+EPOCH=10
+
+# weight decay
+WEIGHT_DECAY=0.01
+
+# ratio of warmup steps
+WARMUP_RATIO=0.1
+
+# run the script for demo
+torchrun \
+ --standalone \
+ --nproc_per_node ${GPUNUM} \
+ opt_train_demo.py \
+ --model_name_or_path ${MODEL} \
+ --output_path ${OUTPUT_PATH} \
+ --plugin ${PLUGIN} \
+ --batch_size ${BS} \
+ --num_epoch ${EPOCH} \
+ --learning_rate ${LR} \
+ --weight_decay ${WEIGHT_DECAY} \
+ --warmup_ratio ${WARMUP_RATIO}
diff --git a/examples/language/opt/run_gemini.sh b/examples/language/opt/run_gemini.sh
deleted file mode 100644
index 73f231292a13..000000000000
--- a/examples/language/opt/run_gemini.sh
+++ /dev/null
@@ -1,28 +0,0 @@
-set -x
-export BS=${BS:-16}
-export MEMCAP=${MEMCAP:-0}
-# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
-export MODEL=${MODEL:-"125m"}
-export GPUNUM=${GPUNUM:-1}
-export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"}
-
-# make directory for logs
-mkdir -p ./logs
-
-if [ ${USE_SHARD_INIT} = "true" ]; then
- USE_SHARD_INIT="--shardinit"
-else
- USE_SHARD_INIT=""
-fi
-
-export MODLE_PATH="facebook/opt-${MODEL}"
-
-# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
-torchrun \
- --nproc_per_node ${GPUNUM} \
- --master_port 19198 \
- train_gemini_opt.py \
- --mem_cap ${MEMCAP} \
- --model_name_or_path ${MODLE_PATH} \
- ${USE_SHARD_INIT} \
- --batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
diff --git a/examples/language/opt/test_ci.sh b/examples/language/opt/test_ci.sh
index 317f602cda3c..fa14f52b70d2 100644
--- a/examples/language/opt/test_ci.sh
+++ b/examples/language/opt/test_ci.sh
@@ -1,4 +1,19 @@
-for GPUNUM in 2 1
+set -xe
+pip install -r requirements.txt
+
+BS=4
+for PLUGIN in "torch_ddp" "torch_ddp_fp16" "low_level_zero" "gemini"
do
-env BS=2 MODEL="125m" GPUNUM=$GPUNUM bash ./run_gemini.sh
+for GPUNUM in 1 4
+do
+
+torchrun \
+ --standalone \
+ --nproc_per_node ${GPUNUM} \
+ opt_benchmark.py \
+ --model_name_or_path "facebook/opt-125m" \
+ --plugin ${PLUGIN} \
+ --batch_size ${BS}
+
+done
done
diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py
deleted file mode 100755
index 3614b689de26..000000000000
--- a/examples/language/opt/train_gemini_opt.py
+++ /dev/null
@@ -1,233 +0,0 @@
-#!/usr/bin/env python
-# coding=utf-8
-# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...)
-on a text file or a dataset without using HuggingFace Trainer.
-
-Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
-https://huggingface.co/models?filter=text-generation
-"""
-# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
-
-import time
-from functools import partial
-
-import datasets
-import torch
-import torch.distributed as dist
-import transformers
-from transformers import CONFIG_MAPPING, MODEL_MAPPING, AutoConfig, OPTForCausalLM
-from transformers.utils.versions import require_version
-
-import colossalai
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.tensor import ProcessGroup, ShardSpec
-from colossalai.utils import get_current_device
-from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP
-
-
-def get_data(batch_size, seq_len, vocab_size):
- input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
- attention_mask = torch.ones_like(input_ids)
- return input_ids, attention_mask
-
-
-require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
-
-MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
-MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
-
-
-def get_time_stamp():
- torch.cuda.synchronize()
- return time.time()
-
-
-def get_tflops(model_numel, batch_size, seq_len, step_time):
- return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
-
-
-def parse_args():
- parser = colossalai.get_default_parser()
- parser.add_argument(
- "--model_name_or_path",
- type=str,
- help="Path to pretrained model or model identifier from huggingface.co/models.",
- required=True,
- )
- parser.add_argument(
- "--config_name",
- type=str,
- default=None,
- help="Pretrained config name or path if not the same as model_name",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=8,
- help="Batch size (per dp group) for the training dataloader.",
- )
- parser.add_argument(
- "--learning_rate",
- type=float,
- default=5e-5,
- help="Initial learning rate (after the potential warmup period) to use.",
- )
- parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
- parser.add_argument(
- "--max_train_steps",
- type=int,
- default=20,
- help="Total number of training steps to perform.",
- )
- parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
- parser.add_argument(
- "--model_type",
- type=str,
- default=None,
- help="Model type to use if training from scratch.",
- choices=MODEL_TYPES,
- )
- parser.add_argument(
- "--shardinit",
- action="store_true",
- help="Initialize the model with tensor parallel",
- )
- parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
- parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
- args = parser.parse_args()
-
- return args
-
-
-def colo_memory_cap(size_in_GB):
- from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
- cuda_capacity = colo_device_memory_capacity(get_current_device())
- if size_in_GB * (1024**3) < cuda_capacity:
- colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
- print("Using {} GB of GPU memory".format(size_in_GB))
-
-
-def main():
- args = parse_args()
- disable_existing_loggers()
- colossalai.launch_from_torch({})
- logger = get_dist_logger()
- is_main_process = dist.get_rank() == 0
-
- if is_main_process:
- datasets.utils.logging.set_verbosity_warning()
- transformers.utils.logging.set_verbosity_info()
- else:
- datasets.utils.logging.set_verbosity_error()
- transformers.utils.logging.set_verbosity_error()
-
- if args.mem_cap > 0:
- colo_memory_cap(args.mem_cap)
-
- # If passed along, set the training seed now.
- if args.seed is not None:
- torch.mannul_seed(args.seed)
- logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}")
-
- # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
- # https://huggingface.co/docs/datasets/loading_datasets.html.
-
- # Load pretrained model
- # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
- # download model & vocab.
- if args.config_name:
- config = AutoConfig.from_pretrained(args.config_name)
- elif args.model_name_or_path:
- config = AutoConfig.from_pretrained(args.model_name_or_path)
- else:
- config = CONFIG_MAPPING[args.model_type]()
- logger.warning("You are instantiating a new config instance from scratch.")
- logger.info("Model config has been created", ranks=[0])
-
- if args.init_in_cpu:
- init_dev = torch.device('cpu')
- else:
- init_dev = get_current_device()
-
- # shard init parameters
- if args.shardinit:
- logger.info("Sharding initialization !", ranks=[0])
- else:
- logger.info("Skipping sharding initialization", ranks=[0])
-
- world_size = torch.distributed.get_world_size()
- shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
- default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
-
- # build model
- if args.model_name_or_path is None:
- logger.info("Train a new model from scratch", ranks=[0])
- with ColoInitContext(device=init_dev,
- dtype=torch.half,
- default_dist_spec=default_dist_spec,
- default_pg=shard_pg):
- model = OPTForCausalLM(config)
- else:
- logger.info("Finetune a pre-trained model", ranks=[0])
- with ColoInitContext(device=init_dev,
- dtype=torch.half,
- default_dist_spec=default_dist_spec,
- default_pg=shard_pg):
- model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
- from_tf=bool(".ckpt" in args.model_name_or_path),
- config=config,
- local_files_only=False)
-
- # enable gradient checkpointing
- model.gradient_checkpointing_enable()
-
- numel = sum([p.numel() for p in model.parameters()])
- PLACEMENT_POLICY = 'cpu'
- model = GeminiDDP(model,
- device=get_current_device(),
- placement_policy=PLACEMENT_POLICY,
- pin_memory=True,
- strict_ddp_mode=args.shardinit)
- optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
-
- SEQ_LEN = 1024
- VOCAB_SIZE = 50257
-
- get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
-
- model.train()
- for step in range(args.max_train_steps):
- st_time = time.time()
- input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
-
- outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=input_ids, use_cache=False)
- loss = outputs['loss']
- optimizer.backward(loss)
-
- optimizer.step()
- optimizer.zero_grad()
- torch.cuda.synchronize()
- step_time = time.time() - st_time
- step_tflops = get_tflops_func(step_time)
-
- logger.info("step {} finished, Tflops {}".format(step, step_tflops), ranks=[0])
-
- logger.info("Training finished", ranks=[0])
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/language/palm/README.md b/examples/language/palm/README.md
index 486bf240f89c..3ff3939d63d4 100644
--- a/examples/language/palm/README.md
+++ b/examples/language/palm/README.md
@@ -43,6 +43,9 @@ palm = PaLM(
)
```
+## New API
+We have modified our previous implementation of PaLM with our new Booster API, which offers a more flexible and efficient way to train your model. The new API is more user-friendly and easy to use. You can find the new API in train.py. We have also offer a shell script test_ci.sh for you to go through all our plugins for the booster. For more information about the booster API you can refer to https://colossalai.org/docs/basics/booster_api/.
+
## Test on Enwik8
```bash
diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh
index 7a533509e009..2a846e81a9a7 100644
--- a/examples/language/palm/run.sh
+++ b/examples/language/palm/run.sh
@@ -3,9 +3,11 @@ export DISTPAN="colossalai"
# The following options only valid when DISTPAN="colossalai"
export TPDEGREE=1
-export GPUNUM=1
+export GPUNUM=4
export PLACEMENT='cpu'
export USE_SHARD_INIT=False
-export BATCH_SIZE=4
+export BATCH_SIZE=1
-env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
+env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py \
+--dummy_data=True --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --plugin='gemini' \
+--placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
diff --git a/examples/language/palm/test_ci.sh b/examples/language/palm/test_ci.sh
index f21095578077..4de6a44e5bf7 100644
--- a/examples/language/palm/test_ci.sh
+++ b/examples/language/palm/test_ci.sh
@@ -4,6 +4,6 @@ for BATCH_SIZE in 2
do
for GPUNUM in 1 4
do
-env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log
+env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --standalone train.py --dummy_data=True --batch_size=${BATCH_SIZE} --plugin='gemini' 2>&1 | tee run.log
done
done
diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py
index b16da1c7744a..a0600db1bc5b 100644
--- a/examples/language/palm/train.py
+++ b/examples/language/palm/train.py
@@ -9,6 +9,8 @@
import torch.optim as optim
import tqdm
from packaging import version
+
+from colossalai.nn import HybridAdam
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset
@@ -18,6 +20,8 @@
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
+from colossalai.booster import Booster
+from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
# constants
@@ -58,6 +62,12 @@ def parse_args():
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
+ parser.add_argument('-p',
+ '--plugin',
+ type=str,
+ default='torch_ddp',
+ choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
+ help="plugin to use")
parser.add_argument(
"--batch_size",
type=int,
@@ -101,28 +111,6 @@ def get_model_size(model: nn.Module):
return total_numel
-# Gemini + ZeRO DDP
-def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
- cai_version = colossalai.__version__
- if version.parse(cai_version) > version.parse("0.1.10"):
- from colossalai.nn.parallel import GeminiDDP
- model = GeminiDDP(model,
- device=get_current_device(),
- placement_policy=placement_policy,
- pin_memory=True,
- search_range_mb=32)
- elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
- from colossalai.gemini import ChunkManager, GeminiManager
- chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
- gemini_manager = GeminiManager(placement_policy, chunk_manager)
- chunk_manager = ChunkManager(chunk_size,
- pg,
- enable_distributed_storage=True,
- init_device=GeminiManager.get_default_device(placement_policy))
- model = ZeroDDP(model, gemini_manager)
- else:
- raise NotImplemented(f"CAI version {cai_version} is not supported")
- return model
# Parameter Sharding Strategies for Tensor Parallelism
@@ -152,15 +140,15 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
continue
param.set_dist_spec(ReplicaSpec())
if 'net.0' in mn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
elif 'to_q' in mn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
elif 'to_kv' in mn:
split_param_row_tp1d(param, pg) # row slice
elif 'to_out' in mn:
split_param_row_tp1d(param, pg) # row slice
elif '1.1' in mn:
- split_param_col_tp1d(param, pg) # colmn slice
+ split_param_col_tp1d(param, pg) # column slice
elif '1.2' in mn:
split_param_row_tp1d(param, pg) # row slice
else:
@@ -218,6 +206,18 @@ def __len__(self):
if args.distplan == "colossalai":
# instantiate GPT-like decoder model
+ booster_kwargs = {}
+ if args.plugin == 'torch_ddp_fp16':
+ booster_kwargs['mixed_precision'] = 'fp16'
+ if args.plugin.startswith('torch_ddp'):
+ plugin = TorchDDPPlugin()
+ elif args.plugin == 'gemini':
+ plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5)
+ elif args.plugin == 'low_level_zero':
+ plugin = LowLevelZeroPlugin(initial_scale=2 ** 5)
+ logger.info(f"plugin: {plugin}")
+ booster = Booster(plugin=plugin, **booster_kwargs)
+
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
@@ -228,12 +228,12 @@ def __len__(self):
pg = default_pg
tensor_parallelize(model, pg)
- model = gemini_zero_dpp(model, pg, args.placement)
# optimizer
- #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
- optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5)
+ optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
+ model, optimizer, _, _, _ = booster.boost(model, optimizer)
+
else:
model = PaLM(num_tokens=256, dim=512, depth=8)
model = AutoregressiveWrapper(model, max_seq_len=2048)
diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py
index 7470327a65b6..6cc4c8ef370d 100644
--- a/tests/kit/model_zoo/registry.py
+++ b/tests/kit/model_zoo/registry.py
@@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import Callable
-__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
+__all__ = ['ModelZooRegistry', 'ModelAttribute', 'model_zoo']
@dataclass
@@ -37,7 +37,7 @@ def register(self,
>>> model_zoo = ModelZooRegistry()
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
>>> # Run the model
- >>> data = resnresnet18_data_gen() # do not input any argument
+ >>> data = resnet18_data_gen() # do not input any argument
>>> model = resnet18() # do not input any argument
>>> out = model(**data)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
index b47b3508ad1b..c3ceef4c7adf 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_batch_norm_handler.py
@@ -27,7 +27,7 @@ def check_bn_module_handler(rank, world_size, port):
# the index of bn node in computation graph
node_index = 1
# the total number of bn strategies without sync bn mode
- # TODO: add sync bn stategies after related passes ready
+ # TODO: add sync bn strategies after related passes ready
strategy_number = 4
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py
index 5259455d2179..1703d5ded2f2 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_output_handler.py
@@ -43,14 +43,14 @@ def test_output_handler(output_option):
output_strategies_vector = StrategiesVector(output_node)
# build handler
- otuput_handler = OutputHandler(node=output_node,
+ output_handler = OutputHandler(node=output_node,
device_mesh=device_mesh,
strategies_vector=output_strategies_vector,
output_option=output_option)
- otuput_handler.register_strategy(compute_resharding_cost=False)
+ output_handler.register_strategy(compute_resharding_cost=False)
# check operation data mapping
- mapping = otuput_handler.get_operation_data_mapping()
+ mapping = output_handler.get_operation_data_mapping()
for name, op_data in mapping.items():
op_data: OperationData
@@ -59,7 +59,7 @@ def test_output_handler(output_option):
assert mapping['output'].name == "output"
assert mapping['output'].type == OperationDataType.OUTPUT
- strategy_name_list = [val.name for val in otuput_handler.strategies_vector]
+ strategy_name_list = [val.name for val in output_handler.strategies_vector]
if output_option == 'distributed':
assert "Distributed Output" in strategy_name_list
else:
diff --git a/tests/test_device/test_device_mesh.py b/tests/test_device/test_device_mesh.py
index 19d41d23353f..3be057b3a98b 100644
--- a/tests/test_device/test_device_mesh.py
+++ b/tests/test_device/test_device_mesh.py
@@ -1,19 +1,20 @@
-import torch
-
from colossalai.device.device_mesh import DeviceMesh
+import torch
def test_device_mesh():
- physical_mesh_id = torch.arange(0, 16)
+ physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
- assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
- assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
+ assert device_mesh.convert_map[5] == [1, 1]
+ assert device_mesh.convert_map[11] == [2, 3]
+ assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]]
+ assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]]
+ assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
if __name__ == '__main__':
diff --git a/tests/test_device/test_init_logical_pg.py b/tests/test_device/test_init_logical_pg.py
index 7c6339eff67e..2b7060c4846a 100644
--- a/tests/test_device/test_init_logical_pg.py
+++ b/tests/test_device/test_init_logical_pg.py
@@ -20,12 +20,16 @@ def check_layer(rank, world_size, port):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
- for axis in range(len(mesh_shape)):
- tensor = torch.ones(4).cuda()
- pg = device_mesh.get_process_group(axis=axis)
- dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
- assert tensor.equal(tensor_to_check)
+ logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
+ logical_process_groups = device_mesh.process_groups_dict
+
+ for mesh_dim, pgs in logical_pg_dict.items():
+ for index, pg in enumerate(pgs):
+ if rank in pg:
+ tensor = torch.ones(4).cuda()
+ group = logical_process_groups[mesh_dim][index][1]
+ dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
+ assert tensor.equal(tensor_to_check)
gpc.destroy()
diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py
index 2911012fafa8..85bfd0e27801 100644
--- a/tests/test_lazy/lazy_init_utils.py
+++ b/tests/test_lazy/lazy_init_utils.py
@@ -6,9 +6,7 @@
import torch
from packaging import version
-from colossalai.device.device_mesh import DeviceMesh
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
-from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import to_global
from tests.kit.model_zoo.registry import ModelAttribute
@@ -83,8 +81,7 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
print(f'{model.__class__.__name__} pass')
-def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
- sharding_spec_dict: dict) -> None:
+def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
state = model.state_dict()
distributed_state = distributed_model.state_dict()
@@ -94,7 +91,6 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
assert n1 == n2
t1 = t1.cuda()
t2 = t2.cuda()
- if n2 in sharding_spec_dict:
- layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
- t2 = to_global(t2, layout)
+ if n2 in layout_dict:
+ t2 = to_global(t2, layout_dict[n2])
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py
index efa43eab5788..d515b175a9ea 100644
--- a/tests/test_lazy/test_distribute.py
+++ b/tests/test_lazy/test_distribute.py
@@ -26,19 +26,23 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
return dim
-def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
+def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
shard_dim = find_shard_dim(original_tensor.shape)
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
- return target_sharding_spec
+ layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=target_sharding_spec,
+ entire_shape=original_tensor.shape)
+ return layout
def _get_current_name(prefix: str, name: str) -> str:
return f'{prefix}.{name}'.lstrip('.')
-def generate_sharding_spec_dict(model: nn.Module) -> dict:
- sharding_spec_dict = {}
+def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
+ layout_dict = {}
@torch.no_grad()
def generate_recursively(module: nn.Module, prefix: str = ''):
@@ -49,17 +53,17 @@ def generate_recursively(module: nn.Module, prefix: str = ''):
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
if isinstance(param, LazyTensor):
- sharding_spec = make_sharding_spec(param)
- sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
+ layout = make_layout(device_mesh, param)
+ layout_dict[_get_current_name(prefix, name)] = layout
for name, buf in module.named_buffers(recurse=False):
if isinstance(buf, LazyTensor):
- sharding_spec = make_sharding_spec(buf)
- sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
+ layout = make_layout(device_mesh, buf)
+ layout_dict[_get_current_name(prefix, name)] = layout
generate_recursively(model)
- return sharding_spec_dict
+ return layout_dict
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
@@ -81,9 +85,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
ctx = LazyInitContext()
with ctx:
deferred_model = model_fn()
- sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
- ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
- assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
+ layout_dict = generate_layout_dict(deferred_model, device_mesh)
+ ctx.distribute(deferred_model, layout_dict, verbose=True)
+ assert_dist_model_equal(model, deferred_model, layout_dict)
def run_dist(rank, world_size, port) -> None:
diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py
index 0797e01e7e9d..d1f5b9299397 100644
--- a/tests/test_tensor/test_dtensor/test_comm_spec.py
+++ b/tests/test_tensor/test_dtensor/test_comm_spec.py
@@ -125,6 +125,23 @@ def check_all_reduce_bwd(process_groups_dict, rank):
assert tensor_to_comm.equal(tensor_to_check)
+def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank):
+ # tensor to comm
+ tensor_to_comm = torch.ones(2, 2).cuda() * rank
+
+ # reduce through logical process axis 0 at flatten device mesh
+ # tensor to check
+ # tensor([[6., 6.],
+ # [6., 6.]])
+ tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
+
+ # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
+ comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
+ tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
+
+ assert tensor_to_comm.equal(tensor_to_check)
+
+
def check_comm(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@@ -136,22 +153,24 @@ def check_comm(rank, world_size, port):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
- process_group_dict = device_mesh._process_group_dict[rank]
+ process_groups_dict = device_mesh.process_groups_dict
# test all gather
- check_all_gather(process_group_dict, rank)
+ check_all_gather(process_groups_dict, rank)
# test shard
- check_shard(process_group_dict, rank)
+ check_shard(process_groups_dict, rank)
# test all to all
- check_all_to_all(process_group_dict, rank)
+ check_all_to_all(process_groups_dict, rank)
# test all reduce
- check_all_reduce_fwd(process_group_dict, rank)
- check_all_reduce_bwd(process_group_dict, rank)
+ check_all_reduce_fwd(process_groups_dict, rank)
+ check_all_reduce_bwd(process_groups_dict, rank)
+ flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict
+ # test all reduce in 1D flatten device mesh
+ check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank)
gpc.destroy()
diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py
index 50a3bfb15c38..3ca369acbf87 100644
--- a/tests/test_tensor/test_dtensor/test_dtensor.py
+++ b/tests/test_tensor/test_dtensor/test_dtensor.py
@@ -31,9 +31,13 @@ def check_dtensor(rank, world_size, port):
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
- d_tensor = DTensor(original_tensor, device_mesh, target_sharding_spec)
+ layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=target_sharding_spec,
+ entire_shape=original_tensor.shape)
+ d_tensor = DTensor(original_tensor, layout)
- assert d_tensor.global_shape == original_tensor.shape
+ assert d_tensor.entire_shape == original_tensor.shape
assert d_tensor.data_type == original_tensor.dtype
if rank in (0, 1):
@@ -53,7 +57,12 @@ def check_dtensor(rank, world_size, port):
raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
- d_tensor.layout_convert(device_mesh, new_sharding_spec)
+ new_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=new_sharding_spec,
+ entire_shape=original_tensor.shape)
+
+ d_tensor.layout_convert(new_layout)
if rank == 0:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
@@ -66,7 +75,7 @@ def check_dtensor(rank, world_size, port):
else:
raise ValueError(f'rank {rank} is not in the device mesh')
- dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec)
+ dtensor_from_local = distribute_tensor(original_tensor, new_layout)
if rank == 0:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py
index 6608e4787273..5c3da5f2b9ff 100644
--- a/tests/test_tensor/test_dtensor/test_layout_converter.py
+++ b/tests/test_tensor/test_dtensor/test_layout_converter.py
@@ -12,9 +12,9 @@
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
-global_shape = torch.Size((64, 32, 16))
+entire_shape = torch.Size((64, 32, 16))
layout_converter = LayoutConverter()
-physical_mesh_id = torch.arange(0, 4)
+physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
mesh_shape = (2, 2)
@@ -30,7 +30,10 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
- layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
+ layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec,
+ entire_shape=entire_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout)
@@ -46,7 +49,10 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
- layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape)
+ layout_all2all = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_all2all,
+ entire_shape=entire_shape)
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
@@ -65,7 +71,10 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4)
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
- shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape)
+ shard_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_shard,
+ entire_shape=entire_shape)
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
@@ -91,13 +100,19 @@ def check_layout_converting(rank, world_size, port):
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
- source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
+ source_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_source,
+ entire_shape=entire_shape)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
- target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
+ target_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_target,
+ entire_shape=entire_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
@@ -122,7 +137,7 @@ def check_layout_converting(rank, world_size, port):
assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1
- # checkout chached_spec_pairs_transform_path
+ # checkout cached_spec_pairs_transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
@@ -144,15 +159,21 @@ def check_layout_converting_apply(rank, world_size, port):
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
- source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
+ source_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_source,
+ entire_shape=entire_shape)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
- target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
+ target_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_target,
+ entire_shape=entire_shape)
- original_tensor = torch.rand(global_shape).cuda()
+ original_tensor = torch.rand(entire_shape).cuda()
# tensor_to_apply: [R, S01, R]
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py
index 859eef051256..6fe9ee292cd0 100644
--- a/tests/test_tensor/test_shape_consistency.py
+++ b/tests/test_tensor/test_shape_consistency.py
@@ -1,10 +1,9 @@
+from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
import torch
-
+from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
-from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
-physical_mesh_id = torch.arange(0, 16)
+physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py
index 9bd9805e9b8f..d66d4fec14d1 100644
--- a/tests/test_tensor/test_sharded_linear.py
+++ b/tests/test_tensor/test_sharded_linear.py
@@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
# the mesh is in the following topo
# [[0, 1],
# [2, 3]]
- physical_mesh_id = torch.arange(0, 4)
+ physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
row_id = rank // 2
diff --git a/tests/test_tensor/test_sharding_spec.py b/tests/test_tensor/test_sharding_spec.py
index 5007c4141849..909c84ef0f0e 100644
--- a/tests/test_tensor/test_sharding_spec.py
+++ b/tests/test_tensor/test_sharding_spec.py
@@ -5,7 +5,7 @@
def test_sharding_spec():
- physical_mesh_id = torch.arange(0, 16)
+ physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],