- [](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
+ [](https://github.com/hpcaitech/ColossalAI/stargazers)
+ [](https://github.com/hpcaitech/ColossalAI/actions/workflows/build_on_schedule.yml)
[](https://colossalai.readthedocs.io/en/latest/?badge=latest)
[](https://www.codefactor.io/repository/github/hpcaitech/colossalai)
[](https://huggingface.co/hpcai-tech)
@@ -19,15 +20,16 @@
[](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png)
- | [English](README.md) | [中文](README-zh-Hans.md) |
+ | [English](README.md) | [中文](docs/README-zh-Hans.md) |
## Latest News
-* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
+* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs)
+* [2023/02] [Open source solution replicates ChatGPT training process! Ready to go with only 1.6GB GPU memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
+* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
-* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
## Table of Contents
@@ -58,12 +60,13 @@
@@ -104,7 +107,7 @@ distributed training and inference in a few lines.
- 1D, [2D](https://arxiv.org/abs/2104.05343), [2.5D](https://arxiv.org/abs/2105.14500), [3D](https://arxiv.org/abs/2105.14450) Tensor Parallelism
- [Sequence Parallelism](https://arxiv.org/abs/2105.13120)
- [Zero Redundancy Optimizer (ZeRO)](https://arxiv.org/abs/1910.02054)
- - [Auto-Parallelism](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/auto_parallel_with_gpt)
+ - [Auto-Parallelism](https://arxiv.org/abs/2302.02599)
- Heterogeneous Memory Management
- [PatrickStar](https://arxiv.org/abs/2108.05818)
@@ -115,8 +118,6 @@ distributed training and inference in a few lines.
- Inference
- [Energon-AI](https://github.com/hpcaitech/EnergonAI)
-- Colossal-AI in the Real World
- - Biomedicine: [FastFold](https://github.com/hpcaitech/FastFold) accelerates training and inference of AlphaFold protein structure
## Parallel Training Demo
@@ -149,9 +150,9 @@ distributed training and inference in a few lines.
- [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model released by Meta, which stimulates AI programmers to perform various downstream tasks and application deployments because public pretrained model weights.
-- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/opt) [[Online Serving]](https://service.colossalai.org/opt)
+- 45% speedup fine-tuning OPT at low cost in lines. [[Example]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/opt) [[Online Serving]](https://colossalai.org/docs/advanced_tutorials/opt_service)
-Please visit our [documentation](https://www.colossalai.org/) and [examples](https://github.com/hpcaitech/ColossalAI-Examples) for more details.
+Please visit our [documentation](https://www.colossalai.org/) and [examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples) for more details.
### ViT
@@ -199,20 +200,44 @@ Please visit our [documentation](https://www.colossalai.org/) and [examples](htt
- [Energon-AI](https://github.com/hpcaitech/EnergonAI): 50% inference acceleration on the same hardware
-
+
-- [OPT Serving](https://service.colossalai.org/opt): Try 175-billion-parameter OPT online services for free, without any registration whatsoever.
+- [OPT Serving](https://colossalai.org/docs/advanced_tutorials/opt_service): Try 175-billion-parameter OPT online services
-- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): Reduce hardware deployment costs of 175-billion-parameter BLOOM by more than 10 times.
+- [BLOOM](https://github.com/hpcaitech/EnergonAI/tree/main/examples/bloom): Reduce hardware deployment costs of 176-billion-parameter BLOOM by more than 10 times.
## Colossal-AI in the Real World
+### ChatGPT
+A low-cost [ChatGPT](https://openai.com/blog/chatgpt/) equivalent implementation process. [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/applications/ChatGPT) [[blog]](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
+
+
+
+
+- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference
+
+
+
+
+
+- Up to 10.3x growth in model capacity on one GPU
+- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU)
+
+
+
+
+
+- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
+- Keep in a sufficiently high running speed
+
+
+
### AIGC
Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion).
@@ -244,7 +269,13 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
-- [FastFold](https://github.com/hpcaitech/FastFold): accelerating training and inference on GPU Clusters, faster data processing, inference sequence containing more than 10000 residues.
+- [FastFold](https://github.com/hpcaitech/FastFold): Accelerating training and inference on GPU Clusters, faster data processing, inference sequence containing more than 10000 residues.
+
+
+
+
+
+- [FastFold with Intel](https://github.com/hpcaitech/FastFold): 3x inference acceleration and 39% cost reduce.
@@ -257,10 +288,37 @@ Acceleration of [AlphaFold Protein Structure](https://alphafold.ebi.ac.uk/)
## Installation
-### Download From Official Releases
+Requirements:
+- PyTorch >= 1.11 (PyTorch 2.x in progress)
+- Python >= 3.7
+- CUDA >= 11.0
+
+If you encounter any problem about installation, you may want to raise an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) in this repository.
+
+### Install from PyPI
+
+You can easily install Colossal-AI with the following command. **By default, we do not build PyTorch extensions during installation.**
-You can visit the [Download](https://www.colossalai.org/download) page to download Colossal-AI with pre-built CUDA extensions.
+```bash
+pip install colossalai
+```
+
+**Note: only Linux is supported for now.**
+
+However, if you want to build the PyTorch extensions during installation, you can set `CUDA_EXT=1`.
+
+```bash
+CUDA_EXT=1 pip install colossalai
+```
+
+**Otherwise, CUDA kernels will be built during runtime when you actually need it.**
+
+We also keep release the nightly version to PyPI on a weekly basis. This allows you to access the unreleased features and bug fixes in the main branch.
+Installation can be made via
+```bash
+pip install colossalai-nightly
+```
### Download From Source
@@ -270,9 +328,6 @@ You can visit the [Download](https://www.colossalai.org/download) page to downlo
git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI
-# install dependency
-pip install -r requirements/requirements.txt
-
# install colossalai
pip install .
```
@@ -318,11 +373,15 @@ docker run -ti --gpus all --rm --ipc=host colossalai bash
Join the Colossal-AI community on [Forum](https://github.com/hpcaitech/ColossalAI/discussions),
[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
-and [WeChat](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your suggestions, feedback, and questions with our engineering team.
+and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your suggestions, feedback, and questions with our engineering team.
-## Contributing
+## Invitation to open-source contribution
+Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models!
-If you wish to contribute to this project, please follow the guideline in [Contributing](./CONTRIBUTING.md).
+You may contact us or participate in the following ways:
+1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks!
+2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md)
+3. Send your official proposal to email contact@hpcaitech.com
Thanks so much to all of our amazing contributors!
@@ -333,8 +392,17 @@ Thanks so much to all of our amazing contributors!
+## CI/CD
+
+We leverage the power of [GitHub Actions](https://github.com/features/actions) to automate our development, release and deployment workflows. Please check out this [documentation](.github/workflows/README.md) on how the automated workflows are operated.
+
+
## Cite Us
+This project is inspired by some related projects (some by our team and some by other organizations). We would like to credit these amazing projects as listed in the [Reference List](./docs/REFERENCE.md).
+
+To cite this project, you can use the following BibTeX citation.
+
```
@article{bian2021colossal,
title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
@@ -344,4 +412,6 @@ Thanks so much to all of our amazing contributors!
}
```
+Colossal-AI has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
+
diff --git a/applications/ChatGPT/.gitignore b/applications/ChatGPT/.gitignore
new file mode 100644
index 000000000000..40f3f6debeee
--- /dev/null
+++ b/applications/ChatGPT/.gitignore
@@ -0,0 +1,146 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+docs/.build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# IDE
+.idea/
+.vscode/
+
+# macos
+*.DS_Store
+#data/
+
+docs/.build
+
+# pytorch checkpoint
+*.pt
+
+# ignore version.py generated by setup.py
+colossalai/version.py
diff --git a/applications/ChatGPT/LICENSE b/applications/ChatGPT/LICENSE
new file mode 100644
index 000000000000..0528c89ea9ec
--- /dev/null
+++ b/applications/ChatGPT/LICENSE
@@ -0,0 +1,202 @@
+Copyright 2021- HPC-AI Technology Inc. All rights reserved.
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2021- HPC-AI Technology Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/applications/ChatGPT/README.md b/applications/ChatGPT/README.md
new file mode 100644
index 000000000000..206ede5f1843
--- /dev/null
+++ b/applications/ChatGPT/README.md
@@ -0,0 +1,209 @@
+# RLHF - Colossal-AI
+
+## Table of Contents
+
+- [What is RLHF - Colossal-AI?](#intro)
+- [How to Install?](#install)
+- [The Plan](#the-plan)
+- [How can you partcipate in open source?](#invitation-to-open-source-contribution)
+---
+## Intro
+Implementation of RLHF (Reinforcement Learning with Human Feedback) powered by Colossal-AI. It supports distributed training and offloading, which can fit extremly large models. More details can be found in the [blog](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt).
+
+
+
+
+
+## Training process (step 3)
+
+
+
+
+
+
+
+
+## Install
+```shell
+pip install .
+```
+
+## Usage
+
+The main entrypoint is `Trainer`. We only support PPO trainer now. We support many training strategies:
+
+- NaiveStrategy: simplest strategy. Train on single GPU.
+- DDPStrategy: use `torch.nn.parallel.DistributedDataParallel`. Train on multi GPUs.
+- ColossalAIStrategy: use Gemini and Zero of ColossalAI. It eliminates model duplication on each GPU and supports offload. It's very useful when training large models on multi GPUs.
+
+Simplest usage:
+
+```python
+from chatgpt.trainer import PPOTrainer
+from chatgpt.trainer.strategies import ColossalAIStrategy
+from chatgpt.models.gpt import GPTActor, GPTCritic
+from chatgpt.models.base import RewardModel
+from copy import deepcopy
+from colossalai.nn.optimizer import HybridAdam
+
+strategy = ColossalAIStrategy()
+
+with strategy.model_init_context():
+ # init your model here
+ # load pretrained gpt2
+ actor = GPTActor(pretrained='gpt2')
+ critic = GPTCritic()
+ initial_model = deepcopy(actor).cuda()
+ reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
+
+actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
+critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
+
+# prepare models and optimizers
+(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
+
+# load saved model checkpoint after preparing
+strategy.load_model(actor, 'actor_checkpoint.pt', strict=False)
+# load saved optimizer checkpoint after preparing
+strategy.load_optimizer(actor_optim, 'actor_optim_checkpoint.pt')
+
+trainer = PPOTrainer(strategy,
+ actor,
+ critic,
+ reward_model,
+ initial_model,
+ actor_optim,
+ critic_optim,
+ ...)
+
+trainer.fit(dataset, ...)
+
+# save model checkpoint after fitting on only rank0
+strategy.save_model(actor, 'actor_checkpoint.pt', only_rank0=True)
+# save optimizer checkpoint on all ranks
+strategy.save_optimizer(actor_optim, 'actor_optim_checkpoint.pt', only_rank0=False)
+```
+
+For more details, see `examples/`.
+
+We also support training reward model with true-world data. See `examples/train_reward_model.py`.
+
+## FAQ
+
+### How to save/load checkpoint
+
+To load pretrained model, you can simply use huggingface pretrained models:
+
+```python
+# load OPT-350m pretrained model
+actor = OPTActor(pretrained='facebook/opt-350m')
+```
+
+To save model checkpoint:
+
+```python
+# save model checkpoint on only rank0
+strategy.save_model(actor, 'actor_checkpoint.pt', only_rank0=True)
+```
+
+This function must be called after `strategy.prepare()`.
+
+For DDP strategy, model weights are replicated on all ranks. And for ColossalAI strategy, model weights may be sharded, but all-gather will be applied before returning state dict. You can set `only_rank0=True` for both of them, which only saves checkpoint on rank0, to save disk space usage. The checkpoint is float32.
+
+To save optimizer checkpoint:
+
+```python
+# save optimizer checkpoint on all ranks
+strategy.save_optimizer(actor_optim, 'actor_optim_checkpoint.pt', only_rank0=False)
+```
+
+For DDP strategy, optimizer states are replicated on all ranks. You can set `only_rank0=True`. But for ColossalAI strategy, optimizer states are sharded over all ranks, and no all-gather will be applied. So for ColossalAI strategy, you can only set `only_rank0=False`. That is to say, each rank will save a cehckpoint. When loading, each rank should load the corresponding part.
+
+Note that different stategy may have different shapes of optimizer checkpoint.
+
+To load model checkpoint:
+
+```python
+# load saved model checkpoint after preparing
+strategy.load_model(actor, 'actor_checkpoint.pt', strict=False)
+```
+
+To load optimizer checkpoint:
+
+```python
+# load saved optimizer checkpoint after preparing
+strategy.load_optimizer(actor_optim, 'actor_optim_checkpoint.pt')
+```
+
+## The Plan
+
+- [x] implement PPO fine-tuning
+- [x] implement training reward model
+- [x] support LoRA
+- [x] support inference
+- [ ] open source the reward model weight
+- [ ] support llama from [facebook](https://github.com/facebookresearch/llama)
+- [ ] support BoN(best of N sample)
+- [ ] implement PPO-ptx fine-tuning
+- [ ] integrate with Ray
+- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
+- [ ] support chain of throught by [langchain](https://github.com/hwchase17/langchain)
+
+### Real-time progress
+You will find our progress in github project broad
+
+[Open ChatGPT](https://github.com/orgs/hpcaitech/projects/17/views/1)
+
+## Invitation to open-source contribution
+Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models from the starting point of replicating ChatGPT!
+
+You may contact us or participate in the following ways:
+1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks!
+2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md).
+3. Join the Colossal-AI community on
+[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
+and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
+4. Send your official proposal to email contact@hpcaitech.com
+
+Thanks so much to all of our amazing contributors!
+
+## Quick Preview
+
+
+
+
+- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference
+
+
+
+
+
+- Up to 10.3x growth in model capacity on one GPU
+- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU)
+
+
+
+
+
+- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
+- Keep in a sufficiently high running speed
+
+## Citations
+
+```bibtex
+@article{Hu2021LoRALA,
+ title = {LoRA: Low-Rank Adaptation of Large Language Models},
+ author = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen},
+ journal = {ArXiv},
+ year = {2021},
+ volume = {abs/2106.09685}
+}
+
+@article{ouyang2022training,
+ title={Training language models to follow instructions with human feedback},
+ author={Ouyang, Long and Wu, Jeff and Jiang, Xu and Almeida, Diogo and Wainwright, Carroll L and Mishkin, Pamela and Zhang, Chong and Agarwal, Sandhini and Slama, Katarina and Ray, Alex and others},
+ journal={arXiv preprint arXiv:2203.02155},
+ year={2022}
+}
+```
diff --git a/applications/ChatGPT/benchmarks/README.md b/applications/ChatGPT/benchmarks/README.md
new file mode 100644
index 000000000000..b4e28ba1d764
--- /dev/null
+++ b/applications/ChatGPT/benchmarks/README.md
@@ -0,0 +1,94 @@
+# Benchmarks
+
+## Benchmark GPT on dummy prompt data
+
+We provide various GPT models (string in parentheses is the corresponding model name used in this script):
+
+- GPT2-S (s)
+- GPT2-M (m)
+- GPT2-L (l)
+- GPT2-XL (xl)
+- GPT2-4B (4b)
+- GPT2-6B (6b)
+- GPT2-8B (8b)
+- GPT2-10B (10b)
+- GPT2-12B (12b)
+- GPT2-15B (15b)
+- GPT2-18B (18b)
+- GPT2-20B (20b)
+- GPT2-24B (24b)
+- GPT2-28B (28b)
+- GPT2-32B (32b)
+- GPT2-36B (36b)
+- GPT2-40B (40b)
+- GPT3 (175b)
+
+We also provide various training strategies:
+
+- ddp: torch DDP
+- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3
+- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload
+- colossalai_zero2: ColossalAI zero2
+- colossalai_zero2_cpu: ColossalAI zero2-offload
+- colossalai_zero1: ColossalAI zero1
+- colossalai_zero1_cpu: ColossalAI zero1-offload
+
+We only support `torchrun` to launch now. E.g.
+
+```shell
+# run GPT2-S on single-node single-GPU with min batch size
+torchrun --standalone --nproc_per_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1
+# run GPT2-XL on single-node 4-GPU
+torchrun --standalone --nproc_per_node 4 benchmark_gpt_dummy.py --model xl --strategy colossalai_zero2
+# run GPT3 on 8-node 8-GPU
+torchrun --nnodes 8 --nproc_per_node 8 \
+ --rdzv_id=$JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR \
+ benchmark_gpt_dummy.py --model 175b --strategy colossalai_gemini
+```
+
+> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
+
+In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
+
+We also provide a simple shell script to run a set of benchmarks. But it only supports benchmark on single node. However, it's easy to run on multi-nodes by modifying launch command in this script.
+
+Usage:
+
+```shell
+# run for GPUS=(1 2 4 8) x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
+./benchmark_gpt_dummy.sh
+# run for GPUS=2 x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
+./benchmark_gpt_dummy.sh 2
+# run for GPUS=2 x strategy=ddp x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
+./benchmark_gpt_dummy.sh 2 ddp
+# run for GPUS=2 x strategy=ddp x model=l x batch_size=(1 2 4 8 16 32 64 128 256)
+./benchmark_gpt_dummy.sh 2 ddp l
+```
+
+## Benchmark OPT with LoRA on dummy prompt data
+
+We provide various OPT models (string in parentheses is the corresponding model name used in this script):
+
+- OPT-125M (125m)
+- OPT-350M (350m)
+- OPT-700M (700m)
+- OPT-1.3B (1.3b)
+- OPT-2.7B (2.7b)
+- OPT-3.5B (3.5b)
+- OPT-5.5B (5.5b)
+- OPT-6.7B (6.7b)
+- OPT-10B (10b)
+- OPT-13B (13b)
+
+We only support `torchrun` to launch now. E.g.
+
+```shell
+# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size
+torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0
+# run OPT-350M with lora_rank=4 on single-node 4-GPU
+torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 350m --strategy colossalai_zero2 --lora_rank 4
+```
+
+> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
+
+In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
diff --git a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py
new file mode 100644
index 000000000000..5ee65763b936
--- /dev/null
+++ b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.py
@@ -0,0 +1,184 @@
+import argparse
+from copy import deepcopy
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from chatgpt.models.base import RewardModel
+from chatgpt.models.gpt import GPTActor, GPTCritic
+from chatgpt.trainer import PPOTrainer
+from chatgpt.trainer.callbacks import PerformanceEvaluator
+from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
+from torch.optim import Adam
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+
+from colossalai.nn.optimizer import HybridAdam
+
+
+def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
+ numel = sum(p.numel() for p in model.parameters())
+ if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
+ numel *= dist.get_world_size()
+ return numel
+
+
+def preprocess_batch(samples) -> dict:
+ input_ids = torch.stack(samples)
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
+
+
+def print_rank_0(*args, **kwargs) -> None:
+ if dist.get_rank() == 0:
+ print(*args, **kwargs)
+
+
+def print_model_numel(model_dict: dict) -> None:
+ B = 1024**3
+ M = 1024**2
+ K = 1024
+ outputs = ''
+ for name, numel in model_dict.items():
+ outputs += f'{name}: '
+ if numel >= B:
+ outputs += f'{numel / B:.2f} B\n'
+ elif numel >= M:
+ outputs += f'{numel / M:.2f} M\n'
+ elif numel >= K:
+ outputs += f'{numel / K:.2f} K\n'
+ else:
+ outputs += f'{numel}\n'
+ print_rank_0(outputs)
+
+
+def get_gpt_config(model_name: str) -> GPT2Config:
+ model_map = {
+ 's': GPT2Config(),
+ 'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16),
+ 'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20),
+ 'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25),
+ '2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16),
+ '4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16),
+ '6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16),
+ '8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16),
+ '10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16),
+ '12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16),
+ '15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16),
+ '18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16),
+ '20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16),
+ '24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16),
+ '28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16),
+ '32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16),
+ '36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16),
+ '40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16),
+ '175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96),
+ }
+ try:
+ return model_map[model_name]
+ except KeyError:
+ raise ValueError(f'Unknown model "{model_name}"')
+
+
+def main(args):
+ if args.strategy == 'ddp':
+ strategy = DDPStrategy()
+ elif args.strategy == 'colossalai_gemini':
+ strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
+ elif args.strategy == 'colossalai_gemini_cpu':
+ strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
+ elif args.strategy == 'colossalai_zero2':
+ strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == 'colossalai_zero2_cpu':
+ strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
+ elif args.strategy == 'colossalai_zero1':
+ strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
+ elif args.strategy == 'colossalai_zero1_cpu':
+ strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
+ else:
+ raise ValueError(f'Unsupported strategy "{args.strategy}"')
+
+ model_config = get_gpt_config(args.model)
+
+ with strategy.model_init_context():
+ actor = GPTActor(config=model_config).cuda()
+ critic = GPTCritic(config=model_config).cuda()
+
+ initial_model = deepcopy(actor).cuda()
+ reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
+
+ actor_numel = get_model_numel(actor, strategy)
+ critic_numel = get_model_numel(critic, strategy)
+ initial_model_numel = get_model_numel(initial_model, strategy)
+ reward_model_numel = get_model_numel(reward_model, strategy)
+ print_model_numel({
+ 'Actor': actor_numel,
+ 'Critic': critic_numel,
+ 'Initial model': initial_model_numel,
+ 'Reward model': reward_model_numel
+ })
+ performance_evaluator = PerformanceEvaluator(actor_numel,
+ critic_numel,
+ initial_model_numel,
+ reward_model_numel,
+ enable_grad_checkpoint=False,
+ ignore_episodes=1)
+
+ if args.strategy.startswith('colossalai'):
+ actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
+ critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
+ else:
+ actor_optim = Adam(actor.parameters(), lr=5e-6)
+ critic_optim = Adam(critic.parameters(), lr=5e-6)
+
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.eos_token
+
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
+
+ trainer = PPOTrainer(strategy,
+ actor,
+ critic,
+ reward_model,
+ initial_model,
+ actor_optim,
+ critic_optim,
+ max_epochs=args.max_epochs,
+ train_batch_size=args.train_batch_size,
+ experience_batch_size=args.experience_batch_size,
+ tokenizer=preprocess_batch,
+ max_length=512,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ callbacks=[performance_evaluator])
+
+ random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
+ trainer.fit(random_prompts,
+ num_episodes=args.num_episodes,
+ max_timesteps=args.max_timesteps,
+ update_timesteps=args.update_timesteps)
+
+ print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model', default='s')
+ parser.add_argument('--strategy',
+ choices=[
+ 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
+ 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
+ ],
+ default='ddp')
+ parser.add_argument('--num_episodes', type=int, default=3)
+ parser.add_argument('--max_timesteps', type=int, default=8)
+ parser.add_argument('--update_timesteps', type=int, default=8)
+ parser.add_argument('--max_epochs', type=int, default=3)
+ parser.add_argument('--train_batch_size', type=int, default=8)
+ parser.add_argument('--experience_batch_size', type=int, default=8)
+ args = parser.parse_args()
+ main(args)
diff --git a/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.sh b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.sh
new file mode 100755
index 000000000000..d70f8872570a
--- /dev/null
+++ b/applications/ChatGPT/benchmarks/benchmark_gpt_dummy.sh
@@ -0,0 +1,45 @@
+#!/usr/bin/env bash
+# Usage: $0
+set -xu
+
+BASE=$(realpath $(dirname $0))
+
+
+PY_SCRIPT=${BASE}/benchmark_gpt_dummy.py
+export OMP_NUM_THREADS=8
+
+function tune_batch_size() {
+ # we found when experience batch size is equal to train batch size
+ # peak CUDA memory usage of making experience phase is less than or equal to that of training phase
+ # thus, experience batch size can be larger than or equal to train batch size
+ for bs in 1 2 4 8 16 32 64 128 256; do
+ torchrun --standalone --nproc_per_node $1 $PY_SCRIPT --model $2 --strategy $3 --experience_batch_size $bs --train_batch_size $bs || return 1
+ done
+}
+
+if [ $# -eq 0 ]; then
+ num_gpus=(1 2 4 8)
+else
+ num_gpus=($1)
+fi
+
+if [ $# -le 1 ]; then
+ strategies=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu")
+else
+ strategies=($2)
+fi
+
+if [ $# -le 2 ]; then
+ models=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b")
+else
+ models=($3)
+fi
+
+
+for num_gpu in ${num_gpus[@]}; do
+ for strategy in ${strategies[@]}; do
+ for model in ${models[@]}; do
+ tune_batch_size $num_gpu $model $strategy || break
+ done
+ done
+done
diff --git a/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py
new file mode 100644
index 000000000000..207edbca94b5
--- /dev/null
+++ b/applications/ChatGPT/benchmarks/benchmark_opt_lora_dummy.py
@@ -0,0 +1,179 @@
+import argparse
+from copy import deepcopy
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from chatgpt.models.base import RewardModel
+from chatgpt.models.opt import OPTActor, OPTCritic
+from chatgpt.trainer import PPOTrainer
+from chatgpt.trainer.callbacks import PerformanceEvaluator
+from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
+from torch.optim import Adam
+from transformers import AutoTokenizer
+from transformers.models.opt.configuration_opt import OPTConfig
+
+from colossalai.nn.optimizer import HybridAdam
+
+
+def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
+ numel = sum(p.numel() for p in model.parameters())
+ if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
+ numel *= dist.get_world_size()
+ return numel
+
+
+def preprocess_batch(samples) -> dict:
+ input_ids = torch.stack(samples)
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
+
+
+def print_rank_0(*args, **kwargs) -> None:
+ if dist.get_rank() == 0:
+ print(*args, **kwargs)
+
+
+def print_model_numel(model_dict: dict) -> None:
+ B = 1024**3
+ M = 1024**2
+ K = 1024
+ outputs = ''
+ for name, numel in model_dict.items():
+ outputs += f'{name}: '
+ if numel >= B:
+ outputs += f'{numel / B:.2f} B\n'
+ elif numel >= M:
+ outputs += f'{numel / M:.2f} M\n'
+ elif numel >= K:
+ outputs += f'{numel / K:.2f} K\n'
+ else:
+ outputs += f'{numel}\n'
+ print_rank_0(outputs)
+
+
+def get_gpt_config(model_name: str) -> OPTConfig:
+ model_map = {
+ '125m': OPTConfig.from_pretrained('facebook/opt-125m'),
+ '350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
+ '700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
+ '1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
+ '2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
+ '3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
+ '5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
+ '6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
+ '10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
+ '13b': OPTConfig.from_pretrained('facebook/opt-13b'),
+ }
+ try:
+ return model_map[model_name]
+ except KeyError:
+ raise ValueError(f'Unknown model "{model_name}"')
+
+
+def main(args):
+ if args.strategy == 'ddp':
+ strategy = DDPStrategy()
+ elif args.strategy == 'colossalai_gemini':
+ strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
+ elif args.strategy == 'colossalai_gemini_cpu':
+ strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
+ elif args.strategy == 'colossalai_zero2':
+ strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ elif args.strategy == 'colossalai_zero2_cpu':
+ strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
+ elif args.strategy == 'colossalai_zero1':
+ strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
+ elif args.strategy == 'colossalai_zero1_cpu':
+ strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
+ else:
+ raise ValueError(f'Unsupported strategy "{args.strategy}"')
+
+ torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac)
+
+ model_config = get_gpt_config(args.model)
+
+ with strategy.model_init_context():
+ actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda()
+ critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda()
+
+ initial_model = deepcopy(actor).cuda()
+ reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
+
+ actor_numel = get_model_numel(actor, strategy)
+ critic_numel = get_model_numel(critic, strategy)
+ initial_model_numel = get_model_numel(initial_model, strategy)
+ reward_model_numel = get_model_numel(reward_model, strategy)
+ print_model_numel({
+ 'Actor': actor_numel,
+ 'Critic': critic_numel,
+ 'Initial model': initial_model_numel,
+ 'Reward model': reward_model_numel
+ })
+ performance_evaluator = PerformanceEvaluator(actor_numel,
+ critic_numel,
+ initial_model_numel,
+ reward_model_numel,
+ enable_grad_checkpoint=False,
+ ignore_episodes=1)
+
+ if args.strategy.startswith('colossalai'):
+ actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
+ critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
+ else:
+ actor_optim = Adam(actor.parameters(), lr=5e-6)
+ critic_optim = Adam(critic.parameters(), lr=5e-6)
+
+ tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
+ tokenizer.pad_token = tokenizer.eos_token
+
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
+
+ trainer = PPOTrainer(strategy,
+ actor,
+ critic,
+ reward_model,
+ initial_model,
+ actor_optim,
+ critic_optim,
+ max_epochs=args.max_epochs,
+ train_batch_size=args.train_batch_size,
+ experience_batch_size=args.experience_batch_size,
+ tokenizer=preprocess_batch,
+ max_length=512,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ callbacks=[performance_evaluator])
+
+ random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
+ trainer.fit(random_prompts,
+ num_episodes=args.num_episodes,
+ max_timesteps=args.max_timesteps,
+ update_timesteps=args.update_timesteps)
+
+ print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model', default='125m')
+ parser.add_argument('--strategy',
+ choices=[
+ 'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
+ 'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
+ ],
+ default='ddp')
+ parser.add_argument('--num_episodes', type=int, default=3)
+ parser.add_argument('--max_timesteps', type=int, default=8)
+ parser.add_argument('--update_timesteps', type=int, default=8)
+ parser.add_argument('--max_epochs', type=int, default=3)
+ parser.add_argument('--train_batch_size', type=int, default=8)
+ parser.add_argument('--experience_batch_size', type=int, default=8)
+ parser.add_argument('--lora_rank', type=int, default=4)
+ parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
+ args = parser.parse_args()
+ main(args)
diff --git a/examples/tutorial/stable_diffusion/ldm/data/__init__.py b/applications/ChatGPT/chatgpt/__init__.py
similarity index 100%
rename from examples/tutorial/stable_diffusion/ldm/data/__init__.py
rename to applications/ChatGPT/chatgpt/__init__.py
diff --git a/applications/ChatGPT/chatgpt/dataset/__init__.py b/applications/ChatGPT/chatgpt/dataset/__init__.py
new file mode 100644
index 000000000000..df484f46d24c
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/dataset/__init__.py
@@ -0,0 +1,5 @@
+from .reward_dataset import RmStaticDataset, HhRlhfDataset
+from .utils import is_rank_0
+from .sft_dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
+
+__all__ = ['RmStaticDataset', 'HhRlhfDataset','is_rank_0', 'SFTDataset', 'AlpacaDataset', 'AlpacaDataCollator']
diff --git a/applications/ChatGPT/chatgpt/dataset/reward_dataset.py b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py
new file mode 100644
index 000000000000..9ee13490b893
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/dataset/reward_dataset.py
@@ -0,0 +1,109 @@
+from typing import Callable
+
+from torch.utils.data import Dataset
+from tqdm import tqdm
+
+from .utils import is_rank_0
+
+# Dahaos/rm-static
+class RmStaticDataset(Dataset):
+ """
+ Dataset for reward model
+
+ Args:
+ dataset: dataset for reward model
+ tokenizer: tokenizer for reward model
+ max_length: max length of input
+ special_token: special token at the end of sentence
+ """
+
+ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
+ super().__init__()
+ self.chosen = []
+ self.reject = []
+ if special_token is None:
+ self.end_token = tokenizer.eos_token
+ else:
+ self.end_token = special_token
+ for data in tqdm(dataset, disable=not is_rank_0()):
+ prompt = data['prompt']
+
+ chosen = prompt + data['chosen'] + self.end_token
+ chosen_token = tokenizer(chosen,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+ self.chosen.append({
+ "input_ids": chosen_token['input_ids'],
+ "attention_mask": chosen_token['attention_mask']
+ })
+
+ reject = prompt + data['rejected'] + self.end_token
+ reject_token = tokenizer(reject,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+ self.reject.append({
+ "input_ids": reject_token['input_ids'],
+ "attention_mask": reject_token['attention_mask']
+ })
+
+ def __len__(self):
+ length = len(self.chosen)
+ return length
+
+ def __getitem__(self, idx):
+ return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
+ "input_ids"], self.reject[idx]["attention_mask"]
+
+# Anthropic/hh-rlhf
+class HhRlhfDataset(Dataset):
+ """
+ Dataset for reward model
+
+ Args:
+ dataset: dataset for reward model
+ tokenizer: tokenizer for reward model
+ max_length: max length of input
+ special_token: special token at the end of sentence
+ """
+ def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
+ super().__init__()
+ self.chosen = []
+ self.reject = []
+ if special_token is None:
+ self.end_token = tokenizer.eos_token
+ else:
+ self.end_token = special_token
+ for data in tqdm(dataset, disable=not is_rank_0()):
+ chosen = data['chosen'] + self.end_token
+ chosen_token = tokenizer(chosen,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+ self.chosen.append({
+ "input_ids": chosen_token['input_ids'],
+ "attention_mask": chosen_token['attention_mask']
+ })
+
+ reject = data['rejected'] + self.end_token
+ reject_token = tokenizer(reject,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+ self.reject.append({
+ "input_ids": reject_token['input_ids'],
+ "attention_mask": reject_token['attention_mask']
+ })
+
+ def __len__(self):
+ length = len(self.chosen)
+ return length
+
+ def __getitem__(self, idx):
+ return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
+ "input_ids"], self.reject[idx]["attention_mask"]
diff --git a/applications/ChatGPT/chatgpt/dataset/sft_dataset.py b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py
new file mode 100644
index 000000000000..11ec61908aef
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/dataset/sft_dataset.py
@@ -0,0 +1,163 @@
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+from dataclasses import dataclass, field
+from typing import Callable, Dict, Sequence
+import random
+from torch.utils.data import Dataset
+import torch.distributed as dist
+from tqdm import tqdm
+import torch
+
+from .utils import is_rank_0, jload
+
+import transformers
+from colossalai.logging import get_dist_logger
+
+logger = get_dist_logger()
+
+IGNORE_INDEX = -100
+PROMPT_DICT = {
+ "prompt_input": (
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
+ ),
+ "prompt_no_input": (
+ "Below is an instruction that describes a task. "
+ "Write a response that appropriately completes the request.\n\n"
+ "### Instruction:\n{instruction}\n\n### Response:"
+ ),
+}
+
+class SFTDataset(Dataset):
+ """
+ Dataset for sft model
+
+ Args:
+ dataset: dataset for supervised model
+ tokenizer: tokenizer for supervised model
+ max_length: max length of input
+ """
+
+ def __init__(self, dataset, tokenizer: Callable, max_length: int=512) -> None:
+ super().__init__()
+ # self.prompts = []
+ self.input_ids = []
+
+ for data in tqdm(dataset, disable=not is_rank_0()):
+ prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
+ prompt_token = tokenizer(prompt,
+ max_length=max_length,
+ padding="max_length",
+ truncation=True,
+ return_tensors="pt")
+
+ # self.prompts.append(prompt_token)s
+ self.input_ids.append(prompt_token)
+ self.labels = copy.deepcopy(self.input_ids)
+
+ def __len__(self):
+ length = len(self.prompts)
+ return length
+
+ def __getitem__(self, idx):
+ # dict(input_ids=self.input_ids[i], labels=self.labels[i])
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
+ # return dict(self.prompts[idx], self.prompts[idx])
+
+
+def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ )
+ for text in strings
+ ]
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+def preprocess(
+ sources: Sequence[str],
+ targets: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ """Preprocess the data by tokenizing."""
+ examples = [s + t for s, t in zip(sources, targets)]
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
+ input_ids = examples_tokenized["input_ids"]
+ labels = copy.deepcopy(input_ids)
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
+ label[:source_len] = IGNORE_INDEX
+ return dict(input_ids=input_ids, labels=labels)
+
+class AlpacaDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
+ super(AlpacaDataset, self).__init__()
+ logger.info("Loading data...")
+ list_data_dict = jload(data_path)
+
+ logger.info("Formatting inputs...")
+ prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
+ sources = [
+ prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
+ for example in list_data_dict
+ ]
+ targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
+
+ logger.info("Tokenizing inputs... This may take some time...")
+ data_dict = preprocess(sources, targets, tokenizer)
+
+ self.input_ids = data_dict["input_ids"]
+ self.labels = data_dict["labels"]
+
+ def __len__(self):
+ return len(self.input_ids)
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ return dict(input_ids=self.input_ids[i], labels=self.labels[i])
+
+@dataclass
+class AlpacaDataCollator(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
diff --git a/applications/ChatGPT/chatgpt/dataset/utils.py b/applications/ChatGPT/chatgpt/dataset/utils.py
new file mode 100644
index 000000000000..0e88cc8c39b4
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/dataset/utils.py
@@ -0,0 +1,20 @@
+import io
+import json
+
+import torch.distributed as dist
+
+
+def is_rank_0() -> bool:
+ return not dist.is_initialized() or dist.get_rank() == 0
+
+def _make_r_io_base(f, mode: str):
+ if not isinstance(f, io.IOBase):
+ f = open(f, mode=mode)
+ return f
+
+def jload(f, mode="r"):
+ """Load a .json file into a dictionary."""
+ f = _make_r_io_base(f, mode)
+ jdict = json.load(f)
+ f.close()
+ return jdict
\ No newline at end of file
diff --git a/applications/ChatGPT/chatgpt/experience_maker/__init__.py b/applications/ChatGPT/chatgpt/experience_maker/__init__.py
new file mode 100644
index 000000000000..39ca7576b227
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/experience_maker/__init__.py
@@ -0,0 +1,4 @@
+from .base import Experience, ExperienceMaker
+from .naive import NaiveExperienceMaker
+
+__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
diff --git a/applications/ChatGPT/chatgpt/experience_maker/base.py b/applications/ChatGPT/chatgpt/experience_maker/base.py
new file mode 100644
index 000000000000..f3640fc1e496
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/experience_maker/base.py
@@ -0,0 +1,77 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from chatgpt.models.base import Actor
+
+
+@dataclass
+class Experience:
+ """Experience is a batch of data.
+ These data should have the the sequence length and number of actions.
+ Left padding for sequences is applied.
+
+ Shapes of each tensor:
+ sequences: (B, S)
+ action_log_probs: (B, A)
+ values: (B)
+ reward: (B)
+ advatanges: (B)
+ attention_mask: (B, S)
+ action_mask: (B, A)
+
+ "A" is the number of actions.
+ """
+ sequences: torch.Tensor
+ action_log_probs: torch.Tensor
+ values: torch.Tensor
+ reward: torch.Tensor
+ advantages: torch.Tensor
+ attention_mask: Optional[torch.LongTensor]
+ action_mask: Optional[torch.BoolTensor]
+
+ @torch.no_grad()
+ def to_device(self, device: torch.device) -> None:
+ self.sequences = self.sequences.to(device)
+ self.action_log_probs = self.action_log_probs.to(device)
+ self.values = self.values.to(device)
+ self.reward = self.reward.to(device)
+ self.advantages = self.advantages.to(device)
+ if self.attention_mask is not None:
+ self.attention_mask = self.attention_mask.to(device)
+ if self.action_mask is not None:
+ self.action_mask = self.action_mask.to(device)
+
+ def pin_memory(self):
+ self.sequences = self.sequences.pin_memory()
+ self.action_log_probs = self.action_log_probs.pin_memory()
+ self.values = self.values.pin_memory()
+ self.reward = self.reward.pin_memory()
+ self.advantages = self.advantages.pin_memory()
+ if self.attention_mask is not None:
+ self.attention_mask = self.attention_mask.pin_memory()
+ if self.action_mask is not None:
+ self.action_mask = self.action_mask.pin_memory()
+ return self
+
+
+class ExperienceMaker(ABC):
+
+ def __init__(self,
+ actor: Actor,
+ critic: nn.Module,
+ reward_model: nn.Module,
+ initial_model: Actor,
+ kl_coef: float = 0.1) -> None:
+ super().__init__()
+ self.actor = actor
+ self.critic = critic
+ self.reward_model = reward_model
+ self.initial_model = initial_model
+ self.kl_coef = kl_coef
+
+ @abstractmethod
+ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
+ pass
diff --git a/applications/ChatGPT/chatgpt/experience_maker/naive.py b/applications/ChatGPT/chatgpt/experience_maker/naive.py
new file mode 100644
index 000000000000..64835cfa1918
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/experience_maker/naive.py
@@ -0,0 +1,36 @@
+import torch
+from chatgpt.models.utils import compute_reward, normalize
+
+from .base import Experience, ExperienceMaker
+
+
+class NaiveExperienceMaker(ExperienceMaker):
+ """
+ Naive experience maker.
+ """
+
+ @torch.no_grad()
+ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
+ self.actor.eval()
+ self.critic.eval()
+ self.initial_model.eval()
+ self.reward_model.eval()
+
+ sequences, attention_mask, action_mask = self.actor.generate(input_ids,
+ return_action_mask=True,
+ **generate_kwargs)
+ num_actions = action_mask.size(1)
+
+ action_log_probs = self.actor(sequences, num_actions, attention_mask)
+ base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
+ value = self.critic(sequences, action_mask, attention_mask)
+ r = self.reward_model(sequences, attention_mask)
+
+ reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
+
+ advantage = reward - value
+ # TODO(ver217): maybe normalize adv
+ if advantage.ndim == 1:
+ advantage = advantage.unsqueeze(-1)
+
+ return Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask)
diff --git a/applications/ChatGPT/chatgpt/models/__init__.py b/applications/ChatGPT/chatgpt/models/__init__.py
new file mode 100644
index 000000000000..b274188a21df
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/__init__.py
@@ -0,0 +1,4 @@
+from .base import Actor, Critic, RewardModel
+from .loss import PolicyLoss, PPOPtxActorLoss, ValueLoss, LogSigLoss, LogExpLoss
+
+__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
diff --git a/applications/ChatGPT/chatgpt/models/base/__init__.py b/applications/ChatGPT/chatgpt/models/base/__init__.py
new file mode 100644
index 000000000000..7c7b1ceba257
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/base/__init__.py
@@ -0,0 +1,6 @@
+from .actor import Actor
+from .critic import Critic
+from .reward_model import RewardModel
+from .lm import LM
+
+__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
diff --git a/applications/ChatGPT/chatgpt/models/base/actor.py b/applications/ChatGPT/chatgpt/models/base/actor.py
new file mode 100644
index 000000000000..57db2bb11a6a
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/base/actor.py
@@ -0,0 +1,62 @@
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..generation import generate
+from ..lora import LoRAModule
+from ..utils import log_probs_from_logits
+
+
+class Actor(LoRAModule):
+ """
+ Actor model base class.
+
+ Args:
+ model (nn.Module): Actor Model.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
+ super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
+ self.model = model
+ self.convert_to_lora()
+
+ @torch.no_grad()
+ def generate(
+ self,
+ input_ids: torch.Tensor,
+ return_action_mask: bool = True,
+ **kwargs
+ ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
+ sequences = generate(self.model, input_ids, **kwargs)
+ attention_mask = None
+ pad_token_id = kwargs.get('pad_token_id', None)
+ if pad_token_id is not None:
+ attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
+ if not return_action_mask:
+ return sequences, attention_mask, None
+ input_len = input_ids.size(1)
+ eos_token_id = kwargs.get('eos_token_id', None)
+ if eos_token_id is None:
+ action_mask = torch.ones_like(sequences, dtype=torch.bool)
+ else:
+ # left padding may be applied, only mask action
+ action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
+ action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
+ action_mask[:, :input_len] = False
+ action_mask = action_mask[:, 1:]
+ return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
+
+ def forward(self,
+ sequences: torch.LongTensor,
+ num_actions: int,
+ attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """Returns action log probs
+ """
+ output = self.model(sequences, attention_mask=attention_mask)
+ logits = output['logits']
+ log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
+ return log_probs[:, -num_actions:]
diff --git a/applications/ChatGPT/chatgpt/models/base/critic.py b/applications/ChatGPT/chatgpt/models/base/critic.py
new file mode 100644
index 000000000000..e68a743a7762
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/base/critic.py
@@ -0,0 +1,54 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from ..lora import LoRAModule
+from ..utils import masked_mean
+
+
+class Critic(LoRAModule):
+ """
+ Critic model base class.
+
+ Args:
+ model (nn.Module): Critic model.
+ value_head (nn.Module): Value head to get value.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ value_head: nn.Module,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none',
+ use_action_mask: bool = False,
+ ) -> None:
+
+ super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
+ self.model = model
+ self.value_head = value_head
+ self.use_action_mask = use_action_mask
+ self.convert_to_lora()
+
+ def forward(self,
+ sequences: torch.LongTensor,
+ action_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ outputs = self.model(sequences, attention_mask=attention_mask)
+ last_hidden_states = outputs['last_hidden_state']
+
+ values = self.value_head(last_hidden_states).squeeze(-1)
+
+ if action_mask is not None and self.use_action_mask:
+ num_actions = action_mask.size(1)
+ prompt_mask = attention_mask[:, :-num_actions]
+ values = values[:, :-num_actions]
+ value = masked_mean(values, prompt_mask, dim=1)
+ return value
+
+ values = values[:, :-1]
+ value = values.mean(dim=1)
+ return value
diff --git a/applications/ChatGPT/chatgpt/models/base/lm.py b/applications/ChatGPT/chatgpt/models/base/lm.py
new file mode 100644
index 000000000000..b6bd7aff8315
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/base/lm.py
@@ -0,0 +1,33 @@
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..generation import generate
+from .actor import Actor
+
+
+class LM(Actor):
+ """
+ Language model base class.
+
+ Args:
+ model (nn.Module): Language Model.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
+ super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias)
+
+ def forward(self,
+ sequences: torch.LongTensor,
+ attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """Returns output log probs
+ """
+ output = self.model(sequences, attention_mask=attention_mask)
+ logits = output['logits']
+ log_probs = F.log_softmax(logits, dim=-1)
+ return log_probs
+
diff --git a/applications/ChatGPT/chatgpt/models/base/reward_model.py b/applications/ChatGPT/chatgpt/models/base/reward_model.py
new file mode 100644
index 000000000000..ce8c0a1d3568
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/base/reward_model.py
@@ -0,0 +1,41 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from ..lora import LoRAModule
+
+
+class RewardModel(LoRAModule):
+ """
+ Reward model base class.
+
+ Args:
+ model (nn.Module): Reward model.
+ value_head (nn.Module): Value head to get reward score.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ model: nn.Module,
+ value_head: Optional[nn.Module] = None,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
+ self.model = model
+ self.convert_to_lora()
+
+ if value_head is not None:
+ if value_head.out_features != 1:
+ raise ValueError("The value head of reward model's output dim should be 1!")
+ self.value_head = value_head
+ else:
+ self.value_head = nn.Linear(model.config.n_embd, 1)
+
+ def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ outputs = self.model(sequences, attention_mask=attention_mask)
+ last_hidden_states = outputs['last_hidden_state']
+ values = self.value_head(last_hidden_states)[:, :-1]
+ value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
+ return value
diff --git a/applications/ChatGPT/chatgpt/models/bloom/__init__.py b/applications/ChatGPT/chatgpt/models/bloom/__init__.py
new file mode 100644
index 000000000000..7d6d7753bb9a
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/bloom/__init__.py
@@ -0,0 +1,6 @@
+from .bloom_actor import BLOOMActor
+from .bloom_critic import BLOOMCritic
+from .bloom_rm import BLOOMRM
+from .bloom_lm import BLOOMLM
+
+__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']
diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_actor.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_actor.py
new file mode 100644
index 000000000000..d7577f096493
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_actor.py
@@ -0,0 +1,35 @@
+from typing import Optional
+
+import torch
+from transformers import BloomConfig, BloomForCausalLM, BloomModel
+
+from ..base import Actor
+
+
+class BLOOMActor(Actor):
+ """
+ BLOOM Actor model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (BloomConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = BloomForCausalLM.from_pretrained(pretrained)
+ elif config is not None:
+ model = BloomForCausalLM(config)
+ else:
+ model = BloomForCausalLM(BloomConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ super().__init__(model, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py
new file mode 100644
index 000000000000..a32fb2e102f9
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py
@@ -0,0 +1,38 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from transformers import BloomConfig, BloomForCausalLM, BloomModel
+
+from ..base import Critic
+
+
+class BLOOMCritic(Critic):
+ """
+ BLOOM Critic model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (BloomConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none',
+ **kwargs) -> None:
+ if pretrained is not None:
+ model = BloomModel.from_pretrained(pretrained)
+ elif config is not None:
+ model = BloomModel(config)
+ else:
+ model = BloomModel(BloomConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ value_head = nn.Linear(model.config.hidden_size, 1)
+ super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py
new file mode 100644
index 000000000000..81e17f27c11a
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_lm.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+import torch
+from transformers import BloomConfig, BloomForCausalLM, BloomModel
+
+from ..base import LM
+
+
+class BLOOMLM(LM):
+ """
+ BLOOM language model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (BloomConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = BloomForCausalLM.from_pretrained(pretrained)
+ elif config is not None:
+ model = BloomForCausalLM(config)
+ else:
+ model = BloomForCausalLM(BloomConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ super().__init__(model, lora_rank, lora_train_bias)
+
diff --git a/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py
new file mode 100644
index 000000000000..2dba227ff7d0
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/bloom/bloom_rm.py
@@ -0,0 +1,37 @@
+from typing import Optional
+
+import torch.nn as nn
+from transformers import BloomConfig, BloomForCausalLM, BloomModel
+
+from ..base import RewardModel
+
+
+class BLOOMRM(RewardModel):
+ """
+ BLOOM Reward model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (BloomConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: str = None,
+ config: Optional[BloomConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = BloomModel.from_pretrained(pretrained)
+ elif config is not None:
+ model = BloomModel(config)
+ else:
+ model = BloomModel(BloomConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ value_head = nn.Linear(model.config.hidden_size, 1)
+ value_head.weight.data.normal_(mean=0.0, std=1/(model.config.hidden_size + 1))
+ super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/deberta/__init__.py b/applications/ChatGPT/chatgpt/models/deberta/__init__.py
new file mode 100644
index 000000000000..b66888f34fd0
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/deberta/__init__.py
@@ -0,0 +1,4 @@
+from .deberta_critic import DebertaCritic
+from .deberta_rm import DebertaRM
+
+__all__ = ['DebertaCritic', 'DebertaRM']
diff --git a/applications/ChatGPT/chatgpt/models/deberta/deberta_critic.py b/applications/ChatGPT/chatgpt/models/deberta/deberta_critic.py
new file mode 100644
index 000000000000..e84c1dbd8380
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/deberta/deberta_critic.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+import torch.nn as nn
+from transformers import DebertaV2Config, DebertaV2Model
+
+from ..base import Critic
+
+
+class DebertaCritic(Critic):
+ """
+ Deberta Critic model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (DebertaV2Config): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the LO-RA decomposition.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[DebertaV2Config] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = DebertaV2Model.from_pretrained(pretrained)
+ elif config is not None:
+ model = DebertaV2Model(config)
+ else:
+ model = DebertaV2Model(DebertaV2Config())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ value_head = nn.Linear(model.config.hidden_size, 1)
+ super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/deberta/deberta_rm.py b/applications/ChatGPT/chatgpt/models/deberta/deberta_rm.py
new file mode 100644
index 000000000000..2448c879ec85
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/deberta/deberta_rm.py
@@ -0,0 +1,37 @@
+from typing import Optional
+
+import torch.nn as nn
+from transformers import DebertaV2Config, DebertaV2Model
+
+from ..base import RewardModel
+
+
+class DebertaRM(RewardModel):
+ """
+ Deberta Reward model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (DebertaV2Config): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the LO-RA decomposition.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: str = None,
+ config: Optional[DebertaV2Config] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = DebertaV2Model.from_pretrained(pretrained)
+ elif config is not None:
+ model = DebertaV2Model(config)
+ else:
+ model = DebertaV2Model(DebertaV2Config())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ value_head = nn.Linear(model.config.hidden_size, 1)
+ value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
+ super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/generation.py b/applications/ChatGPT/chatgpt/models/generation.py
new file mode 100644
index 000000000000..eb30c36d0f84
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/generation.py
@@ -0,0 +1,146 @@
+from typing import Any, Callable, Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+
+try:
+ from transformers.generation_logits_process import (
+ LogitsProcessorList,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ )
+except ImportError:
+ from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
+
+
+def prepare_logits_processor(top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None) -> LogitsProcessorList:
+ processor_list = LogitsProcessorList()
+ if temperature is not None and temperature != 1.0:
+ processor_list.append(TemperatureLogitsWarper(temperature))
+ if top_k is not None and top_k != 0:
+ processor_list.append(TopKLogitsWarper(top_k))
+ if top_p is not None and top_p < 1.0:
+ processor_list.append(TopPLogitsWarper(top_p))
+ return processor_list
+
+
+def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ # consider DP
+ unfinished_sequences = unfinished_sequences.clone()
+ dist.all_reduce(unfinished_sequences)
+ return unfinished_sequences.max() == 0
+
+
+def sample(model: nn.Module,
+ input_ids: torch.Tensor,
+ max_length: int,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs) -> torch.Tensor:
+ if input_ids.size(1) >= max_length:
+ return input_ids
+
+ logits_processor = prepare_logits_processor(top_k, top_p, temperature)
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+
+ for _ in range(input_ids.size(1), max_length):
+ model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
+ 'input_ids': input_ids
+ }
+ outputs = model(**model_inputs)
+
+ next_token_logits = outputs['logits'][:, -1, :]
+ # pre-process distribution
+ next_token_logits = logits_processor(input_ids, next_token_logits)
+ # sample
+ probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+
+ # finished sentences should have their next token be a padding token
+ if eos_token_id is not None:
+ if pad_token_id is None:
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+
+ # update generated ids, model inputs for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ if update_model_kwargs_fn is not None:
+ model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
+
+ # if eos_token was found in one sentence, set sentence to finished
+ if eos_token_id is not None:
+ unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
+
+ # stop when each sentence is finished if early_stopping=True
+ if early_stopping and _is_sequence_finished(unfinished_sequences):
+ break
+
+ return input_ids
+
+
+def generate(model: nn.Module,
+ input_ids: torch.Tensor,
+ max_length: int,
+ num_beams: int = 1,
+ do_sample: bool = True,
+ early_stopping: bool = False,
+ eos_token_id: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[float] = None,
+ temperature: Optional[float] = None,
+ prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
+ update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
+ **model_kwargs) -> torch.Tensor:
+ """Generate token sequence. The returned sequence is input_ids + generated_tokens.
+
+ Args:
+ model (nn.Module): model
+ input_ids (torch.Tensor): input sequence
+ max_length (int): max length of the returned sequence
+ num_beams (int, optional): number of beams. Defaults to 1.
+ do_sample (bool, optional): whether to do sample. Defaults to True.
+ early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
+ eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
+ pad_token_id (Optional[int], optional): pad token id. Defaults to None.
+ top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
+ top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
+ temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
+ prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
+ update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
+ """
+ is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
+ is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
+ is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
+ if is_greedy_gen_mode:
+ # run greedy search
+ raise NotImplementedError
+ elif is_sample_gen_mode:
+ # run sample
+ return sample(model,
+ input_ids,
+ max_length,
+ early_stopping=early_stopping,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ top_k=top_k,
+ top_p=top_p,
+ temperature=temperature,
+ prepare_inputs_fn=prepare_inputs_fn,
+ update_model_kwargs_fn=update_model_kwargs_fn,
+ **model_kwargs)
+ elif is_beam_gen_mode:
+ raise NotImplementedError
+ else:
+ raise ValueError("Unsupported generation mode")
diff --git a/applications/ChatGPT/chatgpt/models/generation_utils.py b/applications/ChatGPT/chatgpt/models/generation_utils.py
new file mode 100644
index 000000000000..c7bc1b383fb9
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/generation_utils.py
@@ -0,0 +1,92 @@
+from typing import Optional
+
+import torch
+
+
+def gpt_prepare_inputs_fn(input_ids: torch.Tensor, past: Optional[torch.Tensor] = None, **kwargs) -> dict:
+ token_type_ids = kwargs.get("token_type_ids", None)
+ # only last token for inputs_ids if past is defined in kwargs
+ if past:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+
+ attention_mask = kwargs.get("attention_mask", None)
+ position_ids = kwargs.get("position_ids", None)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+ else:
+ position_ids = None
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "use_cache": kwargs.get("use_cache"),
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ }
+
+
+def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
+ if "past_key_values" in outputs:
+ model_kwargs["past"] = outputs["past_key_values"]
+ else:
+ model_kwargs["past"] = None
+
+ # update token_type_ids with last value
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
+
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
+
+ return model_kwargs
+
+
+def opt_prepare_inputs_fn(input_ids: torch.Tensor,
+ past: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs) -> dict:
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ if past:
+ input_ids = input_ids[:, -1:]
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "use_cache": use_cache,
+ }
+
+
+def bloom_prepare_inputs_fn(input_ids: torch.Tensor,
+ past: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ **kwargs) -> dict:
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_ids.shape)
+
+ if past:
+ input_ids = input_ids[:, -1:]
+ # first step, decoder_cached_states are empty
+ return {
+ "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
+ "attention_mask": attention_mask,
+ "past_key_values": past,
+ "use_cache": use_cache,
+ }
diff --git a/applications/ChatGPT/chatgpt/models/gpt/__init__.py b/applications/ChatGPT/chatgpt/models/gpt/__init__.py
new file mode 100644
index 000000000000..c6ae05113cc0
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/gpt/__init__.py
@@ -0,0 +1,6 @@
+from .gpt_actor import GPTActor
+from .gpt_critic import GPTCritic
+from .gpt_rm import GPTRM
+from .gpt_lm import GPTLM
+
+__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']
diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_actor.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_actor.py
new file mode 100644
index 000000000000..6a53ad40b817
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_actor.py
@@ -0,0 +1,35 @@
+from typing import Optional
+
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
+
+from ..base import Actor
+
+
+class GPTActor(Actor):
+ """
+ GPT Actor model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (GPT2Config): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the LoRa layer.
+ lora_train_bias (str): Bias training strategy for the LoRa layer.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = GPT2LMHeadModel.from_pretrained(pretrained)
+ elif config is not None:
+ model = GPT2LMHeadModel(config)
+ else:
+ model = GPT2LMHeadModel(GPT2Config())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ super().__init__(model, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py
new file mode 100644
index 000000000000..25bb1ed94de4
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py
@@ -0,0 +1,37 @@
+from typing import Optional
+
+import torch.nn as nn
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+from transformers.models.gpt2.modeling_gpt2 import GPT2Model
+
+from ..base import Critic
+
+
+class GPTCritic(Critic):
+ """
+ GPT Critic model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (GPT2Config): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the LO-RA decomposition.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = GPT2Model.from_pretrained(pretrained)
+ elif config is not None:
+ model = GPT2Model(config)
+ else:
+ model = GPT2Model(GPT2Config())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ value_head = nn.Linear(model.config.n_embd, 1)
+ super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py
new file mode 100644
index 000000000000..5740c80d3e77
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_lm.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
+
+from ..base import LM
+
+
+class GPTLM(LM):
+ """
+ GPT language model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (GPT2Config): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the LoRa layer.
+ lora_train_bias (str): Bias training strategy for the LoRa layer.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = GPT2LMHeadModel.from_pretrained(pretrained)
+ elif config is not None:
+ model = GPT2LMHeadModel(config)
+ else:
+ model = GPT2LMHeadModel(GPT2Config())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ super().__init__(model, lora_rank, lora_train_bias)
+
diff --git a/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py
new file mode 100644
index 000000000000..19d673de6825
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/gpt/gpt_rm.py
@@ -0,0 +1,39 @@
+from typing import Optional
+
+import torch.nn as nn
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+from transformers.models.gpt2.modeling_gpt2 import GPT2Model
+
+from ..base import RewardModel
+
+
+class GPTRM(RewardModel):
+ """
+ GPT Reward model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (GPT2Config): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the low-rank approximation.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[GPT2Config] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = GPT2Model.from_pretrained(pretrained)
+ elif config is not None:
+ model = GPT2Model(config)
+ else:
+ model = GPT2Model(GPT2Config())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+
+ value_head = nn.Linear(model.config.n_embd, 1)
+ value_head.weight.data.normal_(mean=0.0, std=1/(model.config.n_embd + 1))
+ super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/llama/__init__.py b/applications/ChatGPT/chatgpt/models/llama/__init__.py
new file mode 100644
index 000000000000..3edb51e14376
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/llama/__init__.py
@@ -0,0 +1,6 @@
+from .llama_actor import LlamaActor
+from .llama_critic import LlamaCritic
+from .llama_rm import LlamaRM
+from .llama_lm import LlamaLM
+
+__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_actor.py b/applications/ChatGPT/chatgpt/models/llama/llama_actor.py
new file mode 100644
index 000000000000..2c7adb390d8b
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/llama/llama_actor.py
@@ -0,0 +1,38 @@
+from typing import Optional
+
+import torch
+from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
+
+from ..base import Actor
+
+
+class LlamaActor(Actor):
+ """
+ Llama Actor model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (LlamaConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+
+ if pretrained is not None:
+ model = LlamaForCausalLM.from_pretrained(pretrained)
+ elif config is not None:
+ model = LlamaForCausalLM(config)
+ else:
+ model = LlamaForCausalLM(LlamaConfig())
+
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+
+ super().__init__(model, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_critic.py b/applications/ChatGPT/chatgpt/models/llama/llama_critic.py
new file mode 100644
index 000000000000..cd565031e112
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/llama/llama_critic.py
@@ -0,0 +1,42 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
+
+from ..base import Critic
+
+
+class LlamaCritic(Critic):
+ """
+ Llama Critic model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (LlamaConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none',
+ **kwargs) -> None:
+
+ if pretrained is not None:
+ model = LlamaForCausalLM.from_pretrained(pretrained)
+ elif config is not None:
+ model = LlamaForCausalLM(config)
+ else:
+ model = LlamaForCausalLM(LlamaConfig())
+
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+
+ value_head = nn.Linear(model.config.hidden_size, 1)
+
+ super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_lm.py b/applications/ChatGPT/chatgpt/models/llama/llama_lm.py
new file mode 100644
index 000000000000..c63077b1ac04
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/llama/llama_lm.py
@@ -0,0 +1,38 @@
+from typing import Optional
+
+from transformers import LlamaConfig, LlamaForCausalLM
+
+from ..base import LM
+
+
+class LlamaLM(LM):
+ """
+ Llama language model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (LlamaConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+
+ if pretrained is not None:
+ model = LlamaForCausalLM.from_pretrained(pretrained)
+ elif config is not None:
+ model = LlamaForCausalLM(config)
+ else:
+ model = LlamaForCausalLM(LlamaConfig())
+
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+
+ super().__init__(model, lora_rank, lora_train_bias)
+
diff --git a/applications/ChatGPT/chatgpt/models/llama/llama_rm.py b/applications/ChatGPT/chatgpt/models/llama/llama_rm.py
new file mode 100644
index 000000000000..81fa22d1969d
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/llama/llama_rm.py
@@ -0,0 +1,41 @@
+from typing import Optional
+
+import torch.nn as nn
+from transformers import LlamaConfig, LlamaForCausalLM
+
+from ..base import RewardModel
+
+
+class LlamaRM(RewardModel):
+ """
+ Llama Reward model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (LlamaConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): LoRA rank.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[LlamaConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+
+ if pretrained is not None:
+ model = LlamaForCausalLM.from_pretrained(pretrained)
+ elif config is not None:
+ model = LlamaForCausalLM(config)
+ else:
+ model = LlamaForCausalLM(LlamaConfig())
+
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+
+ value_head = nn.Linear(model.config.hidden_size, 1)
+ value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
+
+ super().__init__(model, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/lora.py b/applications/ChatGPT/chatgpt/models/lora.py
new file mode 100644
index 000000000000..9c19f472d726
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/lora.py
@@ -0,0 +1,130 @@
+import math
+from typing import Optional
+
+import loralib as lora
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LoraLinear(lora.LoRALayer, nn.Module):
+ """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
+ """
+
+ def __init__(
+ self,
+ weight: nn.Parameter,
+ bias: Optional[nn.Parameter],
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ merge_weights: bool = True,
+ ):
+ nn.Module.__init__(self)
+ lora.LoRALayer.__init__(self,
+ r=r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+ self.weight = weight
+ self.bias = bias
+
+ out_features, in_features = weight.shape
+ self.in_features = in_features
+ self.out_features = out_features
+
+ self.fan_in_fan_out = fan_in_fan_out
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.T
+
+ def reset_parameters(self):
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ def train(self, mode: bool = True):
+
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+
+ nn.Module.train(self, mode)
+ if self.merge_weights and self.merged:
+ # Make sure that the weights are not merged
+ if self.r > 0:
+ self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ self.merged = False
+
+ def eval(self):
+
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+
+ nn.Module.eval(self)
+ if self.merge_weights and not self.merged:
+ # Merge the weights and mark it
+ if self.r > 0:
+ self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ delattr(self, 'lora_A')
+ delattr(self, 'lora_B')
+ self.merged = True
+
+ def forward(self, x: torch.Tensor):
+
+ def T(w):
+ return w.T if self.fan_in_fan_out else w
+
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ if self.r > 0:
+ result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
+ return result
+ else:
+ return F.linear(x, T(self.weight), bias=self.bias)
+
+
+def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
+ assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
+ lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
+ return lora_linear
+
+
+def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
+ for name, child in module.named_children():
+ if isinstance(child, nn.Linear):
+ setattr(module, name, lora_linear_wrapper(child, lora_rank))
+ else:
+ convert_to_lora_recursively(child, lora_rank)
+
+
+class LoRAModule(nn.Module):
+ """A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
+ This calss will convert all torch.nn.Linear layer to LoraLinear layer.
+
+ Args:
+ lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
+ lora_train_bias (str, optional): Whether LoRA train biases.
+ 'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
+ Defaults to 'none'.
+ """
+
+ def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
+ super().__init__()
+ self.lora_rank = lora_rank
+ self.lora_train_bias = lora_train_bias
+
+ def convert_to_lora(self) -> None:
+ if self.lora_rank <= 0:
+ return
+ convert_to_lora_recursively(self, self.lora_rank)
+ lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
+
diff --git a/applications/ChatGPT/chatgpt/models/loss.py b/applications/ChatGPT/chatgpt/models/loss.py
new file mode 100644
index 000000000000..c5b1ccc93228
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/loss.py
@@ -0,0 +1,115 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from .utils import masked_mean
+
+
+class GPTLMLoss(nn.Module):
+ """
+ GPT Language Model Loss
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.loss = nn.CrossEntropyLoss()
+
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+
+class PolicyLoss(nn.Module):
+ """
+ Policy Loss for PPO
+ """
+
+ def __init__(self, clip_eps: float = 0.2) -> None:
+ super().__init__()
+ self.clip_eps = clip_eps
+
+ def forward(self,
+ log_probs: torch.Tensor,
+ old_log_probs: torch.Tensor,
+ advantages: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ ratio = (log_probs - old_log_probs).exp()
+ surr1 = ratio * advantages
+ surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
+ loss = -torch.min(surr1, surr2)
+ if action_mask is not None:
+ loss = masked_mean(loss, action_mask)
+ loss = loss.mean()
+ return loss
+
+
+class ValueLoss(nn.Module):
+ """
+ Value Loss for PPO
+ """
+
+ def __init__(self, clip_eps: float = 0.4) -> None:
+ super().__init__()
+ self.clip_eps = clip_eps
+
+ def forward(self,
+ values: torch.Tensor,
+ old_values: torch.Tensor,
+ reward: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
+ surr1 = (values_clipped - reward)**2
+ surr2 = (values - reward)**2
+ loss = torch.max(surr1, surr2)
+ loss = loss.mean()
+ return loss
+
+
+class PPOPtxActorLoss(nn.Module):
+ """
+ To Do:
+
+ PPO-ptx Actor Loss
+ """
+
+ def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
+ super().__init__()
+ self.pretrain_coef = pretrain_coef
+ self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
+ self.pretrain_loss_fn = pretrain_loss_fn
+
+ def forward(self,
+ log_probs: torch.Tensor,
+ old_log_probs: torch.Tensor,
+ advantages: torch.Tensor,
+ lm_logits: torch.Tensor,
+ lm_input_ids: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
+ lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
+ return policy_loss + self.pretrain_coef * lm_loss
+
+
+class LogSigLoss(nn.Module):
+ """
+ Pairwise Loss for Reward Model
+ Details: https://arxiv.org/abs/2203.02155
+ """
+ def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
+ probs = torch.sigmoid(chosen_reward - reject_reward)
+ log_probs = torch.log(probs)
+ loss = -log_probs.mean()
+ return loss
+
+
+class LogExpLoss(nn.Module):
+ """
+ Pairwise Loss for Reward Model
+ Details: https://arxiv.org/abs/2204.05862
+ """
+ def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
+ loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
+ return loss
diff --git a/applications/ChatGPT/chatgpt/models/opt/__init__.py b/applications/ChatGPT/chatgpt/models/opt/__init__.py
new file mode 100644
index 000000000000..fccec3bdff99
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/opt/__init__.py
@@ -0,0 +1,6 @@
+from .opt_actor import OPTActor
+from .opt_critic import OPTCritic
+from .opt_rm import OPTRM
+from .opt_lm import OPTLM
+
+__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_actor.py b/applications/ChatGPT/chatgpt/models/opt/opt_actor.py
new file mode 100644
index 000000000000..c14e4377ffb2
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/opt/opt_actor.py
@@ -0,0 +1,35 @@
+from typing import Optional
+
+from transformers.models.opt.configuration_opt import OPTConfig
+from transformers.models.opt.modeling_opt import OPTForCausalLM
+
+from ..base import Actor
+
+
+class OPTActor(Actor):
+ """
+ OPT Actor model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (OPTConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the low-rank approximation.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = OPTForCausalLM.from_pretrained(pretrained)
+ elif config is not None:
+ model = OPTForCausalLM(config)
+ else:
+ model = OPTForCausalLM(OPTConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ super().__init__(model, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_critic.py b/applications/ChatGPT/chatgpt/models/opt/opt_critic.py
new file mode 100644
index 000000000000..fcfebd8a8b03
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/opt/opt_critic.py
@@ -0,0 +1,38 @@
+from typing import Optional
+
+import torch.nn as nn
+from transformers.models.opt.configuration_opt import OPTConfig
+from transformers.models.opt.modeling_opt import OPTModel
+
+from ..base import Critic
+
+
+class OPTCritic(Critic):
+ """
+ OPT Critic model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (OPTConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the low-rank approximation.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none',
+ **kwargs) -> None:
+ if pretrained is not None:
+ model = OPTModel.from_pretrained(pretrained)
+ elif config is not None:
+ model = OPTModel(config)
+ else:
+ model = OPTModel(OPTConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
+ super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_lm.py b/applications/ChatGPT/chatgpt/models/opt/opt_lm.py
new file mode 100644
index 000000000000..35bfe198a225
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/opt/opt_lm.py
@@ -0,0 +1,36 @@
+from typing import Optional
+
+from transformers.models.opt.configuration_opt import OPTConfig
+from transformers.models.opt.modeling_opt import OPTForCausalLM
+
+from ..base import LM
+
+
+class OPTLM(LM):
+ """
+ OPT language model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (OPTConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the low-rank approximation.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = OPTForCausalLM.from_pretrained(pretrained)
+ elif config is not None:
+ model = OPTForCausalLM(config)
+ else:
+ model = OPTForCausalLM(OPTConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+ super().__init__(model, lora_rank, lora_train_bias)
+
diff --git a/applications/ChatGPT/chatgpt/models/opt/opt_rm.py b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py
new file mode 100644
index 000000000000..ef7f0fb16fd1
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/opt/opt_rm.py
@@ -0,0 +1,38 @@
+from typing import Optional
+
+import torch.nn as nn
+from transformers import OPTConfig, OPTModel
+
+from ..base import RewardModel
+
+
+class OPTRM(RewardModel):
+ """
+ OPT Reward model.
+
+ Args:
+ pretrained (str): Pretrained model name or path.
+ config (OPTConfig): Model config.
+ checkpoint (bool): Enable gradient checkpointing.
+ lora_rank (int): Rank of the low-rank approximation.
+ lora_train_bias (str): LoRA bias training mode.
+ """
+
+ def __init__(self,
+ pretrained: Optional[str] = None,
+ config: Optional[OPTConfig] = None,
+ checkpoint: bool = False,
+ lora_rank: int = 0,
+ lora_train_bias: str = 'none') -> None:
+ if pretrained is not None:
+ model = OPTModel.from_pretrained(pretrained)
+ elif config is not None:
+ model = OPTModel(config)
+ else:
+ model = OPTModel(OPTConfig())
+ if checkpoint:
+ model.gradient_checkpointing_enable()
+
+ value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
+ value_head.weight.data.normal_(mean=0.0, std=1/(model.config.word_embed_proj_dim + 1))
+ super().__init__(model, value_head, lora_rank, lora_train_bias)
diff --git a/applications/ChatGPT/chatgpt/models/utils.py b/applications/ChatGPT/chatgpt/models/utils.py
new file mode 100644
index 000000000000..0ff13181fcd2
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/models/utils.py
@@ -0,0 +1,92 @@
+from typing import Optional, Union
+
+import loralib as lora
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def compute_approx_kl(log_probs: torch.Tensor,
+ log_probs_base: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Compute the approximate KL divergence between two distributions.
+ Schulman blog: http://joschu.net/blog/kl-approx.html
+
+ Args:
+ log_probs: Log probabilities of the new distribution.
+ log_probs_base: Log probabilities of the base distribution.
+ action_mask: Mask for actions.
+ """
+
+ log_ratio = log_probs - log_probs_base
+ approx_kl = (log_ratio.exp() - 1) - log_ratio
+ if action_mask is not None:
+ approx_kl = masked_mean(approx_kl, action_mask, dim=1)
+ return approx_kl
+ approx_kl = approx_kl.mean(dim=1)
+ return approx_kl
+
+
+def compute_reward(r: Union[torch.Tensor, float],
+ kl_coef: float,
+ log_probs: torch.Tensor,
+ log_probs_base: torch.Tensor,
+ action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if kl_coef <= 0.0:
+ return r
+ kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
+ reward = r - kl_coef * kl
+ return reward
+
+
+def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
+ log_probs = F.log_softmax(logits, dim=-1)
+ log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
+ return log_probs_labels.squeeze(-1)
+
+
+def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
+ tensor = tensor * mask
+ tensor = tensor.sum(dim=dim)
+ mask_sum = mask.sum(dim=dim)
+ mean = tensor / (mask_sum + 1e-8)
+ return mean
+
+
+def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
+ tensor = tensor * mask
+ mean = masked_mean(tensor, mask, dim=dim)
+ mean_centered = tensor - mean
+ var = masked_mean(mean_centered**2, mask, dim=dim)
+ return mean_centered * var.clamp(min=eps).rsqrt()
+
+
+def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
+ mean = tensor.mean(dim)
+ mean_centered = tensor - mean
+ var = (mean_centered**2).mean(dim)
+ norm = mean_centered * var.clamp(min=eps).rsqrt()
+ return norm
+
+
+def convert_to_lora(model: nn.Module,
+ input_size: int,
+ output_size: int,
+ lora_rank: int = 16,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ fan_in_fan_out: bool = False,
+ merge_weights: bool = True):
+ if lora_rank > min(input_size, output_size):
+ raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
+
+ for name, module in model.named_modules():
+ if isinstance(module, nn.Linear):
+ module._modules[name] = lora.Linear(input_size,
+ output_size,
+ r=lora_rank,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ fan_in_fan_out=fan_in_fan_out,
+ merge_weights=merge_weights)
diff --git a/applications/ChatGPT/chatgpt/replay_buffer/__init__.py b/applications/ChatGPT/chatgpt/replay_buffer/__init__.py
new file mode 100644
index 000000000000..1ebf60382913
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/replay_buffer/__init__.py
@@ -0,0 +1,4 @@
+from .base import ReplayBuffer
+from .naive import NaiveReplayBuffer
+
+__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
diff --git a/applications/ChatGPT/chatgpt/replay_buffer/base.py b/applications/ChatGPT/chatgpt/replay_buffer/base.py
new file mode 100644
index 000000000000..5036b09045c4
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/replay_buffer/base.py
@@ -0,0 +1,43 @@
+from abc import ABC, abstractmethod
+from typing import Any
+
+from chatgpt.experience_maker.base import Experience
+
+
+class ReplayBuffer(ABC):
+ """Replay buffer base class. It stores experience.
+
+ Args:
+ sample_batch_size (int): Batch size when sampling.
+ limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
+ """
+
+ def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
+ super().__init__()
+ self.sample_batch_size = sample_batch_size
+ # limit <= 0 means unlimited
+ self.limit = limit
+
+ @abstractmethod
+ def append(self, experience: Experience) -> None:
+ pass
+
+ @abstractmethod
+ def clear(self) -> None:
+ pass
+
+ @abstractmethod
+ def sample(self) -> Experience:
+ pass
+
+ @abstractmethod
+ def __len__(self) -> int:
+ pass
+
+ @abstractmethod
+ def __getitem__(self, idx: int) -> Any:
+ pass
+
+ @abstractmethod
+ def collate_fn(self, batch: Any) -> Experience:
+ pass
diff --git a/applications/ChatGPT/chatgpt/replay_buffer/naive.py b/applications/ChatGPT/chatgpt/replay_buffer/naive.py
new file mode 100644
index 000000000000..3fc53da65bff
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/replay_buffer/naive.py
@@ -0,0 +1,57 @@
+import random
+from typing import List
+
+import torch
+from chatgpt.experience_maker.base import Experience
+
+from .base import ReplayBuffer
+from .utils import BufferItem, make_experience_batch, split_experience_batch
+
+
+class NaiveReplayBuffer(ReplayBuffer):
+ """Naive replay buffer class. It stores experience.
+
+ Args:
+ sample_batch_size (int): Batch size when sampling.
+ limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
+ cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
+ """
+
+ def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
+ super().__init__(sample_batch_size, limit)
+ self.cpu_offload = cpu_offload
+ self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
+ # TODO(ver217): add prefetch
+ self.items: List[BufferItem] = []
+
+ @torch.no_grad()
+ def append(self, experience: Experience) -> None:
+ if self.cpu_offload:
+ experience.to_device(torch.device('cpu'))
+ items = split_experience_batch(experience)
+ self.items.extend(items)
+ if self.limit > 0:
+ samples_to_remove = len(self.items) - self.limit
+ if samples_to_remove > 0:
+ self.items = self.items[samples_to_remove:]
+
+ def clear(self) -> None:
+ self.items.clear()
+
+ @torch.no_grad()
+ def sample(self) -> Experience:
+ items = random.sample(self.items, self.sample_batch_size)
+ experience = make_experience_batch(items)
+ if self.cpu_offload:
+ experience.to_device(self.target_device)
+ return experience
+
+ def __len__(self) -> int:
+ return len(self.items)
+
+ def __getitem__(self, idx: int) -> BufferItem:
+ return self.items[idx]
+
+ def collate_fn(self, batch) -> Experience:
+ experience = make_experience_batch(batch)
+ return experience
diff --git a/applications/ChatGPT/chatgpt/replay_buffer/utils.py b/applications/ChatGPT/chatgpt/replay_buffer/utils.py
new file mode 100644
index 000000000000..752f16704771
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/replay_buffer/utils.py
@@ -0,0 +1,73 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from chatgpt.experience_maker.base import Experience
+
+
+@dataclass
+class BufferItem:
+ """BufferItem is an item of experience data.
+
+ Shapes of each tensor:
+ sequences: (S)
+ action_log_probs: (A)
+ values: (1)
+ reward: (1)
+ advatanges: (1)
+ attention_mask: (S)
+ action_mask: (A)
+
+ "A" is the number of actions.
+ """
+ sequences: torch.Tensor
+ action_log_probs: torch.Tensor
+ values: torch.Tensor
+ reward: torch.Tensor
+ advantages: torch.Tensor
+ attention_mask: Optional[torch.LongTensor]
+ action_mask: Optional[torch.BoolTensor]
+
+
+def split_experience_batch(experience: Experience) -> List[BufferItem]:
+ batch_size = experience.sequences.size(0)
+ batch_kwargs = [{} for _ in range(batch_size)]
+ keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
+ for key in keys:
+ value = getattr(experience, key)
+ if isinstance(value, torch.Tensor):
+ vals = torch.unbind(value)
+ else:
+ # None
+ vals = [value for _ in range(batch_size)]
+ assert batch_size == len(vals)
+ for i, v in enumerate(vals):
+ batch_kwargs[i][key] = v
+ items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
+ return items
+
+
+def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
+ assert side in ('left', 'right')
+ max_len = max(seq.size(0) for seq in sequences)
+ padded_sequences = []
+ for seq in sequences:
+ pad_len = max_len - seq.size(0)
+ padding = (pad_len, 0) if side == 'left' else (0, pad_len)
+ padded_sequences.append(F.pad(seq, padding))
+ return torch.stack(padded_sequences, dim=0)
+
+
+def make_experience_batch(items: List[BufferItem]) -> Experience:
+ kwargs = {}
+ to_pad_keys = set(('action_log_probs', 'action_mask'))
+ keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
+ for key in keys:
+ vals = [getattr(item, key) for item in items]
+ if key in to_pad_keys:
+ batch_data = zero_pad_sequences(vals)
+ else:
+ batch_data = torch.stack(vals, dim=0)
+ kwargs[key] = batch_data
+ return Experience(**kwargs)
diff --git a/applications/ChatGPT/chatgpt/trainer/__init__.py b/applications/ChatGPT/chatgpt/trainer/__init__.py
new file mode 100644
index 000000000000..525b57bf21d3
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/__init__.py
@@ -0,0 +1,6 @@
+from .base import Trainer
+from .ppo import PPOTrainer
+from .rm import RewardModelTrainer
+from .sft import SFTTrainer
+
+__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
diff --git a/applications/ChatGPT/chatgpt/trainer/base.py b/applications/ChatGPT/chatgpt/trainer/base.py
new file mode 100644
index 000000000000..a2419a35b6cd
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/base.py
@@ -0,0 +1,162 @@
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from chatgpt.experience_maker import Experience, ExperienceMaker
+from chatgpt.replay_buffer import ReplayBuffer
+from torch import Tensor
+from torch.utils.data import DistributedSampler
+from tqdm import tqdm
+
+from .callbacks import Callback
+from .strategies import Strategy
+from .utils import is_rank_0
+
+
+class Trainer(ABC):
+ """
+ Base class for rlhf trainers.
+
+ Args:
+ strategy (Strategy):the strategy to use for training
+ experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
+ replay_buffer (ReplayBuffer): the replay buffer to use for training
+ experience_batch_size (int, defaults to 8): the batch size to use for experience generation
+ max_epochs (int, defaults to 1): the number of epochs of training process
+ tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
+ sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
+ data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
+ callbacks (List[Callback], defaults to []): the callbacks to call during training process
+ generate_kwargs (dict, optional): the kwargs to use while model generating
+ """
+
+ def __init__(self,
+ strategy: Strategy,
+ experience_maker: ExperienceMaker,
+ replay_buffer: ReplayBuffer,
+ experience_batch_size: int = 8,
+ max_epochs: int = 1,
+ tokenizer: Optional[Callable[[Any], dict]] = None,
+ sample_replay_buffer: bool = False,
+ dataloader_pin_memory: bool = True,
+ callbacks: List[Callback] = [],
+ **generate_kwargs) -> None:
+ super().__init__()
+ self.strategy = strategy
+ self.experience_maker = experience_maker
+ self.replay_buffer = replay_buffer
+ self.experience_batch_size = experience_batch_size
+ self.max_epochs = max_epochs
+ self.tokenizer = tokenizer
+ self.generate_kwargs = generate_kwargs
+ self.sample_replay_buffer = sample_replay_buffer
+ self.dataloader_pin_memory = dataloader_pin_memory
+ self.callbacks = callbacks
+
+ @abstractmethod
+ def training_step(self, experience: Experience) -> Dict[str, Any]:
+ pass
+
+ def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
+ if isinstance(inputs, Tensor):
+ return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
+ elif isinstance(inputs, dict):
+ return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
+ else:
+ raise ValueError(f'Unsupported input type "{type(inputs)}"')
+
+ def _sample_prompts(self, prompts) -> list:
+ indices = list(range(len(prompts)))
+ sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
+ return [prompts[i] for i in sampled_indices]
+
+ def _learn(self):
+ # replay buffer may be empty at first, we should rebuild at each training
+ if not self.sample_replay_buffer:
+ dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
+ device = torch.cuda.current_device()
+ if self.sample_replay_buffer:
+ pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
+ for _ in pbar:
+ experience = self.replay_buffer.sample()
+ metrics = self.training_step(experience)
+ pbar.set_postfix(metrics)
+ else:
+ for epoch in range(self.max_epochs):
+ self._on_learn_epoch_start(epoch)
+ if isinstance(dataloader.sampler, DistributedSampler):
+ dataloader.sampler.set_epoch(epoch)
+ pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
+ for experience in pbar:
+ self._on_learn_batch_start()
+ experience.to_device(device)
+ metrics = self.training_step(experience)
+ self._on_learn_batch_end(metrics, experience)
+ pbar.set_postfix(metrics)
+ self._on_learn_epoch_end(epoch)
+
+ def fit(self, prompts, num_episodes: int = 50000, max_timesteps: int = 500, update_timesteps: int = 5000) -> None:
+ time = 0
+ sampler = self.strategy.setup_sampler(prompts)
+ self._on_fit_start()
+ for episode in range(num_episodes):
+ self._on_episode_start(episode)
+ for timestep in tqdm(range(max_timesteps),
+ desc=f'Episode [{episode+1}/{num_episodes}]',
+ disable=not is_rank_0()):
+ time += 1
+ rand_prompts = sampler.sample(self.experience_batch_size)
+ if self.tokenizer is not None:
+ inputs = self.tokenizer(rand_prompts)
+ else:
+ inputs = rand_prompts
+ self._on_make_experience_start()
+ experience = self._make_experience(inputs)
+ self._on_make_experience_end(experience)
+ self.replay_buffer.append(experience)
+ if time % update_timesteps == 0:
+ self._learn()
+ self.replay_buffer.clear()
+ self._on_episode_end(episode)
+ self._on_fit_end()
+
+ # TODO(ver217): maybe simplify these code using context
+ def _on_fit_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_fit_start()
+
+ def _on_fit_end(self) -> None:
+ for callback in self.callbacks:
+ callback.on_fit_end()
+
+ def _on_episode_start(self, episode: int) -> None:
+ for callback in self.callbacks:
+ callback.on_episode_start(episode)
+
+ def _on_episode_end(self, episode: int) -> None:
+ for callback in self.callbacks:
+ callback.on_episode_end(episode)
+
+ def _on_make_experience_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_make_experience_start()
+
+ def _on_make_experience_end(self, experience: Experience) -> None:
+ for callback in self.callbacks:
+ callback.on_make_experience_end(experience)
+
+ def _on_learn_epoch_start(self, epoch: int) -> None:
+ for callback in self.callbacks:
+ callback.on_learn_epoch_start(epoch)
+
+ def _on_learn_epoch_end(self, epoch: int) -> None:
+ for callback in self.callbacks:
+ callback.on_learn_epoch_end(epoch)
+
+ def _on_learn_batch_start(self) -> None:
+ for callback in self.callbacks:
+ callback.on_learn_batch_start()
+
+ def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
+ for callback in self.callbacks:
+ callback.on_learn_batch_end(metrics, experience)
diff --git a/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py b/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py
new file mode 100644
index 000000000000..9ed0ee6f7640
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py
@@ -0,0 +1,5 @@
+from .base import Callback
+from .performance_evaluator import PerformanceEvaluator
+from .save_checkpoint import SaveCheckpoint
+
+__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
diff --git a/applications/ChatGPT/chatgpt/trainer/callbacks/base.py b/applications/ChatGPT/chatgpt/trainer/callbacks/base.py
new file mode 100644
index 000000000000..0b01345f7872
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/callbacks/base.py
@@ -0,0 +1,39 @@
+from abc import ABC
+
+from chatgpt.experience_maker import Experience
+
+
+class Callback(ABC):
+ """
+ Base callback class. It defines the interface for callbacks.
+ """
+
+ def on_fit_start(self) -> None:
+ pass
+
+ def on_fit_end(self) -> None:
+ pass
+
+ def on_episode_start(self, episode: int) -> None:
+ pass
+
+ def on_episode_end(self, episode: int) -> None:
+ pass
+
+ def on_make_experience_start(self) -> None:
+ pass
+
+ def on_make_experience_end(self, experience: Experience) -> None:
+ pass
+
+ def on_learn_epoch_start(self, epoch: int) -> None:
+ pass
+
+ def on_learn_epoch_end(self, epoch: int) -> None:
+ pass
+
+ def on_learn_batch_start(self) -> None:
+ pass
+
+ def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
+ pass
diff --git a/applications/ChatGPT/chatgpt/trainer/callbacks/performance_evaluator.py b/applications/ChatGPT/chatgpt/trainer/callbacks/performance_evaluator.py
new file mode 100644
index 000000000000..faa38af1b84e
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/callbacks/performance_evaluator.py
@@ -0,0 +1,133 @@
+from time import time
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+from chatgpt.experience_maker import Experience
+
+from .base import Callback
+
+
+def get_world_size() -> int:
+ if dist.is_initialized():
+ return dist.get_world_size()
+ return 1
+
+
+def print_rank_0(*args, **kwargs) -> None:
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ print(*args, **kwargs)
+
+
+@torch.no_grad()
+def all_reduce_mean(x: float, world_size: int) -> float:
+ if world_size == 1:
+ return x
+ tensor = torch.tensor([x], device=torch.cuda.current_device())
+ dist.all_reduce(tensor)
+ tensor = tensor / world_size
+ return tensor.item()
+
+
+class PerformanceEvaluator(Callback):
+ """
+ Callback for valuate the performance of the model.
+ Args:
+ actor_num_params: The number of parameters of the actor model.
+ critic_num_params: The number of parameters of the critic model.
+ initial_model_num_params: The number of parameters of the initial model.
+ reward_model_num_params: The number of parameters of the reward model.
+ enable_grad_checkpoint: Whether to enable gradient checkpointing.
+ ignore_episodes: The number of episodes to ignore when calculating the performance.
+ """
+
+ def __init__(self,
+ actor_num_params: int,
+ critic_num_params: int,
+ initial_model_num_params: int,
+ reward_model_num_params: int,
+ enable_grad_checkpoint: bool = False,
+ ignore_episodes: int = 0) -> None:
+ super().__init__()
+ self.world_size = get_world_size()
+ self.actor_num_params = actor_num_params
+ self.critic_num_params = critic_num_params
+ self.initial_model_num_params = initial_model_num_params
+ self.reward_model_num_params = reward_model_num_params
+ self.enable_grad_checkpoint = enable_grad_checkpoint
+ self.ignore_episodes = ignore_episodes
+ self.disable: bool = False
+
+ self.make_experience_duration: float = 0.
+ self.make_experience_start_time: Optional[float] = None
+ self.make_experience_num_samples: int = 0
+ self.make_experience_flop: int = 0
+ self.learn_duration: float = 0.
+ self.learn_start_time: Optional[float] = None
+ self.learn_num_samples: int = 0
+ self.learn_flop: int = 0
+
+ def on_episode_start(self, episode: int) -> None:
+ self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
+
+ def on_make_experience_start(self) -> None:
+ if self.disable:
+ return
+ self.make_experience_start_time = time()
+
+ def on_make_experience_end(self, experience: Experience) -> None:
+ if self.disable:
+ return
+ self.make_experience_duration += time() - self.make_experience_start_time
+
+ batch_size, seq_len = experience.sequences.shape
+
+ self.make_experience_num_samples += batch_size
+
+ # actor generate
+ num_actions = experience.action_mask.size(1)
+ input_len = seq_len - num_actions
+ total_seq_len = (input_len + seq_len - 1) * num_actions / 2
+ self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
+ # actor forward
+ self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
+ # critic forward
+ self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
+ # initial model forward
+ self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
+ # reward model forward
+ self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
+
+ def on_learn_batch_start(self) -> None:
+ if self.disable:
+ return
+ self.learn_start_time = time()
+
+ def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
+ if self.disable:
+ return
+ self.learn_duration += time() - self.learn_start_time
+
+ batch_size, seq_len = experience.sequences.shape
+
+ self.learn_num_samples += batch_size
+
+ # actor forward-backward, 3 means forward(1) + backward(2)
+ self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
+ # critic foward-backward
+ self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
+
+ def on_fit_end(self) -> None:
+ avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size)
+ avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size)
+
+ avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12)
+ avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
+
+ avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12)
+ avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12)
+
+ print_rank_0(
+ f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}'
+ )
+ print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}')
diff --git a/applications/ChatGPT/chatgpt/trainer/callbacks/save_checkpoint.py b/applications/ChatGPT/chatgpt/trainer/callbacks/save_checkpoint.py
new file mode 100644
index 000000000000..8f2beb12db22
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/callbacks/save_checkpoint.py
@@ -0,0 +1,75 @@
+import os
+
+import torch.distributed as dist
+from chatgpt.trainer.strategies import ColossalAIStrategy, Strategy
+from chatgpt.trainer.utils import is_rank_0
+from torch import nn
+from torch.optim import Optimizer
+
+from .base import Callback
+
+
+class SaveCheckpoint(Callback):
+ """
+ The callback for saving checkpoint for chatgpt.
+
+ Only support saving actor and critic model.
+ A typical architecture of the saved checkpoint would be:
+ - checkpoint
+ - episode_x
+ - actor.pt
+ - actor-optim-rank-0.pt
+ - actor-optim-rank-1.pt
+ - critic.pt
+ - critic-optim-rank-0.pt
+ - critic-optim-rank-1.pt
+ - ...
+
+ Args:
+ path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint`
+ interval(int): the interval episode of saving checkpoint
+ strategy(Strategy): the strategy used to train
+ actor(nn.Module): the actor model
+ critic(nn.Module): the critic model
+ actor_optim(Optimizer): the optimizer of actor
+ critic_optim(Optimizer): the optimizer of critic
+
+ """
+
+ def __init__(self,
+ path: str,
+ interval: int,
+ strategy: Strategy,
+ actor: nn.Module = None,
+ critic: nn.Module = None,
+ actor_optim: Optimizer = None,
+ critic_optim: Optimizer = None) -> None:
+ super().__init__()
+ self.path = os.path.join(path, 'checkpoint')
+ self.interval = interval
+ self.strategy = strategy
+ self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
+
+ def on_episode_end(self, episode: int) -> None:
+ if (episode + 1) % self.interval != 0:
+ return
+ base_path = os.path.join(self.path, f'episode_{episode}')
+ if not os.path.exists(base_path):
+ os.makedirs(base_path)
+
+ for model in self.model_dict.keys():
+
+ # save model
+ if self.model_dict[model][0] is None:
+ # saving only optimizer states is meaningless, so it would be skipped
+ continue
+ model_path = os.path.join(base_path, f'{model}.pt')
+ self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
+
+ # save optimizer
+ if self.model_dict[model][1] is None:
+ continue
+ only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
+ rank = 0 if is_rank_0() else dist.get_rank()
+ optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
+ self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
diff --git a/applications/ChatGPT/chatgpt/trainer/ppo.py b/applications/ChatGPT/chatgpt/trainer/ppo.py
new file mode 100644
index 000000000000..dacab4784039
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/ppo.py
@@ -0,0 +1,116 @@
+from typing import Any, Callable, Dict, List, Optional
+
+import torch.nn as nn
+from chatgpt.experience_maker import Experience, NaiveExperienceMaker
+from chatgpt.models.base import Actor, Critic
+from chatgpt.models.generation_utils import update_model_kwargs_fn
+from chatgpt.models.loss import PolicyLoss, ValueLoss
+from chatgpt.replay_buffer import NaiveReplayBuffer
+from torch.optim import Optimizer
+
+from .base import Trainer
+from .callbacks import Callback
+from .strategies import Strategy
+
+
+class PPOTrainer(Trainer):
+ """
+ Trainer for PPO algorithm.
+
+ Args:
+ strategy (Strategy): the strategy to use for training
+ actor (Actor): the actor model in ppo algorithm
+ critic (Critic): the critic model in ppo algorithm
+ reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
+ initial_model (Actor): the initial model in rlhf algorithm to generate reference logits to limit the update of actor
+ actor_optim (Optimizer): the optimizer to use for actor model
+ critic_optim (Optimizer): the optimizer to use for critic model
+ kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
+ train_batch_size (int, defaults to 8): the batch size to use for training
+ buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
+ buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
+ eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
+ value_clip (float, defaults to 0.4): the clip coefficient of value loss
+ experience_batch_size (int, defaults to 8): the batch size to use for experience generation
+ max_epochs (int, defaults to 1): the number of epochs of training process
+ tokenier (Callable, optional): the tokenizer to use for tokenizing the input
+ sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
+ dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
+ callbacks (List[Callback], defaults to []): the callbacks to call during training process
+ generate_kwargs (dict, optional): the kwargs to use while model generating
+ """
+
+ def __init__(self,
+ strategy: Strategy,
+ actor: Actor,
+ critic: Critic,
+ reward_model: nn.Module,
+ initial_model: Actor,
+ actor_optim: Optimizer,
+ critic_optim: Optimizer,
+ kl_coef: float = 0.1,
+ train_batch_size: int = 8,
+ buffer_limit: int = 0,
+ buffer_cpu_offload: bool = True,
+ eps_clip: float = 0.2,
+ value_clip: float = 0.4,
+ experience_batch_size: int = 8,
+ max_epochs: int = 1,
+ tokenizer: Optional[Callable[[Any], dict]] = None,
+ sample_replay_buffer: bool = False,
+ dataloader_pin_memory: bool = True,
+ callbacks: List[Callback] = [],
+ **generate_kwargs) -> None:
+ experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
+ replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
+ generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
+ super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
+ sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
+ self.actor = actor
+ self.critic = critic
+
+ self.actor_loss_fn = PolicyLoss(eps_clip)
+ self.critic_loss_fn = ValueLoss(value_clip)
+
+ self.actor_optim = actor_optim
+ self.critic_optim = critic_optim
+
+ def training_step(self, experience: Experience) -> Dict[str, float]:
+ self.actor.train()
+ self.critic.train()
+
+ num_actions = experience.action_mask.size(1)
+ action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
+ actor_loss = self.actor_loss_fn(action_log_probs,
+ experience.action_log_probs,
+ experience.advantages,
+ action_mask=experience.action_mask)
+ self.strategy.backward(actor_loss, self.actor, self.actor_optim)
+ self.strategy.optimizer_step(self.actor_optim)
+ self.actor_optim.zero_grad()
+
+ values = self.critic(experience.sequences,
+ action_mask=experience.action_mask,
+ attention_mask=experience.attention_mask)
+ critic_loss = self.critic_loss_fn(values,
+ experience.values,
+ experience.reward,
+ action_mask=experience.action_mask)
+ self.strategy.backward(critic_loss, self.critic, self.critic_optim)
+ self.strategy.optimizer_step(self.critic_optim)
+ self.critic_optim.zero_grad()
+
+ return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
+
+
+def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
+ origin_model = strategy._unwrap_actor(actor)
+ new_kwargs = {**generate_kwargs}
+ # use huggingface models method directly
+ if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
+ new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
+
+ if 'update_model_kwargs_fn' not in generate_kwargs:
+ new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
+
+ return new_kwargs
diff --git a/applications/ChatGPT/chatgpt/trainer/rm.py b/applications/ChatGPT/chatgpt/trainer/rm.py
new file mode 100644
index 000000000000..7fa87a64968b
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/rm.py
@@ -0,0 +1,120 @@
+from abc import ABC
+import pandas as pd
+import loralib as lora
+import torch
+from datetime import datetime
+from torch.optim import Optimizer, lr_scheduler
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+
+from .strategies import Strategy
+from .utils import is_rank_0
+
+
+class RewardModelTrainer(ABC):
+ """
+ Trainer to use while training reward model.
+
+ Args:
+ model (torch.nn.Module): the model to train
+ strategy (Strategy): the strategy to use for training
+ optim(Optimizer): the optimizer to use for training
+ loss_fn (callable): the loss function to use for training
+ train_dataset (Dataset): the dataset to use for training
+ valid_dataset (Dataset): the dataset to use for validation
+ eval_dataset (Dataset): the dataset to use for evaluation
+ batch_size (int, defaults to 1): the batch size while training
+ max_epochs (int, defaults to 2): the number of epochs to train
+ """
+
+ def __init__(
+ self,
+ model,
+ strategy: Strategy,
+ optim: Optimizer,
+ loss_fn,
+ train_dataset: Dataset,
+ valid_dataset: Dataset,
+ eval_dataset: Dataset,
+ batch_size: int = 1,
+ max_epochs: int = 1,
+ ) -> None:
+ super().__init__()
+ self.strategy = strategy
+ self.epochs = max_epochs
+ self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+ self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
+ self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
+
+ self.model = strategy.setup_model(model)
+ self.loss_fn = loss_fn
+ self.optimizer = strategy.setup_optimizer(optim, self.model)
+ self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__()//100)
+
+
+ def eval_acc(self, dataloader):
+ dist = 0
+ on = 0
+ cnt = 0
+ self.model.eval()
+ with torch.no_grad():
+ for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
+ chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
+ c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
+ reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
+ r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
+ chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
+ reject_reward = self.model(reject_ids, attention_mask=r_mask)
+ for i in range(len(chosen_reward)):
+ cnt += 1
+ if chosen_reward[i] > reject_reward[i]:
+ on += 1
+ dist += (chosen_reward - reject_reward).mean().item()
+ dist_mean = dist / len(dataloader)
+ acc = on / cnt
+ self.model.train()
+ return dist_mean, acc
+
+
+ def fit(self):
+ time = datetime.now()
+ epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
+ for epoch in range(self.epochs):
+ step_bar = tqdm(range(self.train_dataloader.__len__()),
+ desc='Train step of epoch %d' % epoch,
+ disable=not is_rank_0())
+ # train
+ self.model.train()
+ cnt = 0
+ acc = 0
+ dist = 0
+ for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
+ chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
+ c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
+ reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
+ r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
+ chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
+ reject_reward = self.model(reject_ids, attention_mask=r_mask)
+ loss = self.loss_fn(chosen_reward, reject_reward)
+ self.strategy.backward(loss, self.model, self.optimizer)
+ self.strategy.optimizer_step(self.optimizer)
+ self.optimizer.zero_grad()
+ cnt += 1
+ if cnt == 100:
+ self.scheduler.step()
+ dist, acc = self.eval_acc(self.valid_dataloader)
+ cnt = 0
+ if is_rank_0():
+ log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
+ log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
+ step_bar.update()
+ step_bar.set_postfix({'dist': dist, 'acc': acc})
+
+ # eval
+ dist, acc = self.eval_acc(self.eval_dataloader)
+ if is_rank_0():
+ log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
+ log.to_csv('log.csv', mode='a', header=False, index=False)
+ epoch_bar.update()
+ step_bar.set_postfix({'dist': dist, 'acc': acc})
+ step_bar.close()
diff --git a/applications/ChatGPT/chatgpt/trainer/sft.py b/applications/ChatGPT/chatgpt/trainer/sft.py
new file mode 100644
index 000000000000..3b35f516816f
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/sft.py
@@ -0,0 +1,101 @@
+from abc import ABC
+from typing import Optional
+import loralib as lora
+import torch
+from chatgpt.models.loss import GPTLMLoss
+from torch.optim import Adam, Optimizer
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from tqdm import tqdm
+import torch.distributed as dist
+from .strategies import Strategy
+from .utils import is_rank_0
+from colossalai.logging import get_dist_logger
+
+
+class SFTTrainer(ABC):
+ """
+ Trainer to use while training reward model.
+
+ Args:
+ model (torch.nn.Module): the model to train
+ strategy (Strategy): the strategy to use for training
+ optim(Optimizer): the optimizer to use for training
+ train_dataloader: the dataloader to use for training
+ eval_dataloader: the dataloader to use for evaluation
+ batch_size (int, defaults to 1): the batch size while training
+ max_epochs (int, defaults to 2): the number of epochs to train
+ optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
+ """
+
+ def __init__(
+ self,
+ model,
+ strategy: Strategy,
+ optim: Optimizer,
+ train_dataloader: DataLoader,
+ eval_dataloader: DataLoader = None,
+ sampler: Optional[DistributedSampler] = None,
+ batch_size: int = 1,
+ max_epochs: int = 2,
+ ) -> None:
+ super().__init__()
+ self.strategy = strategy
+ self.epochs = max_epochs
+ self.sampler = sampler
+
+ self.train_dataloader = train_dataloader
+ self.eval_dataloader = eval_dataloader
+
+ self.model = strategy.setup_model(model)
+ if "DDP" in str(self.strategy):
+ self.model = self.model.module
+ self.loss_fn = GPTLMLoss()
+ self.optimizer = strategy.setup_optimizer(optim, self.model)
+
+ def fit(self, logger, use_lora, log_interval=10):
+ epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
+ for epoch in range(self.epochs):
+ if isinstance(self.sampler, DistributedSampler):
+ self.sampler.set_epoch(epoch)
+ # train
+ self.model.train()
+ for batch_id, batch in enumerate(self.train_dataloader):
+ prompt_ids = batch["input_ids"]
+ p_mask = batch["attention_mask"]
+ labels = batch["labels"]
+ prompt_ids = prompt_ids.squeeze(1).cuda()
+ p_mask = p_mask.squeeze(1).cuda()
+ # prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
+ loss, prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
+
+ # loss = self.loss_fn(prompt_logits, labels)
+ self.strategy.backward(loss, self.model, self.optimizer)
+ self.strategy.optimizer_step(self.optimizer)
+ self.optimizer.zero_grad()
+ if batch_id % log_interval == 0:
+ logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
+
+ # eval
+ if self.eval_dataloader is not None:
+ self.model.eval()
+ with torch.no_grad():
+ loss_sum = 0
+ num_seen = 0
+ for batch in self.eval_dataloader:
+ prompt_ids = batch["input_ids"]
+ p_mask = batch["attention_mask"]
+ prompt_ids = prompt_ids.squeeze(1).cuda()
+ p_mask = p_mask.squeeze(1).cuda()
+
+ prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
+ loss = self.loss_fn(prompt_logits, prompt_ids)
+ loss_sum += loss.item()
+ num_seen += prompt_ids.size(0)
+
+ loss_mean = loss_sum / num_seen
+ if dist.get_rank() == 0:
+ logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
+
+ epoch_bar.update()
+
diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/__init__.py b/applications/ChatGPT/chatgpt/trainer/strategies/__init__.py
new file mode 100644
index 000000000000..f258c9b8a873
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/strategies/__init__.py
@@ -0,0 +1,6 @@
+from .base import Strategy
+from .colossalai import ColossalAIStrategy
+from .ddp import DDPStrategy
+from .naive import NaiveStrategy
+
+__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy']
diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/base.py b/applications/ChatGPT/chatgpt/trainer/strategies/base.py
new file mode 100644
index 000000000000..4347c08b4333
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/strategies/base.py
@@ -0,0 +1,131 @@
+from abc import ABC, abstractmethod
+from contextlib import nullcontext
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+from chatgpt.models.base import Actor, Critic, RewardModel
+from chatgpt.replay_buffer import ReplayBuffer
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader
+
+from .sampler import DistributedSampler
+
+ModelOptimPair = Tuple[nn.Module, Optimizer]
+ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
+
+
+class Strategy(ABC):
+ """
+ Base class for training strategies.
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.setup_distributed()
+
+ @abstractmethod
+ def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
+ pass
+
+ @abstractmethod
+ def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
+ pass
+
+ @abstractmethod
+ def setup_distributed(self) -> None:
+ pass
+
+ @abstractmethod
+ def setup_model(self, model: nn.Module) -> nn.Module:
+ pass
+
+ @abstractmethod
+ def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
+ pass
+
+ @abstractmethod
+ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
+ pass
+
+ def model_init_context(self):
+ return nullcontext()
+
+ def prepare(
+ self, *models_or_model_optim_pairs: ModelOrModelOptimPair
+ ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
+ """Prepare models or model-optimizer-pairs based on each strategy.
+
+ Example::
+ >>> # when fine-tuning actor and critic
+ >>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
+ >>> # or when training reward model
+ >>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
+ >>> # or just inference
+ >>> actor, critic = strategy.prepare(actor, critic)
+
+ Returns:
+ Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
+ """
+
+ def prepare_model(model: nn.Module):
+ if isinstance(model, Actor):
+ return Actor(self.setup_model(self._unwrap_model(model)))
+ return self.setup_model(self._unwrap_model(model))
+
+ rets = []
+ for arg in models_or_model_optim_pairs:
+ if isinstance(arg, tuple):
+ assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
+ model, optimizer = arg
+ model = prepare_model(model)
+ optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model))
+ rets.append((model, optimizer))
+ elif isinstance(arg, nn.Module):
+ rets.append(prepare_model(arg))
+ else:
+ raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
+
+ if len(rets) == 1:
+ return rets[0]
+ return rets
+
+ @staticmethod
+ def _unwrap_model(model: nn.Module) -> nn.Module:
+ """Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving.
+
+ Args:
+ model (nn.Module): an actor or a critic
+ """
+ if isinstance(model, Actor):
+ return model.model
+ return model
+
+ @staticmethod
+ def _unwrap_actor(actor: Actor) -> nn.Module:
+ """Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
+
+ Args:
+ actor (Actor): a wrapped actor
+ """
+ return Strategy._unwrap_model(actor)
+
+ @abstractmethod
+ def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
+ pass
+
+ @abstractmethod
+ def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
+ pass
+
+ @abstractmethod
+ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
+ pass
+
+ @abstractmethod
+ def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
+ pass
+
+ def setup_sampler(self, dataset) -> DistributedSampler:
+ return DistributedSampler(dataset, 1, 0)
diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
new file mode 100644
index 000000000000..64ebf12f1922
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
@@ -0,0 +1,181 @@
+import warnings
+from typing import Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.optim as optim
+from chatgpt.models.base import Actor
+from chatgpt.models.lora import LoraLinear
+from torch.optim import Optimizer
+
+
+from transformers.modeling_utils import PreTrainedModel
+from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+
+import colossalai
+from colossalai.nn.optimizer import CPUAdam, HybridAdam
+from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
+from colossalai.nn.parallel.utils import get_static_torch_model
+from colossalai.tensor import ProcessGroup, ShardSpec
+from colossalai.utils import get_current_device
+from colossalai.utils.model.colo_init_context import ColoInitContext
+
+from .base import Strategy
+from .ddp import DDPStrategy
+
+
+class ColossalAIStrategy(DDPStrategy):
+ """
+ The strategy for training with ColossalAI.
+
+ Args:
+ stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
+ seed(int): The seed for the random number generator.
+ shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
+ This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
+ placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
+ If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
+ If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
+ pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
+ force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
+ search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
+ hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
+ min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
+ gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
+ reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
+ overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
+ initial_scale(float): The initial scale for the optimizer.
+ growth_factor(float): The growth factor for the optimizer.
+ backoff_factor(float): The backoff factor for the optimizer.
+ growth_interval(int): The growth interval for the optimizer.
+ hysteresis(int): The hysteresis for the optimizer.
+ min_scale(float): The minimum scale for the optimizer.
+ max_scale(float): The maximum scale for the optimizer.
+ max_norm(float): The maximum norm for the optimizer.
+ norm_type(float): The norm type for the optimizer.
+
+ """
+
+ def __init__(
+ self,
+ stage: int = 3,
+ seed: int = 42,
+ shard_init: bool = False, # only for stage 3
+ placement_policy: str = 'cuda',
+ pin_memory: bool = True, # only for stage 3
+ force_outputs_fp32: bool = False, # only for stage 3
+ search_range_mb: int = 32, # only for stage 3
+ hidden_dim: Optional[int] = None, # only for stage 3
+ min_chunk_size_mb: float = 32, # only for stage 3
+ gpu_margin_mem_ratio: float = 0.0, # only for stage 3
+ reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
+ overlap_communication: bool = True, # only for stage 1&2
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ norm_type: float = 2.0) -> None:
+ super().__init__(seed)
+ assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
+ self.stage = stage
+ # TODO(ver217): support shard_init when using from_pretrained()
+ if shard_init:
+ warnings.warn(
+ f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
+ )
+ self.shard_init = shard_init
+ self.gemini_config = dict(device=get_current_device(),
+ placement_policy=placement_policy,
+ pin_memory=pin_memory,
+ force_outputs_fp32=force_outputs_fp32,
+ strict_ddp_mode=shard_init,
+ search_range_mb=search_range_mb,
+ hidden_dim=hidden_dim,
+ min_chunk_size_mb=min_chunk_size_mb)
+ if stage == 3:
+ self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio)
+ else:
+ self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size,
+ overlap_communication=overlap_communication,
+ cpu_offload=(placement_policy == 'cpu'))
+ self.optim_kwargs = dict(initial_scale=initial_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ min_scale=min_scale,
+ max_scale=max_scale,
+ max_norm=max_norm,
+ norm_type=norm_type)
+
+ def setup_distributed(self) -> None:
+ colossalai.launch_from_torch({}, seed=self.seed)
+
+ def model_init_context(self):
+ if self.stage == 3:
+ world_size = dist.get_world_size()
+ shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
+ default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
+ return ColoInitContext(device=get_current_device(),
+ dtype=torch.half,
+ default_pg=shard_pg,
+ default_dist_spec=default_dist_spec)
+ return super().model_init_context()
+
+ def setup_model(self, model: nn.Module) -> nn.Module:
+ return zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
+
+ def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
+ assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
+ return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs)
+
+ def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
+ optimizer.backward(loss)
+
+ def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
+ optimizer.step()
+
+ @staticmethod
+ def _unwrap_actor(actor: Actor) -> nn.Module:
+ model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor)
+ if isinstance(model, ZeroDDP):
+ return model.module
+ return model
+
+ def save_model(self, model: nn.Module, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
+ unwrapped_model = self._unwrap_model(model)
+ # TODO : better way to get torch model from gemini model
+ # to get torch model from gemini model
+ if isinstance(unwrapped_model, ZeroDDP):
+ state_dict = unwrapped_model.state_dict()
+ unwrapped_model = get_static_torch_model(unwrapped_model)
+ if only_rank0 and dist.get_rank() != 0:
+ return
+ unwrapped_model.load_state_dict(state_dict)
+ # merge lora_weights into weights
+ for module in unwrapped_model.modules():
+ if isinstance(module, LoraLinear):
+ module.merge_weights=True
+ module.eval()
+ # get state_dict and save
+
+ if not isinstance(self.model, PreTrainedModel):
+ state_dict = unwrapped_model.state_dict()
+ if only_rank0 and dist.get_rank() != 0:
+ return
+ torch.save(state_dict, path)
+ else:
+ self.model.save_pretrained(path)
+ if tokenizer is not None:
+ tokenizer.save_pretrained(path)
+
+ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
+ if only_rank0:
+ raise RuntimeError(
+ f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
+ torch.save(optimizer.state_dict(), path)
diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py
new file mode 100644
index 000000000000..c9f92c12fe0a
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/strategies/ddp.py
@@ -0,0 +1,93 @@
+import os
+import random
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from chatgpt.models.base import Actor
+from chatgpt.models.lora import LoraLinear
+from chatgpt.replay_buffer import ReplayBuffer
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader
+
+from .base import Strategy
+from .naive import NaiveStrategy
+from .sampler import DistributedSampler
+
+
+class DDPStrategy(NaiveStrategy):
+ """
+ Strategy for distributed training using torch.distributed.
+ """
+
+ def __init__(self, seed: int = 42) -> None:
+ self.seed = seed
+ super().__init__()
+
+ def setup_distributed(self) -> None:
+ try:
+ rank = int(os.environ['RANK'])
+ local_rank = int(os.environ['LOCAL_RANK'])
+ world_size = int(os.environ['WORLD_SIZE'])
+ host = os.environ['MASTER_ADDR']
+ port = int(os.environ['MASTER_PORT'])
+ except KeyError as e:
+ raise RuntimeError(
+ f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
+ )
+ dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
+ self.set_seed(self.seed)
+ torch.cuda.set_device(local_rank)
+
+ def set_seed(self, seed: int) -> None:
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+ def setup_model(self, model: nn.Module) -> nn.Module:
+ device = torch.cuda.current_device()
+ return DDP(model, device_ids=[device])
+
+ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
+ # DDP only mode, replay buffers on each rank are different.
+ # sampler = DistributedSampler(replay_buffer,
+ # num_replicas=dist.get_world_size(),
+ # rank=dist.get_rank(),
+ # shuffle=True,
+ # seed=self.seed,
+ # drop_last=True)
+ return DataLoader(
+ replay_buffer,
+ batch_size=replay_buffer.sample_batch_size,
+ # sampler=sampler,
+ shuffle=True,
+ drop_last=True,
+ pin_memory=pin_memory,
+ collate_fn=replay_buffer.collate_fn)
+
+ @staticmethod
+ def _unwrap_actor(actor: Actor) -> nn.Module:
+ model: DDP = Strategy._unwrap_actor(actor)
+ return model.module
+
+ def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
+ for module in model.modules():
+ if isinstance(module, LoraLinear):
+ module.merge_weights=True
+ module.eval()
+
+ if only_rank0 and dist.get_rank() != 0:
+ return
+ model = model.model.module
+ state_dict = model.state_dict()
+ torch.save(state_dict, path)
+
+ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
+ if only_rank0 and dist.get_rank() != 0:
+ return
+ super().save_optimizer(optimizer, path, only_rank0)
+
+ def setup_sampler(self, dataset) -> DistributedSampler:
+ return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/naive.py b/applications/ChatGPT/chatgpt/trainer/strategies/naive.py
new file mode 100644
index 000000000000..99b8d6635394
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/strategies/naive.py
@@ -0,0 +1,55 @@
+from typing import Any
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from chatgpt.replay_buffer import ReplayBuffer
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader
+
+from .base import Strategy
+
+
+class NaiveStrategy(Strategy):
+ """
+ Strategy for single GPU. No parallelism is used.
+ """
+
+ def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
+ loss.backward()
+
+ def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
+ optimizer.step()
+
+ def setup_distributed(self) -> None:
+ pass
+
+ def setup_model(self, model: nn.Module) -> nn.Module:
+ return model
+
+ def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
+ return optimizer
+
+ def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
+ return DataLoader(replay_buffer,
+ batch_size=replay_buffer.sample_batch_size,
+ shuffle=True,
+ drop_last=True,
+ pin_memory=pin_memory,
+ collate_fn=replay_buffer.collate_fn)
+
+ def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
+ unwrapped_model = self._unwrap_model(model)
+ torch.save(unwrapped_model.state_dict(), path)
+
+ def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
+ unwrapped_model = self._unwrap_model(model)
+ state_dict = torch.load(path, map_location=map_location)
+ unwrapped_model.load_state_dict(state_dict, strict=strict)
+
+ def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
+ torch.save(optimizer.state_dict(), path)
+
+ def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
+ state_dict = torch.load(path, map_location=map_location)
+ optimizer.load_state_dict(state_dict)
diff --git a/applications/ChatGPT/chatgpt/trainer/strategies/sampler.py b/applications/ChatGPT/chatgpt/trainer/strategies/sampler.py
new file mode 100644
index 000000000000..d726fa640fa2
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/strategies/sampler.py
@@ -0,0 +1,32 @@
+import math
+
+import numpy as np
+
+
+class DistributedSampler:
+
+ def __init__(self, dataset, num_replicas: int, rank: int) -> None:
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+
+ if len(self.dataset) % self.num_replicas != 0:
+ self.num_samples = math.ceil(
+ (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
+ )
+ else:
+ self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
+
+ self.total_size = self.num_samples * self.num_replicas
+
+ indices = list(range(len(self.dataset)))
+ indices = indices[:self.total_size]
+ assert len(indices) == self.total_size
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+ self.indices = indices
+
+ def sample(self, batch_size: int) -> list:
+ sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
+ return [self.dataset[idx] for idx in sampled_indices]
diff --git a/applications/ChatGPT/chatgpt/trainer/utils.py b/applications/ChatGPT/chatgpt/trainer/utils.py
new file mode 100644
index 000000000000..6c9f7f085f8c
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/trainer/utils.py
@@ -0,0 +1,5 @@
+import torch.distributed as dist
+
+
+def is_rank_0() -> bool:
+ return not dist.is_initialized() or dist.get_rank() == 0
diff --git a/applications/ChatGPT/chatgpt/utils/__init__.py b/applications/ChatGPT/chatgpt/utils/__init__.py
new file mode 100644
index 000000000000..8f526d7efdad
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/utils/__init__.py
@@ -0,0 +1,3 @@
+from .tokenizer_utils import smart_tokenizer_and_embedding_resize, prepare_llama_tokenizer_and_embedding
+
+__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']
\ No newline at end of file
diff --git a/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py b/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
new file mode 100644
index 000000000000..8699bf64c7b5
--- /dev/null
+++ b/applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
@@ -0,0 +1,74 @@
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict
+
+import transformers
+
+DEFAULT_PAD_TOKEN = "[PAD]"
+DEFAULT_EOS_TOKEN = ""
+DEFAULT_BOS_TOKEN = ""
+DEFAULT_UNK_TOKEN = ""
+
+def prepare_llama_tokenizer_and_embedding(
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+ special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
+):
+ """prepare llama tokenizer and embedding.
+
+ """
+
+ if tokenizer.pad_token is None:
+ smart_tokenizer_and_embedding_resize(
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
+ tokenizer=tokenizer,
+ model=model,
+ )
+
+ tokenizer.add_special_tokens(
+ {
+ "eos_token": DEFAULT_EOS_TOKEN,
+ "bos_token": DEFAULT_BOS_TOKEN,
+ "unk_token": DEFAULT_UNK_TOKEN,
+ }
+ )
+
+ return tokenizer
+
+
+def smart_tokenizer_and_embedding_resize(
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+ special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+
+ if tokenizer.pad_token is None:
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
\ No newline at end of file
diff --git a/applications/ChatGPT/examples/README.md b/applications/ChatGPT/examples/README.md
new file mode 100644
index 000000000000..ce73a5407944
--- /dev/null
+++ b/applications/ChatGPT/examples/README.md
@@ -0,0 +1,141 @@
+# Examples
+
+## Install requirements
+
+```shell
+pip install -r requirements.txt
+```
+
+## Train the reward model (Stage 2)
+Use these code to train your reward model.
+```shell
+# Take naive reward model training with opt-350m as example
+python train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive
+# use colossalai_zero2
+torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
+```
+
+### Features and tricks in RM training
+- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets.
+- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic).
+- We change the loss to valid_acc and pair_dist to monitor progress during training.
+- We add special token to the end of the sequence to get better result.
+- We use cosine-reducing lr-scheduler for RM training.
+- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
+- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2112.00861).
+
+### Experiment result
+Model performance in [Anthropics paper](https://arxiv.org/abs/2112.00861):
+
+
+
+
Our training & test result of bloom-560m for 1 epoch:
+
+
+
+
+
+## Train with dummy prompt data (Stage 3)
+
+This script supports 4 kinds of strategies:
+
+- naive
+- ddp
+- colossalai_zero2
+- colossalai_gemini
+
+It uses random generated prompt data.
+
+Naive strategy only support single GPU training:
+
+```shell
+python train_dummy.py --strategy naive
+# display cli help
+python train_dummy.py -h
+```
+
+DDP strategy and ColossalAI strategy support multi GPUs training:
+
+```shell
+# run DDP on 2 GPUs
+torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp
+# run ColossalAI on 2 GPUs
+torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2
+```
+
+## Train with real prompt data (Stage 3)
+
+We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts.
+
+You should download `prompts.csv` first.
+
+This script also supports 4 strategies.
+
+```shell
+# display cli help
+python train_dummy.py -h
+# run naive on 1 GPU
+python train_prompts.py prompts.csv --strategy naive
+# run DDP on 2 GPUs
+torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy ddp
+# run ColossalAI on 2 GPUs
+torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
+```
+
+## Inference example(After Stage3)
+We support naive inference demo after training.
+```shell
+# inference, using pretrain path to configure model
+python inference.py --model_path --model --pretrain
+# example
+python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom
+```
+
+## Attention
+The examples is just a demo for testing our progress of RM and PPO training.
+
+
+#### data
+- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
+- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)
+- [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
+- [ ] [openai/webgpt_comparisons](https://huggingface.co/datasets/openai/webgpt_comparisons)
+- [ ] [Dahoas/instruct-synthetic-prompt-responses](https://huggingface.co/datasets/Dahoas/instruct-synthetic-prompt-responses)
+
+## Support Model
+
+### GPT
+- [x] GPT2-S (s)
+- [x] GPT2-M (m)
+- [x] GPT2-L (l)
+- [ ] GPT2-XL (xl)
+- [x] GPT2-4B (4b)
+- [ ] GPT2-6B (6b)
+- [ ] GPT2-8B (8b)
+- [ ] GPT2-10B (10b)
+- [ ] GPT2-12B (12b)
+- [ ] GPT2-15B (15b)
+- [ ] GPT2-18B (18b)
+- [ ] GPT2-20B (20b)
+- [ ] GPT2-24B (24b)
+- [ ] GPT2-28B (28b)
+- [ ] GPT2-32B (32b)
+- [ ] GPT2-36B (36b)
+- [ ] GPT2-40B (40b)
+- [ ] GPT3 (175b)
+
+### BLOOM
+- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
+- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
+- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
+- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloom-7b1)
+- [ ] BLOOM-175b
+
+### OPT
+- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
+- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
+- [ ] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
+- [ ] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b)
+- [ ] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b)
+- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b)
+- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)
diff --git a/applications/ChatGPT/examples/inference.py b/applications/ChatGPT/examples/inference.py
new file mode 100644
index 000000000000..08885c33b194
--- /dev/null
+++ b/applications/ChatGPT/examples/inference.py
@@ -0,0 +1,59 @@
+import argparse
+
+import torch
+from chatgpt.models.bloom import BLOOMActor
+from chatgpt.models.gpt import GPTActor
+from chatgpt.models.opt import OPTActor
+from transformers import AutoTokenizer
+from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+
+
+def eval(args):
+ # configure model
+ if args.model == 'gpt2':
+ actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
+ elif args.model == 'bloom':
+ actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
+ elif args.model == 'opt':
+ actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+
+ state_dict = torch.load(args.model_path)
+ actor.model.load_state_dict(state_dict)
+
+ # configure tokenizer
+ if args.model == 'gpt2':
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'bloom':
+ tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'opt':
+ tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+
+ actor.eval()
+ input = args.input
+ input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device())
+ outputs = actor.generate(input_ids,
+ max_length=args.max_length,
+ do_sample=True,
+ top_k=50,
+ top_p=0.95,
+ num_return_sequences=1)
+ output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
+ print(output)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
+ # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
+ parser.add_argument('--pretrain', type=str, default=None)
+ parser.add_argument('--model_path', type=str, default=None)
+ parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
+ parser.add_argument('--max_length', type=int, default=100)
+ args = parser.parse_args()
+ eval(args)
diff --git a/applications/ChatGPT/examples/requirements.txt b/applications/ChatGPT/examples/requirements.txt
new file mode 100644
index 000000000000..40e6edc7ea73
--- /dev/null
+++ b/applications/ChatGPT/examples/requirements.txt
@@ -0,0 +1,2 @@
+pandas>=1.4.1
+sentencepiece
diff --git a/applications/ChatGPT/examples/test_ci.sh b/applications/ChatGPT/examples/test_ci.sh
new file mode 100755
index 000000000000..1d05c4c58341
--- /dev/null
+++ b/applications/ChatGPT/examples/test_ci.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+set -xue
+
+if [ -z "$PROMPT_PATH" ]; then
+ echo "Please set \$PROMPT_PATH to the path to prompts csv."
+ exit 1
+fi
+
+BASE=$(realpath $(dirname $0))
+
+export OMP_NUM_THREADS=8
+
+# install requirements
+pip install -r ${BASE}/requirements.txt
+
+# train dummy
+python ${BASE}/train_dummy.py --strategy naive --num_episodes 1 \
+ --max_timesteps 2 --update_timesteps 2 \
+ --max_epochs 1 --train_batch_size 2 --lora_rank 4
+
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
+ --strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
+ --update_timesteps 2 --max_epochs 1 --train_batch_size 2\
+ --pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
+ --save_path ${BASE}/actor_checkpoint_dummy.pt
+python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'facebook/opt-350m' --model opt
+
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
+ --strategy ddp --num_episodes 1 --max_timesteps 2 \
+ --update_timesteps 2 --max_epochs 1 --train_batch_size 2\
+ --pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
+ --save_path ${BASE}/actor_checkpoint_dummy.pt
+python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'facebook/opt-350m' --model opt
+
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
+ --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
+ --update_timesteps 2 --max_epochs 1 --train_batch_size 2\
+ --pretrain 'gpt2' --model gpt2 --lora_rank 4\
+ --save_path ${BASE}/actor_checkpoint_dummy.pt
+python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'gpt2' --model gpt2
+
+rm -rf ${BASE}/actor_checkpoint_dummy.pt
+
+# train prompts
+python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 1 \
+ --max_timesteps 2 --update_timesteps 2 \
+ --max_epochs 1 --train_batch_size 2 --lora_rank 4
+
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
+ --strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
+ --update_timesteps 2 --max_epochs 1 --train_batch_size 2\
+ --pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
+ --save_path ${BASE}/actor_checkpoint_prompts.pt
+python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'facebook/opt-350m' --model opt
+
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
+ --strategy ddp --num_episodes 1 --max_timesteps 2 \
+ --update_timesteps 2 --max_epochs 1 --train_batch_size 2\
+ --pretrain 'gpt2' --model gpt2 --lora_rank 4\
+ --save_path ${BASE}/actor_checkpoint_prompts.pt
+python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
+
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
+ --strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
+ --update_timesteps 2 --max_epochs 1 --train_batch_size 2\
+ --pretrain 'gpt2' --model gpt2 --lora_rank 4\
+ --save_path ${BASE}/actor_checkpoint_prompts.pt
+python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
+
+rm -rf ${BASE}/actor_checkpoint_prompts.pt
+
+# train rm
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
+ --pretrain 'facebook/opt-350m' --model 'opt' \
+ --strategy colossalai_zero2 --loss_fn 'log_sig'\
+ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
+ --test True --lora_rank 4
+
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
+ --pretrain 'gpt2' --model 'gpt2' \
+ --strategy colossalai_gemini --loss_fn 'log_exp'\
+ --dataset 'Dahoas/rm-static' --test True --lora_rank 4
+
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
+ --pretrain 'bigscience/bloom-560m' --model 'bloom' \
+ --strategy colossalai_zero2 --loss_fn 'log_sig'\
+ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
+ --test True --lora_rank 4
+
+torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
+ --pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
+ --strategy colossalai_zero2 --loss_fn 'log_sig'\
+ --dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
+ --test True --lora_rank 4
+
+rm -rf ${BASE}/rm_ckpt.pt
diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py
new file mode 100644
index 000000000000..c0ebf8f9b7b6
--- /dev/null
+++ b/applications/ChatGPT/examples/train_dummy.py
@@ -0,0 +1,148 @@
+import argparse
+from copy import deepcopy
+
+import torch
+from chatgpt.models.base import RewardModel
+from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
+from chatgpt.models.gpt import GPTActor, GPTCritic
+from chatgpt.models.opt import OPTActor, OPTCritic
+from chatgpt.trainer import PPOTrainer
+from chatgpt.trainer.callbacks import SaveCheckpoint
+from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
+from torch.optim import Adam
+from transformers import AutoTokenizer, BloomTokenizerFast
+from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+
+from colossalai.nn.optimizer import HybridAdam
+
+
+def preprocess_batch(samples):
+ input_ids = torch.stack(samples)
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long)
+ return {'input_ids': input_ids, 'attention_mask': attention_mask}
+
+
+def main(args):
+ # configure strategy
+ if args.strategy == 'naive':
+ strategy = NaiveStrategy()
+ elif args.strategy == 'ddp':
+ strategy = DDPStrategy()
+ elif args.strategy == 'colossalai_gemini':
+ strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
+ elif args.strategy == 'colossalai_zero2':
+ strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ else:
+ raise ValueError(f'Unsupported strategy "{args.strategy}"')
+
+ # configure model
+ with strategy.model_init_context():
+ if args.model == 'gpt2':
+ actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ elif args.model == 'bloom':
+ actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ elif args.model == 'opt':
+ actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+
+ initial_model = deepcopy(actor).to(torch.cuda.current_device())
+ reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
+
+ # configure optimizer
+ if args.strategy.startswith('colossalai'):
+ actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
+ critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
+ else:
+ actor_optim = Adam(actor.parameters(), lr=5e-6)
+ critic_optim = Adam(critic.parameters(), lr=5e-6)
+
+ # configure tokenizer
+ if args.model == 'gpt2':
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'bloom':
+ tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'opt':
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
+
+ callbacks = []
+ if args.save_ckpt_path:
+ ckpt_callback = SaveCheckpoint(
+ args.save_ckpt_path,
+ args.save_ckpt_interval,
+ strategy,
+ actor,
+ critic,
+ actor_optim,
+ critic_optim,
+ )
+ callbacks.append(ckpt_callback)
+
+ # configure trainer
+
+ trainer = PPOTrainer(strategy,
+ actor,
+ critic,
+ reward_model,
+ initial_model,
+ actor_optim,
+ critic_optim,
+ max_epochs=args.max_epochs,
+ train_batch_size=args.train_batch_size,
+ tokenizer=preprocess_batch,
+ max_length=128,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ callbacks=callbacks)
+
+ random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
+ trainer.fit(random_prompts,
+ num_episodes=args.num_episodes,
+ max_timesteps=args.max_timesteps,
+ update_timesteps=args.update_timesteps)
+
+ # save model checkpoint after fitting
+ strategy.save_model(actor, args.save_path, only_rank0=True)
+ # save optimizer checkpoint on all ranks
+ if args.need_optim_ckpt:
+ strategy.save_optimizer(actor_optim,
+ 'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
+ only_rank0=False)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--strategy',
+ choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
+ default='naive')
+ parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
+ parser.add_argument('--pretrain', type=str, default=None)
+ parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
+ parser.add_argument('--need_optim_ckpt', type=bool, default=False)
+ parser.add_argument('--num_episodes', type=int, default=50)
+ parser.add_argument('--max_timesteps', type=int, default=10)
+ parser.add_argument('--update_timesteps', type=int, default=10)
+ parser.add_argument('--max_epochs', type=int, default=5)
+ parser.add_argument('--train_batch_size', type=int, default=8)
+ parser.add_argument('--experience_batch_size', type=int, default=8)
+ parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument('--save_ckpt_path',
+ type=str,
+ default=None,
+ help="path to save checkpoint, None means not to save")
+ parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint")
+ args = parser.parse_args()
+ main(args)
diff --git a/applications/ChatGPT/examples/train_dummy.sh b/applications/ChatGPT/examples/train_dummy.sh
new file mode 100755
index 000000000000..595da573e2b1
--- /dev/null
+++ b/applications/ChatGPT/examples/train_dummy.sh
@@ -0,0 +1,18 @@
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
+ | tail -n +2 \
+ | nl -v 0 \
+ | tee /dev/tty \
+ | sort -g -k 2 \
+ | awk '{print $1}' \
+ | head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 2
+
+torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2
diff --git a/applications/ChatGPT/examples/train_prompts.py b/applications/ChatGPT/examples/train_prompts.py
new file mode 100644
index 000000000000..8f48a11c33e8
--- /dev/null
+++ b/applications/ChatGPT/examples/train_prompts.py
@@ -0,0 +1,132 @@
+import argparse
+from copy import deepcopy
+
+import pandas as pd
+import torch
+from chatgpt.models.base import RewardModel
+from chatgpt.models.bloom import BLOOMActor, BLOOMCritic
+from chatgpt.models.gpt import GPTActor, GPTCritic
+from chatgpt.models.opt import OPTActor, OPTCritic
+from chatgpt.trainer import PPOTrainer
+from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
+from torch.optim import Adam
+from transformers import AutoTokenizer, BloomTokenizerFast
+from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+
+from colossalai.nn.optimizer import HybridAdam
+
+
+def main(args):
+ # configure strategy
+ if args.strategy == 'naive':
+ strategy = NaiveStrategy()
+ elif args.strategy == 'ddp':
+ strategy = DDPStrategy()
+ elif args.strategy == 'colossalai_gemini':
+ strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
+ elif args.strategy == 'colossalai_zero2':
+ strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ else:
+ raise ValueError(f'Unsupported strategy "{args.strategy}"')
+
+ # configure model
+ with strategy.model_init_context():
+ if args.model == 'gpt2':
+ actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ elif args.model == 'bloom':
+ actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ elif args.model == 'opt':
+ actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+
+ initial_model = deepcopy(actor)
+ reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
+
+ # configure optimizer
+ if args.strategy.startswith('colossalai'):
+ actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
+ critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
+ else:
+ actor_optim = Adam(actor.parameters(), lr=5e-6)
+ critic_optim = Adam(critic.parameters(), lr=5e-6)
+
+ # configure tokenizer
+ if args.model == 'gpt2':
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'bloom':
+ tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'opt':
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+
+ dataset = pd.read_csv(args.prompt_path)['prompt']
+
+ def tokenize_fn(texts):
+ # MUST padding to max length to ensure inputs of all ranks have the same length
+ # Different length may lead to hang when using gemini, as different generation steps
+ batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
+ return {k: v.cuda() for k, v in batch.items()}
+
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
+ (actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
+
+ # configure trainer
+ trainer = PPOTrainer(
+ strategy,
+ actor,
+ critic,
+ reward_model,
+ initial_model,
+ actor_optim,
+ critic_optim,
+ max_epochs=args.max_epochs,
+ train_batch_size=args.train_batch_size,
+ experience_batch_size=args.experience_batch_size,
+ tokenizer=tokenize_fn,
+ max_length=128,
+ do_sample=True,
+ temperature=1.0,
+ top_k=50,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ )
+
+ trainer.fit(dataset,
+ num_episodes=args.num_episodes,
+ max_timesteps=args.max_timesteps,
+ update_timesteps=args.update_timesteps)
+ # save model checkpoint after fitting
+ strategy.save_model(actor, args.save_path, only_rank0=True)
+ # save optimizer checkpoint on all ranks
+ if args.need_optim_ckpt:
+ strategy.save_optimizer(actor_optim,
+ 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
+ only_rank0=False)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('prompt_path')
+ parser.add_argument('--strategy',
+ choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
+ default='naive')
+ parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
+ parser.add_argument('--pretrain', type=str, default=None)
+ parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts.pt')
+ parser.add_argument('--need_optim_ckpt', type=bool, default=False)
+ parser.add_argument('--num_episodes', type=int, default=10)
+ parser.add_argument('--max_timesteps', type=int, default=10)
+ parser.add_argument('--update_timesteps', type=int, default=10)
+ parser.add_argument('--max_epochs', type=int, default=5)
+ parser.add_argument('--train_batch_size', type=int, default=8)
+ parser.add_argument('--experience_batch_size', type=int, default=8)
+ parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
+ args = parser.parse_args()
+ main(args)
diff --git a/applications/ChatGPT/examples/train_prompts.sh b/applications/ChatGPT/examples/train_prompts.sh
new file mode 100755
index 000000000000..db73ac8e8e85
--- /dev/null
+++ b/applications/ChatGPT/examples/train_prompts.sh
@@ -0,0 +1,18 @@
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
+ | tail -n +2 \
+ | nl -v 0 \
+ | tee /dev/tty \
+ | sort -g -k 2 \
+ | awk '{print $1}' \
+ | head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 2
+
+torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
diff --git a/applications/ChatGPT/examples/train_reward_model.py b/applications/ChatGPT/examples/train_reward_model.py
new file mode 100644
index 000000000000..a9c844b7b1f8
--- /dev/null
+++ b/applications/ChatGPT/examples/train_reward_model.py
@@ -0,0 +1,143 @@
+import argparse
+
+import loralib as lora
+import torch
+from chatgpt.dataset import HhRlhfDataset, RmStaticDataset
+from chatgpt.models import LogSigLoss, LogExpLoss
+from chatgpt.models.base import RewardModel
+from chatgpt.models.bloom import BLOOMRM
+from chatgpt.models.gpt import GPTRM
+from chatgpt.models.opt import OPTRM
+from chatgpt.models.deberta import DebertaRM
+from chatgpt.trainer import RewardModelTrainer
+from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
+from datasets import load_dataset
+from random import randint
+from torch.optim import Adam
+from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer
+from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+
+from colossalai.nn.optimizer import HybridAdam
+
+def train(args):
+ # configure strategy
+ if args.strategy == 'naive':
+ strategy = NaiveStrategy()
+ elif args.strategy == 'ddp':
+ strategy = DDPStrategy()
+ elif args.strategy == 'colossalai_gemini':
+ strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
+ elif args.strategy == 'colossalai_zero2':
+ strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ else:
+ raise ValueError(f'Unsupported strategy "{args.strategy}"')
+
+ # configure model
+ with strategy.model_init_context():
+ if args.model == 'bloom':
+ model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ elif args.model == 'opt':
+ model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ elif args.model == 'gpt2':
+ model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ elif args.model == 'deberta':
+ model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+
+ if args.model_path is not None:
+ state_dict = torch.load(args.model_path)
+ model.load_state_dict(state_dict)
+
+ # configure tokenizer
+ if args.model == 'gpt2':
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'bloom':
+ tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
+ elif args.model == 'opt':
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ elif args.model == 'deberta':
+ tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+ max_len = args.max_len
+
+ # configure optimizer
+ if args.strategy.startswith('colossalai'):
+ optim = HybridAdam(model.parameters(), lr=1.5e-5)
+ else:
+ optim = Adam(model.parameters(), lr=1.5e-5)
+
+ # configure loss function
+ if args.loss_fn == 'log_sig':
+ loss_fn = LogSigLoss()
+ elif args.loss_fn == 'log_exp':
+ loss_fn = LogExpLoss()
+ else:
+ raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
+
+ # prepare for data and dataset
+ if args.subset is not None:
+ data = load_dataset(args.dataset, data_dir=args.subset)
+ else:
+ data = load_dataset(args.dataset)
+
+ if args.test:
+ train_data = data['train'].select(range(100))
+ eval_data = data['test'].select(range(10))
+ else:
+ train_data = data['train']
+ eval_data = data['test']
+ valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data)//10)))
+
+ if args.dataset == 'Dahoas/rm-static':
+ train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
+ valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
+ eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
+ elif args.dataset == 'Anthropic/hh-rlhf':
+ train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
+ valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
+ eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
+ else:
+ raise ValueError(f'Unsupported dataset "{args.dataset}"')
+
+ trainer = RewardModelTrainer(model=model,
+ strategy=strategy,
+ optim=optim,
+ loss_fn = loss_fn,
+ train_dataset=train_dataset,
+ valid_dataset=valid_dataset,
+ eval_dataset=eval_dataset,
+ batch_size=args.batch_size,
+ max_epochs=args.max_epochs)
+
+ trainer.fit()
+ # save model checkpoint after fitting on only rank0
+ strategy.save_model(trainer.model, args.save_path, only_rank0=True)
+ # save optimizer checkpoint on all ranks
+ if args.need_optim_ckpt:
+ strategy.save_optimizer(trainer.optimizer, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--strategy',
+ choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
+ default='naive')
+ parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta'], default='bloom')
+ parser.add_argument('--pretrain', type=str, default=None)
+ parser.add_argument('--model_path', type=str, default=None)
+ parser.add_argument('--need_optim_ckpt', type=bool, default=False)
+ parser.add_argument('--dataset', type=str,
+ choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
+ default='Dahoas/rm-static')
+ parser.add_argument('--subset', type=str, default=None)
+ parser.add_argument('--save_path', type=str, default='rm_ckpt.pt')
+ parser.add_argument('--max_epochs', type=int, default=1)
+ parser.add_argument('--batch_size', type=int, default=1)
+ parser.add_argument('--max_len', type=int, default=512)
+ parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
+ parser.add_argument('--test', type=bool, default=False)
+ args = parser.parse_args()
+ train(args)
diff --git a/applications/ChatGPT/examples/train_rm.sh b/applications/ChatGPT/examples/train_rm.sh
new file mode 100755
index 000000000000..4f9f55b6b59a
--- /dev/null
+++ b/applications/ChatGPT/examples/train_rm.sh
@@ -0,0 +1,8 @@
+set_n_least_used_CUDA_VISIBLE_DEVICES 1
+
+python train_reward_model.py --pretrain 'microsoft/deberta-v3-large' \
+ --model 'deberta' \
+ --strategy naive \
+ --loss_fn 'log_exp'\
+ --save_path 'rmstatic.pt' \
+ --test True
diff --git a/applications/ChatGPT/examples/train_sft.py b/applications/ChatGPT/examples/train_sft.py
new file mode 100644
index 000000000000..83b34f9dd1ea
--- /dev/null
+++ b/applications/ChatGPT/examples/train_sft.py
@@ -0,0 +1,141 @@
+import argparse
+
+import loralib as lora
+import torch
+import torch.distributed as dist
+from torch.utils.data.distributed import DistributedSampler
+from chatgpt.dataset import SFTDataset, AlpacaDataset, AlpacaDataCollator
+from chatgpt.models.base import RewardModel
+from chatgpt.models.bloom import BLOOMLM
+from chatgpt.models.gpt import GPTLM
+from chatgpt.models.opt import OPTLM
+from chatgpt.models.llama import LlamaLM
+from chatgpt.trainer import SFTTrainer
+from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
+from chatgpt.utils import prepare_llama_tokenizer_and_embedding
+from datasets import load_dataset
+from torch.optim import Adam
+from torch.utils.data import DataLoader
+from transformers import AutoTokenizer, BloomTokenizerFast
+from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
+
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.logging import get_dist_logger
+
+
+def train(args):
+ # configure strategy
+ if args.strategy == 'naive':
+ strategy = NaiveStrategy()
+ elif args.strategy == 'ddp':
+ strategy = DDPStrategy()
+ elif args.strategy == 'colossalai_gemini':
+ strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
+ elif args.strategy == 'colossalai_zero2':
+ strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ else:
+ raise ValueError(f'Unsupported strategy "{args.strategy}"')
+
+ # configure model
+ with strategy.model_init_context():
+ if args.model == 'bloom':
+ model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
+ elif args.model == 'opt':
+ model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
+ elif args.model == 'gpt2':
+ model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
+ elif args.model == 'llama':
+ model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+
+ # configure tokenizer
+ if args.model == 'gpt2':
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'bloom':
+ tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
+ tokenizer.pad_token = tokenizer.eos_token
+ elif args.model == 'opt':
+ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+ elif args.model == 'llama':
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrain,
+ padding_side="right",
+ use_fast=False,
+ )
+ else:
+ raise ValueError(f'Unsupported model "{args.model}"')
+
+ if args.model == 'llama':
+ tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
+ else:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ max_len = 512
+
+ # configure optimizer
+ if args.strategy.startswith('colossalai'):
+ optim = HybridAdam(model.parameters(), lr=5e-5)
+ else:
+ optim = Adam(model.parameters(), lr=5e-5)
+
+ logger = get_dist_logger()
+
+ # configure dataset
+ if args.dataset == 'yizhongw/self_instruct':
+ train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
+ eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
+
+ train_dataset = SFTDataset(train_data, tokenizer, max_len)
+ eval_dataset = SFTDataset(eval_data, tokenizer, max_len)
+
+ elif 'alpaca' in args.dataset:
+ train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
+ eval_dataset = None
+ eval_dataset
+ data_collator = AlpacaDataCollator(tokenizer=tokenizer)
+
+ if dist.is_initialized() and dist.get_world_size() > 1:
+ sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
+ logger.info("Using Distributed Sampler")
+ else:
+ sampler = None
+
+ train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size)
+ if eval_dataset is not None:
+ eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size)
+
+ trainer = SFTTrainer(model=model,
+ strategy=strategy,
+ optim=optim,
+ train_dataloader=train_dataloader,
+ eval_dataloader=eval_dataloader,
+ sampler=sampler,
+ batch_size=args.batch_size,
+ max_epochs=args.max_epochs)
+
+ trainer.fit(logger=logger, use_lora=args.lora_rank, log_interval=args.log_interval)
+
+ # save model checkpoint after fitting on only rank0
+ strategy.save_model(model, 'sft_checkpoint.pt', only_rank0=True)
+ # save optimizer checkpoint on all ranks
+ strategy.save_optimizer(optim, 'sft_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--strategy',
+ choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
+ default='naive')
+ parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
+ parser.add_argument('--pretrain', type=str, default=None)
+ parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct')
+ parser.add_argument('--save_path', type=str, default='sft_ckpt.pth')
+ parser.add_argument('--max_epochs', type=int, default=1)
+ parser.add_argument('--batch_size', type=int, default=4)
+ parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
+ parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
+ args = parser.parse_args()
+ train(args)
+
diff --git a/applications/ChatGPT/examples/train_sft.sh b/applications/ChatGPT/examples/train_sft.sh
new file mode 100755
index 000000000000..9f747b24689e
--- /dev/null
+++ b/applications/ChatGPT/examples/train_sft.sh
@@ -0,0 +1,20 @@
+set_n_least_used_CUDA_VISIBLE_DEVICES() {
+ local n=${1:-"9999"}
+ echo "GPU Memory Usage:"
+ local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
+ | tail -n +2 \
+ | nl -v 0 \
+ | tee /dev/tty \
+ | sort -g -k 2 \
+ | awk '{print $1}' \
+ | head -n $n)
+ export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
+ echo "Now CUDA_VISIBLE_DEVICES is set to:"
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
+}
+
+set_n_least_used_CUDA_VISIBLE_DEVICES 8
+
+#torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 --log_interval 10
+#torchrun --standalone --nproc_per_node=8 train_sft.py --model 'gpt2' --strategy colossalai_zero2 --batch_size 1 --log_interval 10
+torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 --log_interval 10
diff --git a/applications/ChatGPT/pytest.ini b/applications/ChatGPT/pytest.ini
new file mode 100644
index 000000000000..01e5cd217c5d
--- /dev/null
+++ b/applications/ChatGPT/pytest.ini
@@ -0,0 +1,6 @@
+[pytest]
+markers =
+ cpu: tests which can run on CPU
+ gpu: tests which requires a single GPU
+ dist: tests which are run in a multi-GPU or multi-machine environment
+ experiment: tests for experimental features
diff --git a/applications/ChatGPT/requirements-test.txt b/applications/ChatGPT/requirements-test.txt
new file mode 100644
index 000000000000..e079f8a6038d
--- /dev/null
+++ b/applications/ChatGPT/requirements-test.txt
@@ -0,0 +1 @@
+pytest
diff --git a/applications/ChatGPT/requirements.txt b/applications/ChatGPT/requirements.txt
new file mode 100644
index 000000000000..3469111925ff
--- /dev/null
+++ b/applications/ChatGPT/requirements.txt
@@ -0,0 +1,7 @@
+transformers>=4.20.1
+tqdm
+datasets
+loralib
+colossalai>=0.2.4
+torch==1.12.1
+langchain
diff --git a/applications/ChatGPT/setup.py b/applications/ChatGPT/setup.py
new file mode 100644
index 000000000000..deec10e0c841
--- /dev/null
+++ b/applications/ChatGPT/setup.py
@@ -0,0 +1,41 @@
+from setuptools import find_packages, setup
+
+
+def fetch_requirements(path):
+ with open(path, 'r') as fd:
+ return [r.strip() for r in fd.readlines()]
+
+
+def fetch_readme():
+ with open('README.md', encoding='utf-8') as f:
+ return f.read()
+
+
+def fetch_version():
+ with open('version.txt', 'r') as f:
+ return f.read().strip()
+
+
+setup(
+ name='chatgpt',
+ version=fetch_version(),
+ packages=find_packages(exclude=(
+ 'tests',
+ 'benchmarks',
+ '*.egg-info',
+ )),
+ description='A RLFH implementation (ChatGPT) powered by ColossalAI',
+ long_description=fetch_readme(),
+ long_description_content_type='text/markdown',
+ license='Apache Software License 2.0',
+ url='https://github.com/hpcaitech/ChatGPT',
+ install_requires=fetch_requirements('requirements.txt'),
+ python_requires='>=3.6',
+ classifiers=[
+ 'Programming Language :: Python :: 3',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Environment :: GPU :: NVIDIA CUDA',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'Topic :: System :: Distributed Computing',
+ ],
+)
diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/__init__.py b/applications/ChatGPT/tests/__init__.py
similarity index 100%
rename from examples/tutorial/stable_diffusion/ldm/models/diffusion/__init__.py
rename to applications/ChatGPT/tests/__init__.py
diff --git a/applications/ChatGPT/tests/test_checkpoint.py b/applications/ChatGPT/tests/test_checkpoint.py
new file mode 100644
index 000000000000..1bbd133f76d3
--- /dev/null
+++ b/applications/ChatGPT/tests/test_checkpoint.py
@@ -0,0 +1,98 @@
+import os
+import tempfile
+from contextlib import nullcontext
+from functools import partial
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from chatgpt.models.gpt import GPTActor
+from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.testing import rerun_if_address_is_in_use
+from colossalai.utils import free_port
+
+GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
+
+
+def get_data(batch_size: int, seq_len: int = 10) -> dict:
+ input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
+ attention_mask = torch.ones_like(input_ids)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def run_test_checkpoint(strategy):
+ BATCH_SIZE = 2
+
+ if strategy == 'ddp':
+ strategy = DDPStrategy()
+ elif strategy == 'colossalai_gemini':
+ strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
+ elif strategy == 'colossalai_zero2':
+ strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
+ else:
+ raise ValueError(f'Unsupported strategy "{strategy}"')
+
+ with strategy.model_init_context():
+ actor = GPTActor(config=GPT_CONFIG).cuda()
+
+ actor_optim = HybridAdam(actor.parameters())
+
+ actor, actor_optim = strategy.prepare((actor, actor_optim))
+
+ def run_step():
+ data = get_data(BATCH_SIZE)
+ action_mask = torch.ones_like(data['attention_mask'], dtype=torch.bool)
+ action_log_probs = actor(data['input_ids'], action_mask.size(1), data['attention_mask'])
+ loss = action_log_probs.sum()
+ strategy.backward(loss, actor, actor_optim)
+ strategy.optimizer_step(actor_optim)
+
+ run_step()
+
+ ctx = tempfile.TemporaryDirectory() if dist.get_rank() == 0 else nullcontext()
+
+ with ctx as dirname:
+ rank0_dirname = [dirname]
+ dist.broadcast_object_list(rank0_dirname)
+ rank0_dirname = rank0_dirname[0]
+
+ model_path = os.path.join(rank0_dirname, 'model.pt')
+ optim_path = os.path.join(rank0_dirname, f'optim-r{dist.get_rank()}.pt')
+
+ strategy.save_model(actor, model_path, only_rank0=True)
+ strategy.save_optimizer(actor_optim, optim_path, only_rank0=False)
+
+ dist.barrier()
+
+ strategy.load_model(actor, model_path, strict=False)
+ strategy.load_optimizer(actor_optim, optim_path)
+
+ dist.barrier()
+
+ run_step()
+
+
+def run_dist(rank, world_size, port, strategy):
+ os.environ['RANK'] = str(rank)
+ os.environ['LOCAL_RANK'] = str(rank)
+ os.environ['WORLD_SIZE'] = str(world_size)
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = str(port)
+ run_test_checkpoint(strategy)
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize('world_size', [2])
+@pytest.mark.parametrize('strategy', ['ddp', 'colossalai_zero2', 'colossalai_gemini'])
+@rerun_if_address_is_in_use()
+def test_checkpoint(world_size, strategy):
+ run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_checkpoint(2, 'colossalai_zero2')
diff --git a/applications/ChatGPT/tests/test_data.py b/applications/ChatGPT/tests/test_data.py
new file mode 100644
index 000000000000..3d8fe912cb27
--- /dev/null
+++ b/applications/ChatGPT/tests/test_data.py
@@ -0,0 +1,122 @@
+import os
+from copy import deepcopy
+from functools import partial
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from chatgpt.experience_maker import NaiveExperienceMaker
+from chatgpt.models.base import RewardModel
+from chatgpt.models.gpt import GPTActor, GPTCritic
+from chatgpt.replay_buffer import NaiveReplayBuffer
+from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+
+from colossalai.testing import rerun_if_address_is_in_use
+from colossalai.utils import free_port
+
+GPT_CONFIG = GPT2Config(n_embd=128, n_layer=4, n_head=4)
+
+
+def get_data(batch_size: int, seq_len: int = 10) -> dict:
+ input_ids = torch.randint(0, 50257, (batch_size, seq_len), device='cuda')
+ attention_mask = torch.ones_like(input_ids)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def gather_and_equal(tensor: torch.Tensor) -> bool:
+ world_size = dist.get_world_size()
+ outputs = [torch.empty_like(tensor) for _ in range(world_size)]
+ dist.all_gather(outputs, tensor.contiguous())
+ for t in outputs[1:]:
+ if not torch.equal(outputs[0], t):
+ return False
+ return True
+
+
+def run_test_data(strategy):
+ EXPERINCE_BATCH_SIZE = 4
+ SAMPLE_BATCH_SIZE = 2
+
+ if strategy == 'ddp':
+ strategy = DDPStrategy()
+ elif strategy == 'colossalai':
+ strategy = ColossalAIStrategy(placement_policy='cuda')
+ else:
+ raise ValueError(f'Unsupported strategy "{strategy}"')
+
+ actor = GPTActor(config=GPT_CONFIG).cuda()
+ critic = GPTCritic(config=GPT_CONFIG).cuda()
+
+ initial_model = deepcopy(actor)
+ reward_model = RewardModel(deepcopy(critic.model)).cuda()
+
+ experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model)
+ replay_buffer = NaiveReplayBuffer(SAMPLE_BATCH_SIZE, cpu_offload=False)
+
+ # experience of all ranks should be the same
+ for _ in range(2):
+ data = get_data(EXPERINCE_BATCH_SIZE)
+ assert gather_and_equal(data['input_ids'])
+ assert gather_and_equal(data['attention_mask'])
+ experience = experience_maker.make_experience(**data,
+ do_sample=True,
+ max_length=16,
+ eos_token_id=50256,
+ pad_token_id=50256)
+ assert gather_and_equal(experience.sequences)
+ assert gather_and_equal(experience.action_log_probs)
+ assert gather_and_equal(experience.values)
+ assert gather_and_equal(experience.reward)
+ assert gather_and_equal(experience.advantages)
+ assert gather_and_equal(experience.action_mask)
+ assert gather_and_equal(experience.attention_mask)
+ replay_buffer.append(experience)
+
+ # replay buffer's data should be the same
+ buffer_size = torch.tensor([len(replay_buffer)], device='cuda')
+ assert gather_and_equal(buffer_size)
+ for item in replay_buffer.items:
+ assert gather_and_equal(item.sequences)
+ assert gather_and_equal(item.action_log_probs)
+ assert gather_and_equal(item.values)
+ assert gather_and_equal(item.reward)
+ assert gather_and_equal(item.advantages)
+ assert gather_and_equal(item.action_mask)
+ assert gather_and_equal(item.attention_mask)
+
+ # dataloader of each rank should have the same size and different batch
+ dataloader = strategy.setup_dataloader(replay_buffer)
+ dataloader_size = torch.tensor([len(dataloader)], device='cuda')
+ assert gather_and_equal(dataloader_size)
+ for experience in dataloader:
+ assert not gather_and_equal(experience.sequences)
+ assert not gather_and_equal(experience.action_log_probs)
+ assert not gather_and_equal(experience.values)
+ assert not gather_and_equal(experience.reward)
+ assert not gather_and_equal(experience.advantages)
+ # action mask and attention mask may be same
+
+
+def run_dist(rank, world_size, port, strategy):
+ os.environ['RANK'] = str(rank)
+ os.environ['LOCAL_RANK'] = str(rank)
+ os.environ['WORLD_SIZE'] = str(world_size)
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = str(port)
+ run_test_data(strategy)
+
+
+@pytest.mark.skip
+@pytest.mark.dist
+@pytest.mark.parametrize('world_size', [2])
+@pytest.mark.parametrize('strategy', ['ddp', 'colossalai'])
+@rerun_if_address_is_in_use()
+def test_data(world_size, strategy):
+ run_func = partial(run_dist, world_size=world_size, port=free_port(), strategy=strategy)
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_data(2, 'colossalai')
diff --git a/applications/ChatGPT/version.txt b/applications/ChatGPT/version.txt
new file mode 100644
index 000000000000..6e8bf73aa550
--- /dev/null
+++ b/applications/ChatGPT/version.txt
@@ -0,0 +1 @@
+0.1.0
diff --git a/colossalai/_analyzer/README.md b/colossalai/_analyzer/README.md
new file mode 100644
index 000000000000..c5c55eddd325
--- /dev/null
+++ b/colossalai/_analyzer/README.md
@@ -0,0 +1,306 @@
+# Analyzer
+
+# Overview
+The Analyzer is a collection of static graph utils including Colossal-AI FX. Features include:
+- MetaTensor -- enabling:
+ - Ahead-of-time Profiling
+ - Shape Propagation
+ - Ideal Flop Counter
+- symbolic_trace()
+ - Robust Control-flow Tracing / Recompile
+ - Robust Activation Checkpoint Tracing / CodeGen
+ - Easy-to-define Bias-Addition Split
+- symbolic_profile()
+ - Support ``MetaTensorMode``, where all Tensor operations are executed symbolically.
+ - Shape Inference Across Device and Unified ``MetaInfo``
+ - Ideal Flop Counter https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
+
+# Quickstart
+## Analyzer.FX
+**Reference:**
+
+ https://pytorch.org/docs/stable/fx.html [[paper](https://arxiv.org/pdf/2112.08429)]
+
+
+torch.FX is a toolkit for developers to use to transform nn.Module instances. FX consists of three main components: a symbolic tracer, an intermediate representation, and Python code generation. FX.Tracer hacks _\_\_torch_function\_\__ and use a Proxy object to propagate through any forward function of torch.nn.Module.
+
+ColossalAI FX is modified from torch.FX, with the extra capability of ahead-of-time profiling enabled by the subclass of ``MetaTensor``.
+
+### Analyzer.FX.symbolic_trace()
+A drawback of the original torch.FX implementation is that it is poor at handling control flow. All control flow is not PyTorch native operands and requires actual instances that specify the branches to execute on. For example,
+
+```python
+class MyModule(nn.Module):
+ def forward(self, x):
+ if x.dim() == 3:
+ return x * 2 + 1
+ else:
+ return x - 5
+```
+
+The above function has the computation graph of
+
+
+
+However, since Proxy does not have concrete data, applying ``x.dim()`` will return nothing. In the context of the auto-parallel system, at least the control-flow dependencies for tensor shape should be removed, since any searched strategy could only auto-parallelize a specific computation graph with the same tensor shape. It is native to attach concrete data onto a Proxy, and propagate them through control flow.
+
+
+
+
+With ``MetaTensor``, the computation during shape propagation can be virtualized. This speeds up tracing by avoiding allocating actual memory on devices.
+
+#### Remarks
+There is no free lunch for PyTorch to unify all operands in both its repo and other repos in its eco-system. For example, the einops library currently has no intention to support torch.FX (See https://github.com/arogozhnikov/einops/issues/188). To support different PyTorch-based libraries without modifying source code, good practices can be to allow users to register their implementation to substitute the functions not supported by torch.FX, or to avoid entering incompatible submodules.
+
+### Analyzer.FX.symbolic_profile()
+
+``symbolic_profile`` is another important feature of Colossal-AI's auto-parallel system. Profiling DNN can be costly, as you need to allocate memory and execute on real devices. However, since the profiling requirements for auto-parallel is enough if we can detect when and where the intermediate activations (i.e. Tensor) are generated, we can profile the whole procedure without actually executing it. ``symbolic_profile``, as its name infers, profiles the whole network with symbolic information only.
+
+```python
+with MetaTensorMode():
+ model = MyModule().cuda()
+ sample = torch.rand(100, 3, 224, 224).cuda()
+meta_args = dict(
+ x = sample,
+)
+gm = symbolic_trace(model, meta_args=meta_args)
+gm = symbolic_profile(gm, sample)
+```
+
+``symbolic_profile`` is enabled by ``ShapeProp`` and ``GraphProfile``.
+
+#### ShapeProp
+Both Tensor Parallel and Activation Checkpoint solvers need to know the shape information ahead of time. Unlike PyTorch's implementation, this ``ShapeProp`` can be executed under MetaTensorMode. With this, all the preparation for auto-parallel solvers can be done in milliseconds.
+
+Meanwhile, it is easy to keep track of the memory usage of each node when doing shape propagation. However, the drawbacks of FX is that not every ``call_function`` saves its input for backward, and different tensor that flows within one FX.Graph can actually have the same layout. This raises problems for fine-grained profiling.
+
+
+
+To address this problem, I came up with a simulated environment enabled by ``torch.autograd.graph.saved_tensor_hooks`` and fake ``data_ptr`` (check ``_subclasses/meta_tensor.py`` for more details of ``data_ptr`` updates).
+
+```python
+class sim_env(saved_tensors_hooks):
+ """
+ A simulation of memory allocation and deallocation in the forward pass
+ using ``saved_tensor_hooks``.
+
+ Attributes:
+ ctx (Dict[int, torch.Tensor]): A dictionary that maps the
+ data pointer of a tensor to the tensor itself. This is used
+ to track the memory allocation and deallocation.
+
+ param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
+ data pointer of all model parameters to the parameter itself.
+ This avoids overestimating the memory usage of the intermediate activations.
+ """
+
+ def __init__(self, module: Optional[torch.nn.Module] = None):
+ super().__init__(self.pack_hook, self.unpack_hook)
+ self.ctx = {}
+ self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
+ self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
+
+ def pack_hook(self, tensor: torch.Tensor):
+ if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
+ self.ctx[tensor.data_ptr()] = tensor
+ return tensor
+
+ def unpack_hook(self, tensor):
+ return tensor
+```
+The ``ctx`` variable will keep track of all saved tensors with a unique identifier. It is likely that ``nn.Parameter`` is also counted in the ``ctx``, which is not desired. To avoid this, we can use ``param_ctx`` to keep track of all parameters in the model. The ``buffer_ctx`` is used to keep track of all buffers in the model. The ``local_ctx`` that is attached to each ``Node`` marks the memory usage of the stage to which the node belongs. With simple ``intersect``, ``union`` and ``subtract`` operations, we can get any memory-related information. For non-profileable nodes, you might add your customized profile rules to simulate the memory allocation. If a ``Graph`` is modified with some non-PyTorch functions, such as fused operands, you can register the shape propagation rule with the decorator.
+
+```python
+@register_shape_impl(fuse_conv_bn)
+def fuse_conv_bn_shape_impl(*args, **kwargs):
+ # infer output shape here
+ return torch.empty(output_shape, device=output_device)
+```
+
+An important notice is that ``ShapeProp`` will attach additional information to the graph, which will be exactly the input of ``Profiler``.
+
+#### GraphProfiler
+``GraphProfiler`` executes at the node level, and profiles both forward and backward within one node. For example, ``FlopProfiler`` will profile the forward and backward FLOPs of a node, and ``CommunicationProfiler`` will profile the forward and backward communication cost of a node. The ``GraphProfiler`` will attach the profiling results to the ``Node``. These procedures are decoupled for better extensibility.
+
+To provide a general insight of the profiled results, you can set ``verbose=True`` to print the summary as well.
+```python
+model = tm.resnet18()
+sample = torch.rand(100, 3, 224, 224)
+meta_args = dict(x=sample)
+gm = symbolic_trace(model, meta_args=meta_args)
+gm = symbolic_profile(gm, sample, verbose=True)
+
+============================================================ Results =====================================================================
+ Op type Op Accumulate size Incremental size Output size Temp size Param size Backward size Fwd FLOPs Bwd FLOPs
+------------- ---------------------------------------------- ----------------- ------------------ ------------- ----------- ------------ --------------- ------------- -------------
+ placeholder x 4.59 Mb 0 b 4.59 Mb 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module conv_proj 4.59 Mb 0 b 0 b 4.59 Mb 2.25 Mb 4.59 Mb 924.84 MFLOPs 924.84 MFLOPs
+ call_method reshape 4.59 Mb 0 b 0 b 4.59 Mb 0 b 4.59 Mb 0 FLOPs 0 FLOPs
+ call_method permute 4.59 Mb 0 b 0 b 4.59 Mb 0 b 4.59 Mb 0 FLOPs 0 FLOPs
+ get_attr class_token 4.59 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_method expand 4.59 Mb 0 b 0 b 24.00 Kb 3.00 Kb 0 b 0 FLOPs 6.14 kFLOPs
+call_function cat 4.59 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+ get_attr encoder_pos_embedding 4.59 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+call_function add 9.21 Mb 4.62 Mb 4.62 Mb 0 b 591.00 Kb 4.62 Mb 1.21 MFLOPs 1.21 MFLOPs
+ call_module encoder_dropout 9.21 Mb 0 b 4.62 Mb 0 b 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_0_ln_1 9.22 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_0_self_attention 46.52 Mb 37.30 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem 46.52 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_1 46.52 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_0_dropout 46.52 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_1 51.14 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_0_ln_2 51.15 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_0_mlp_0 74.24 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_0_mlp_1 92.71 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_0_mlp_2 92.71 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_0_mlp_3 92.71 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_0_mlp_4 92.71 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_2 97.32 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_1_ln_1 101.95 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_1_self_attention 134.63 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_2 134.63 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_3 134.63 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_1_dropout 134.63 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_3 139.25 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_1_ln_2 139.26 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_1_mlp_0 162.35 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_1_mlp_1 180.82 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_1_mlp_2 180.82 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_1_mlp_3 180.82 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_1_mlp_4 180.82 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_4 185.43 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_2_ln_1 190.06 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_2_self_attention 222.74 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_4 222.74 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_5 222.74 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_2_dropout 222.74 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_5 227.36 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_2_ln_2 227.37 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_2_mlp_0 250.46 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_2_mlp_1 268.93 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_2_mlp_2 268.93 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_2_mlp_3 268.93 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_2_mlp_4 268.93 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_6 273.54 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_3_ln_1 278.17 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_3_self_attention 310.86 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_6 310.86 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_7 310.86 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_3_dropout 310.86 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_7 315.47 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_3_ln_2 315.48 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_3_mlp_0 338.57 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_3_mlp_1 357.04 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_3_mlp_2 357.04 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_3_mlp_3 357.04 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_3_mlp_4 357.04 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_8 361.66 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_4_ln_1 366.29 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_4_self_attention 398.97 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_8 398.97 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_9 398.97 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_4_dropout 398.97 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_9 403.58 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_4_ln_2 403.60 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_4_mlp_0 426.68 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_4_mlp_1 445.15 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_4_mlp_2 445.15 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_4_mlp_3 445.15 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_4_mlp_4 445.15 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_10 449.77 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_5_ln_1 454.40 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_5_self_attention 487.08 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_10 487.08 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_11 487.08 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_5_dropout 487.08 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_11 491.70 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_5_ln_2 491.71 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_5_mlp_0 514.79 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_5_mlp_1 533.26 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_5_mlp_2 533.26 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_5_mlp_3 533.26 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_5_mlp_4 533.26 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_12 537.88 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_6_ln_1 542.51 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_6_self_attention 575.19 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_12 575.19 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_13 575.19 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_6_dropout 575.19 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_13 579.81 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_6_ln_2 579.82 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_6_mlp_0 602.90 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_6_mlp_1 621.37 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_6_mlp_2 621.37 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_6_mlp_3 621.37 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_6_mlp_4 621.37 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_14 625.99 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_7_ln_1 630.62 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_7_self_attention 663.30 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_14 663.30 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_15 663.30 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_7_dropout 663.30 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_15 667.92 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_7_ln_2 667.93 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_7_mlp_0 691.02 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_7_mlp_1 709.48 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_7_mlp_2 709.48 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_7_mlp_3 709.48 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_7_mlp_4 709.48 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_16 714.10 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_8_ln_1 718.73 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_8_self_attention 751.41 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_16 751.41 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_17 751.41 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_8_dropout 751.41 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_17 756.03 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_8_ln_2 756.04 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_8_mlp_0 779.13 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_8_mlp_1 797.60 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_8_mlp_2 797.60 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_8_mlp_3 797.60 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_8_mlp_4 797.60 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_18 802.21 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_9_ln_1 806.84 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_9_self_attention 839.52 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_18 839.52 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_19 839.52 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_9_dropout 839.52 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_19 844.14 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_9_ln_2 844.15 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_9_mlp_0 867.24 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_9_mlp_1 885.71 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_9_mlp_2 885.71 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_9_mlp_3 885.71 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_9_mlp_4 885.71 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_20 890.32 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_10_ln_1 894.95 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_10_self_attention 927.63 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_20 927.63 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_21 927.63 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_10_dropout 927.63 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_21 932.25 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_10_ln_2 932.26 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_10_mlp_0 955.35 Mb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_10_mlp_1 973.82 Mb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_10_mlp_2 973.82 Mb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_10_mlp_3 973.82 Mb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_10_mlp_4 973.82 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_22 978.44 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_11_ln_1 983.06 Mb 4.63 Mb 4.62 Mb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_11_self_attention 1015.75 Mb 32.68 Mb 0 b 4.62 Mb 9.01 Mb 13.85 Mb 4.20 GFLOPs 8.40 GFLOPs
+call_function getitem_22 1015.75 Mb 0 b 0 b 4.62 Mb 0 b 0 b 0 FLOPs 0 FLOPs
+call_function getitem_23 1015.75 Mb 0 b 0 b 0 b 0 b 0 b 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_11_dropout 1015.75 Mb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_23 1020.36 Mb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_11_ln_2 1020.38 Mb 12.31 Kb 0 b 4.62 Mb 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+ call_module encoder_layers_encoder_layer_11_mlp_0 1.02 Gb 23.09 Mb 18.47 Mb 0 b 9.01 Mb 4.62 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_11_mlp_1 1.04 Gb 18.47 Mb 18.47 Mb 0 b 0 b 18.47 Mb 4.84 MFLOPs 4.84 MFLOPs
+ call_module encoder_layers_encoder_layer_11_mlp_2 1.04 Gb 0 b 18.47 Mb 0 b 0 b 18.47 Mb 0 FLOPs 0 FLOPs
+ call_module encoder_layers_encoder_layer_11_mlp_3 1.04 Gb 0 b 0 b 4.62 Mb 9.00 Mb 18.47 Mb 3.72 GFLOPs 7.44 GFLOPs
+ call_module encoder_layers_encoder_layer_11_mlp_4 1.04 Gb 0 b 0 b 4.62 Mb 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+call_function add_24 1.04 Gb 4.62 Mb 4.62 Mb 0 b 0 b 9.23 Mb 1.21 MFLOPs 0 FLOPs
+ call_module encoder_ln 1.04 Gb 36.31 Kb 24.00 Kb 0 b 6.00 Kb 4.62 Mb 6.05 MFLOPs 6.05 MFLOPs
+call_function getitem_24 1.04 Gb 0 b 24.00 Kb 0 b 0 b 4.62 Mb 0 FLOPs 0 FLOPs
+ call_module heads_head 1.04 Gb 0 b 0 b 31.25 Kb 2.93 Mb 24.00 Kb 6.14 MFLOPs 12.30 MFLOPs
+ output output 1.04 Gb 0 b 0 b 31.25 Kb 0 b 31.25 Kb 0 FLOPs 0 FLOPs
+```
diff --git a/colossalai/_analyzer/_subclasses/__init__.py b/colossalai/_analyzer/_subclasses/__init__.py
new file mode 100644
index 000000000000..8464fed25edf
--- /dev/null
+++ b/colossalai/_analyzer/_subclasses/__init__.py
@@ -0,0 +1,4 @@
+from ._meta_registration import *
+from ._monkey_patch import *
+from .flop_tensor import flop_count, flop_mapping
+from .meta_tensor import MetaTensor, MetaTensorMode
diff --git a/colossalai/_analyzer/_subclasses/_meta_registration.py b/colossalai/_analyzer/_subclasses/_meta_registration.py
new file mode 100644
index 000000000000..2af7e05399af
--- /dev/null
+++ b/colossalai/_analyzer/_subclasses/_meta_registration.py
@@ -0,0 +1,464 @@
+# meta patch from https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
+# should be activated for PyTorch version 1.12.0 and below
+# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
+# for more meta_registrations
+
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+from packaging import version
+from torch.utils._pytree import tree_map
+
+aten = torch.ops.aten
+
+try:
+ meta_lib = torch.library.Library("aten", "IMPL", "Meta")
+except AttributeError:
+ meta_lib = None
+
+meta_table = {}
+
+orig_empty = torch.empty
+orig_empty_strided = torch.empty_strided
+orig_empty_like = torch.empty_like
+
+
+def new(*args, **kwargs):
+ return orig_empty(*args, **kwargs, device=torch.device('meta'))
+
+
+def new_strided(*args, **kwargs):
+ return orig_empty_strided(*args, **kwargs, device=torch.device('meta'))
+
+
+def new_like(*args, **kwargs):
+ return orig_empty_like(*args, **kwargs, device=torch.device('meta'))
+
+
+def register_meta(op, register_dispatcher=True):
+
+ def wrapper(f):
+
+ def add_func(op):
+ meta_table[op] = f
+ if register_dispatcher:
+ name = (op.__name__ if op._overloadname != "default" else op.overloadpacket.__name__)
+ try:
+ meta_lib.impl(name, f)
+ except:
+ pass
+
+ tree_map(add_func, op)
+ return f
+
+ return wrapper
+
+
+if version.parse(torch.__version__) >= version.parse('1.12.0'):
+ # ============================== Convolutions ======================================
+ # https://github.com/pytorch/pytorch/pull/79834
+ @register_meta(aten.convolution.default)
+ def meta_conv(
+ input_tensor: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ stride: List[int],
+ padding: List[int],
+ dilation: List[int],
+ is_transposed: bool,
+ output_padding: List[int],
+ groups: int,
+ ):
+
+ def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
+ """
+ Formula to apply to calculate the length of some dimension of the output
+ See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
+ Args:
+ ln: length of the dimension
+ p: padding in that dim
+ d: dilation in that dim
+ k: kernel size in that dim
+ s: stride in that dim
+ Returns:
+ The output length
+ """
+ return (ln + 2 * p - d * (k - 1) - 1) // s + 1
+
+ def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
+ """
+ Formula to apply to calculate the length of some dimension of the output
+ if transposed convolution is used.
+ See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
+ Args:
+ ln: length of the dimension
+ p: padding in that dim
+ d: dilation in that dim
+ k: kernel size in that dim
+ s: stride in that dim
+ op: output padding in that dim
+ Returns:
+ The output length
+ """
+ return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
+
+ def calc_conv_nd_return_shape(
+ dims: torch.Size,
+ kernel_size: torch.Size,
+ stride: Union[List[int], int],
+ padding: Union[List[int], int],
+ dilation: Union[List[int], int],
+ output_padding: Optional[Union[List[int], int]] = None,
+ ):
+ ret_shape = []
+ if isinstance(stride, int):
+ stride = [stride] * len(dims)
+ elif len(stride) == 1:
+ stride = [stride[0]] * len(dims)
+
+ if isinstance(padding, int):
+ padding = [padding] * len(dims)
+ elif len(padding) == 1:
+ padding = [padding[0]] * len(dims)
+
+ if isinstance(dilation, int):
+ dilation = [dilation] * len(dims)
+ elif len(dilation) == 1:
+ dilation = [dilation[0]] * len(dims)
+
+ output_padding_list: Optional[List[int]] = None
+ if output_padding:
+ if isinstance(output_padding, int):
+ output_padding_list = [output_padding] * len(dims)
+ elif len(output_padding) == 1:
+ output_padding_list = [output_padding[0]] * len(dims)
+ else:
+ output_padding_list = output_padding
+
+ for i in range(len(dims)):
+ # If output_padding is present, we are dealing with a transposed convolution
+ if output_padding_list:
+ ret_shape.append(
+ _formula_transposed(
+ dims[i],
+ padding[i],
+ dilation[i],
+ kernel_size[i],
+ stride[i],
+ output_padding_list[i],
+ ))
+ else:
+ ret_shape.append(_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]))
+ return ret_shape
+
+ def pick_memory_format():
+ if input_tensor.is_contiguous(memory_format=torch.channels_last):
+ return torch.channels_last
+ elif input_tensor.is_contiguous(memory_format=torch.contiguous_format):
+ return torch.contiguous_format
+ elif input_tensor.is_contiguous(memory_format=torch.preserve_format):
+ return torch.preserve_format
+
+ kernel_size = weight.shape[2:]
+ dims = input_tensor.shape[2:]
+ if is_transposed:
+ out_channels = groups * weight.shape[1]
+
+ shape_out = calc_conv_nd_return_shape(
+ dims,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ output_padding,
+ )
+
+ else:
+ out_channels = weight.shape[0]
+ if weight.shape[1] != input_tensor.shape[1] / groups:
+ raise RuntimeError("Invalid channel dimensions")
+ shape_out = calc_conv_nd_return_shape(dims, kernel_size, stride, padding, dilation)
+ out = input_tensor.new_empty((input_tensor.shape[0], out_channels, *shape_out))
+ mem_fmt = pick_memory_format()
+ out = out.to(memory_format=mem_fmt) # type: ignore[call-overload]
+ return out
+
+ @register_meta(aten._convolution.default)
+ def meta__conv(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
+ padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
+ *extra_args):
+ out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
+ return out
+
+ @register_meta(aten.convolution_backward.default)
+ def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
+ padding, dilation, transposed, output_padding, groups, output_mask):
+ return new_like(input), new_like(weight), new((bias_sizes))
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
+ @register_meta(aten._adaptive_avg_pool2d_backward.default)
+ def meta_adaptive_avg_pool2d_backward(
+ grad_output: torch.Tensor,
+ input: torch.Tensor,
+ ):
+ return new_like(input)
+
+ # ================================ RNN =============================================
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
+ @register_meta(aten._cudnn_rnn.default)
+ def meta_cuda_rnn(
+ input,
+ weight,
+ weight_stride0,
+ weight_buf,
+ hx,
+ cx,
+ mode,
+ hidden_size,
+ proj_size,
+ num_layers,
+ batch_first,
+ dropout,
+ train,
+ bidirectional,
+ batch_sizes,
+ dropout_state,
+ ):
+
+ is_input_packed = len(batch_sizes) != 0
+ if is_input_packed:
+ seq_length = len(batch_sizes)
+ mini_batch = batch_sizes[0]
+ batch_sizes_sum = input.shape[0]
+ else:
+ seq_length = input.shape[1] if batch_first else input.shape[0]
+ mini_batch = input.shape[0] if batch_first else input.shape[1]
+ batch_sizes_sum = -1
+
+ num_directions = 2 if bidirectional else 1
+ out_size = proj_size if proj_size != 0 else hidden_size
+ if is_input_packed:
+ out_shape = [batch_sizes_sum, out_size * num_directions]
+ else:
+ out_shape = ([mini_batch, seq_length, out_size *
+ num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
+ output = input.new_empty(out_shape)
+
+ cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
+ cy = new(0) if cx is None else cx.new_empty(cell_shape)
+
+ hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
+
+ # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
+ reserve_shape = 0 if train else 0
+ reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
+
+ return output, hy, cy, reserve, weight_buf
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/RNN.cpp
+ @register_meta(aten._cudnn_rnn_backward.default)
+ def meta_cudnn_rnn_backward(input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_stride0: int,
+ hx: torch.Tensor,
+ cx: Optional[torch.Tensor] = None,
+ *args,
+ **kwargs):
+ return new_like(input), new_like(weight), new_like(hx), new_like(cx) if cx is not None else new(
+ ()) # (grad_input, grad_weight, grad_hx, grad_cx)
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Activation.cpp
+ # ============================== Activations =======================================
+ _unregistered_ewise = [
+ aten.relu.default,
+ aten.prelu.default,
+ aten.hardswish.default,
+ aten.hardtanh.default,
+ aten.prelu_backward.default,
+ aten.hardswish_backward.default,
+ aten.hardtanh_backward.default,
+ ]
+
+ @register_meta(_unregistered_ewise)
+ def meta_unregistered_ewise(input: torch.Tensor, *args):
+ return new_like(input)
+
+ # ============================== Normalization =====================================
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
+ @register_meta(aten.native_batch_norm.default)
+ def meta_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
+ n_input = input.size(1)
+ return new_like(input), new((n_input)), new((n_input))
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
+ @register_meta(aten.native_batch_norm_backward.default)
+ def meta_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
+ save_mean, save_invstd, train, eps, output_mask):
+ return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
+ @register_meta(aten.cudnn_batch_norm.default)
+ def meta_cudnn_bn(input: torch.Tensor, weight, bias, running_mean, running_var, training, momentum, eps):
+ n_input = input.size(1)
+ return new_like(input), new((n_input)), new((n_input)), new(
+ (0), dtype=torch.uint8) # (output, running_mean, running_var, reserve)
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cudnn/BatchNorm.cpp
+ # NB: CuDNN only implements the backward algorithm for batchnorm
+ # in training mode (evaluation mode batchnorm has a different algorithm),
+ # which is why this doesn't accept a 'training' parameter.
+ @register_meta(aten.cudnn_batch_norm_backward.default)
+ def meta_cudnn_bn_backward(dY: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, running_mean, running_var,
+ save_mean, save_invstd, eps, reserve):
+ return new_like(input), new_like(weight), new_like(weight) # (dX, dgamma, dbeta)
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
+ @register_meta(aten.native_layer_norm.default)
+ def meta_ln(input: torch.Tensor, normalized_shape, weight, bias, eps):
+ bs, n_input = input.size(0), input.size(1)
+ return new_like(input), new((bs, n_input, 1)), new((bs, n_input, 1)) # (output, running_mean, running_var)
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/layer_norm.cpp
+ @register_meta(aten.native_layer_norm_backward.default)
+ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, mean, rstd, weight, bias,
+ grad_input_mask):
+ return new_like(input), new_like(weight), new_like(bias) # (dX, dgamma, dbeta)
+
+ # ================================== Misc ==========================================
+ # Maybe incorrect
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Im2Col.cpp
+ @register_meta(aten.im2col.default)
+ def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
+ return new_like(input)
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
+ @register_meta(aten.eye.m_out)
+ def meta_eye(n: int, m: int, out: torch.Tensor):
+ return out
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
+ @register_meta(aten.roll.default)
+ def meta_roll(input: torch.Tensor, shifts, dims):
+ return input
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Scalar.cpp
+ @register_meta(aten._local_scalar_dense.default)
+ def meta_local_scalar_dense(self: torch.Tensor):
+ return 0
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorCompare.cpp
+ @register_meta(aten.where.self)
+ def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
+ result_type = torch.result_type(self, other)
+ return new_like(condition + self + other, dtype=result_type)
+
+ @register_meta(aten.index.Tensor)
+ def meta_index_Tensor(self, indices):
+ assert indices, "at least one index must be provided"
+ # aten::index is the internal advanced indexing implementation
+ # checkIndexTensorTypes and expandTensors
+ result: List[Optional[torch.Tensor]] = []
+ for i, index in enumerate(indices):
+ if index is not None:
+ assert index.dtype in [torch.long, torch.int8, torch.bool],\
+ "tensors used as indices must be long, byte or bool tensors"
+ if index.dtype in [torch.int8, torch.bool]:
+ nonzero = index.nonzero()
+ k = len(result)
+ assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
+ for j in range(index.ndim):
+ assert index.shape[j] == self.shape[
+ k +
+ j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
+ result.append(nonzero.select(1, j))
+ else:
+ result.append(index)
+ else:
+ result.append(index)
+ indices = result
+ assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
+ # expand_outplace
+ import torch._refs as refs
+
+ indices = list(refs._maybe_broadcast(*indices))
+ # add missing null tensors
+ while len(indices) < self.ndim:
+ indices.append(None)
+
+ # hasContiguousSubspace
+ # true if all non-null tensors are adjacent
+ # See:
+ # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
+ # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
+ state = 0
+ has_contiguous_subspace = False
+ for index in indices:
+ if state == 0:
+ if index is not None:
+ state = 1
+ elif state == 1:
+ if index is None:
+ state = 2
+ else:
+ if index is not None:
+ break
+ else:
+ has_contiguous_subspace = True
+
+ # transposeToFront
+ # This is the logic that causes the newly inserted dimensions to show up
+ # at the beginning of the tensor, if they're not contiguous
+ if not has_contiguous_subspace:
+ dims = []
+ transposed_indices = []
+ for i, index in enumerate(indices):
+ if index is not None:
+ dims.append(i)
+ transposed_indices.append(index)
+ for i, index in enumerate(indices):
+ if index is None:
+ dims.append(i)
+ transposed_indices.append(index)
+ self = self.permute(dims)
+ indices = transposed_indices
+
+ # AdvancedIndex::AdvancedIndex
+ # Now we can assume the indices have contiguous subspace
+ # This is simplified from AdvancedIndex which goes to more effort
+ # to put the input and indices in a form so that TensorIterator can
+ # take them. If we write a ref for this, probably that logic should
+ # get implemented
+ before_shape: List[int] = []
+ after_shape: List[int] = []
+ replacement_shape: List[int] = []
+ for dim, index in enumerate(indices):
+ if index is None:
+ if replacement_shape:
+ after_shape.append(self.shape[dim])
+ else:
+ before_shape.append(self.shape[dim])
+ else:
+ replacement_shape = list(index.shape)
+ return self.new_empty(before_shape + replacement_shape + after_shape)
+
+ # ============================== Embedding =========================================
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
+ @register_meta(aten.embedding_dense_backward.default)
+ def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
+ scale_grad_by_freq):
+ return new((num_weights, grad_output.size(-1)),
+ dtype=grad_output.dtype,
+ device=grad_output.device,
+ layout=grad_output.layout)
+
+ # ============================== Dropout ===========================================
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
+ @register_meta(aten.native_dropout.default)
+ def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
+ # notice that mask is bool
+ return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
+
+ # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
+ @register_meta(aten.native_dropout_backward.default)
+ def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
+ return new_like(grad) # (grad_in)
diff --git a/colossalai/_analyzer/_subclasses/_monkey_patch.py b/colossalai/_analyzer/_subclasses/_monkey_patch.py
new file mode 100644
index 000000000000..7c1c3d3d8cd4
--- /dev/null
+++ b/colossalai/_analyzer/_subclasses/_monkey_patch.py
@@ -0,0 +1,94 @@
+import torch
+import torch.distributed as dist
+from packaging import version
+
+aten = torch.ops.aten
+
+__all__ = [
+ "_TorchFactoryMethod",
+ "_TorchOverrideableFactoryMethod",
+ "_TorchNonOverrideableFactoryMethod",
+ "_TensorPropertyMethod",
+ "_DistCommMethod",
+ "_AliasATen",
+ "_InplaceATen",
+ "_MaybeInplaceATen",
+]
+
+_TorchOverrideableFactoryMethod = [
+ "empty",
+ "eye",
+ "full",
+ "ones",
+ "rand",
+ "randn",
+ "zeros",
+]
+
+_TorchNonOverrideableFactoryMethod = [
+ "arange",
+ "finfo",
+ "linspace",
+ "logspace",
+ "randint",
+ "randperm",
+ "tensor",
+]
+
+_TorchFactoryMethod = _TorchOverrideableFactoryMethod + _TorchNonOverrideableFactoryMethod
+
+_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
+
+_DistCommMethod = [
+ "all_gather",
+ "all_reduce",
+ "all_to_all",
+ "broadcast",
+ "gather",
+ "reduce",
+ "reduce_scatter",
+ "scatter",
+]
+
+if version.parse(torch.__version__) >= version.parse('1.12.0'):
+ # TODO: dive deep here
+ # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
+ _AliasATen = [
+ aten.detach.default,
+ aten.detach_.default,
+ aten.t.default,
+ aten.transpose.int,
+ aten.view.default,
+ aten._unsafe_view.default,
+ aten._reshape_alias.default,
+ ]
+
+ _InplaceATen = [
+ aten.add_.Tensor,
+ aten.add_.Scalar,
+ aten.sub_.Tensor,
+ aten.sub_.Scalar,
+ aten.mul_.Tensor,
+ aten.mul_.Scalar,
+ aten.div_.Tensor,
+ aten.div_.Scalar,
+ aten.pow_.Tensor,
+ aten.pow_.Scalar,
+ ]
+
+ # use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
+ _MaybeInplaceATen = [
+ aten.diagonal.default,
+ aten.expand.default,
+ aten.select.int,
+ aten.slice.Tensor,
+ aten.split.Tensor,
+ aten.squeeze.default,
+ aten.permute.default,
+ aten.unsqueeze.default,
+ aten.as_strided.default,
+ ]
+else:
+ _AliasATen = []
+ _InplaceATen = []
+ _MaybeInplaceATen = []
diff --git a/colossalai/_analyzer/_subclasses/flop_tensor.py b/colossalai/_analyzer/_subclasses/flop_tensor.py
new file mode 100644
index 000000000000..dd35b00b3fab
--- /dev/null
+++ b/colossalai/_analyzer/_subclasses/flop_tensor.py
@@ -0,0 +1,542 @@
+# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
+# ideas from https://pastebin.com/AkvAyJBw
+# and https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505
+
+import operator
+from collections import defaultdict
+from contextlib import contextmanager
+from enum import Enum, auto
+from functools import partial, reduce
+from numbers import Number
+from typing import Any, Callable, List, Optional, Union
+
+import torch
+from packaging import version
+from torch.utils._pytree import tree_map
+
+from .meta_tensor import MetaTensor
+
+aten = torch.ops.aten
+
+
+class Phase(Enum):
+ FWD = auto()
+ BWD = auto()
+
+
+def normalize_tuple(x):
+ if not isinstance(x, tuple):
+ return (x,)
+ return x
+
+
+def _format_flops(flop):
+ K = 1e3
+ M = 1e6
+ B = 1e9
+ T = 1e12
+ if flop < K:
+ return f'{flop:.2f}'
+ elif flop < M:
+ return f'{flop / K:.2f}K'
+ elif flop < B:
+ return f'{flop / M:.2f}M'
+ elif flop < T:
+ return f'{flop / B:.2f}B'
+ else:
+ return f'{flop / T:.2f}T'
+
+
+def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number:
+ """
+ Count the number of floating point operations in a model.
+ Ideas from https://pastebin.com/AkvAyJBw.
+ Args:
+ module (torch.nn.Module): A PyTorch model.
+ *args: Input arguments to the model.
+ verbose (bool): If True, print the number of flops for each module.
+ **kwargs: Input keyword arguments to the model.
+ Returns:
+ Number: The total number of floating point operations (FWD + BWD).
+ """
+ maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False)
+ or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_'))
+
+ class DummyModule(torch.nn.Module):
+
+ def __init__(self, func):
+ super().__init__()
+ self.func = func
+ self.__name__ = func.__name__
+
+ def forward(self, *args, **kwargs):
+ return self.func(*args, **kwargs)
+
+ total_flop_count = {Phase.FWD: 0, Phase.BWD: 0}
+ flop_counts = defaultdict(lambda: defaultdict(int))
+ parents = ['Global']
+ module = module if isinstance(module, torch.nn.Module) else DummyModule(module)
+
+ class FlopTensor(MetaTensor):
+ _tensor: torch.Tensor
+
+ def __repr__(self):
+ name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor'
+ if self.grad_fn:
+ return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
+ return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+
+ # no_dispatch is only needed if you use enable_python_mode.
+ # It prevents infinite recursion.
+ rs = super().__torch_dispatch__(func, types, args, kwargs)
+
+ outs = normalize_tuple(rs)
+
+ if func in flop_mapping:
+ nonlocal flop_counts, total_flop_count
+ flop_count = flop_mapping[func](args, outs)
+ for par in parents:
+ flop_counts[par][func.__name__] += flop_count
+ total_flop_count[cur_phase] += flop_count
+
+ def wrap(x):
+ if isinstance(x, MetaTensor):
+ x = FlopTensor(x)
+ return x
+
+ rs = tree_map(wrap, rs)
+
+ return rs
+
+ def is_autogradable(x):
+ return isinstance(x, torch.Tensor) and x.is_floating_point()
+
+ def create_backwards_push(name):
+
+ class PushState(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, *args):
+ args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
+ if len(args) == 1:
+ return args[0]
+ return args
+
+ @staticmethod
+ def backward(ctx, *grad_outs):
+ nonlocal parents
+ parents.append(name)
+ return grad_outs
+
+ return PushState.apply
+
+ def create_backwards_pop(name):
+
+ class PopState(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, *args):
+ args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
+ if len(args) == 1:
+ return args[0]
+ return args
+
+ @staticmethod
+ def backward(ctx, *grad_outs):
+ nonlocal parents
+ assert (parents[-1] == name)
+ parents.pop()
+ return grad_outs
+
+ return PopState.apply
+
+ def enter_module(name):
+
+ def f(module, inputs):
+ nonlocal parents
+ parents.append(name)
+ inputs = normalize_tuple(inputs)
+ out = create_backwards_pop(name)(*inputs)
+ return out
+
+ return f
+
+ def exit_module(name):
+
+ def f(module, inputs, outputs):
+ nonlocal parents
+ assert (parents[-1] == name)
+ parents.pop()
+ outputs = normalize_tuple(outputs)
+ return create_backwards_push(name)(*outputs)
+
+ return f
+
+ @contextmanager
+ def instrument_module(mod):
+ registered = []
+ for name, module in dict(mod.named_children()).items():
+ registered.append(module.register_forward_pre_hook(enter_module(name)))
+ registered.append(module.register_forward_hook(exit_module(name)))
+ yield
+ for handle in registered:
+ handle.remove()
+
+ def display_flops():
+ for mod in flop_counts.keys():
+ print(f"Module: ", mod)
+ for k, v in flop_counts[mod].items():
+ print('\t', k, _format_flops(v))
+ print()
+
+ def detach_variables(r):
+ if isinstance(r, torch.Tensor):
+ requires_grad = r.requires_grad
+ r = r.detach()
+ r.requires_grad = requires_grad
+ return r
+
+ def wrap(r):
+ if isinstance(r, torch.Tensor):
+ data_ptr_fn = getattr(r, '_tensor', r).data_ptr
+ r = FlopTensor(detach_variables(r))
+ if maybe_inplace:
+ r = r + 0
+ r._tensor.data_ptr = data_ptr_fn
+ return r
+
+ with instrument_module(module):
+ cur_phase = Phase.FWD
+ rst = module(*tree_map(wrap, args), **tree_map(wrap, kwargs))
+ rst = tuple(r for r in normalize_tuple(rst) if is_autogradable(r) and r.requires_grad)
+ cur_phase = Phase.BWD
+
+ if rst:
+ grad = [torch.zeros_like(t) for t in rst]
+ torch.autograd.backward(
+ rst,
+ grad,
+ )
+
+ if verbose:
+ display_flops()
+
+ return total_flop_count[Phase.FWD], total_flop_count[Phase.BWD]
+
+
+def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
+ """
+ Count flops for matmul.
+ """
+ # Inputs should be a list of length 2.
+ # Inputs contains the shapes of two matrices.
+ input_shapes = [v.shape for v in inputs]
+ assert len(input_shapes) == 2, input_shapes
+ assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
+ flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
+ return flops
+
+
+def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
+ """
+ Count flops for fully connected layers.
+ """
+ # Count flop for nn.Linear
+ # inputs is a list of length 3.
+ input_shapes = [v.shape for v in inputs[1:3]]
+ # input_shapes[0]: [batch size, input feature dimension]
+ # input_shapes[1]: [input feature dimension, output feature dimension]
+ assert len(input_shapes[0]) == 2, input_shapes[0]
+ assert len(input_shapes[1]) == 2, input_shapes[1]
+ batch_size, input_dim = input_shapes[0]
+ output_dim = input_shapes[1][1]
+ flops = batch_size * input_dim * output_dim
+ return flops
+
+
+def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
+ """
+ Count flops for the aten::linear operator.
+ """
+ # Inputs is a list of length 3; unlike aten::addmm, it is the first
+ # two elements that are relevant.
+ input_shapes = [v.shape for v in inputs[0:2]]
+ # input_shapes[0]: [dim0, dim1, ..., input_feature_dim]
+ # input_shapes[1]: [output_feature_dim, input_feature_dim]
+ assert input_shapes[0][-1] == input_shapes[1][-1]
+ flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0]
+ return flops
+
+
+def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
+ """
+ Count flops for the bmm operation.
+ """
+ # Inputs should be a list of length 2.
+ # Inputs contains the shapes of two tensor.
+ assert len(inputs) == 2, len(inputs)
+ input_shapes = [v.shape for v in inputs]
+ n, c, t = input_shapes[0]
+ d = input_shapes[-1][-1]
+ flops = n * c * t * d
+ return flops
+
+
+def conv_flop_count(
+ x_shape: List[int],
+ w_shape: List[int],
+ out_shape: List[int],
+ transposed: bool = False,
+) -> Number:
+ """
+ Count flops for convolution. Note only multiplication is
+ counted. Computation for addition and bias is ignored.
+ Flops for a transposed convolution are calculated as
+ flops = (x_shape[2:] * prod(w_shape) * batch_size).
+ Args:
+ x_shape (list(int)): The input shape before convolution.
+ w_shape (list(int)): The filter shape.
+ out_shape (list(int)): The output shape after convolution.
+ transposed (bool): is the convolution transposed
+ Returns:
+ int: the number of flops
+ """
+ batch_size = x_shape[0]
+ conv_shape = (x_shape if transposed else out_shape)[2:]
+ flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape)
+ return flops
+
+
+def conv_flop_jit(inputs: List[Any], outputs: List[Any]):
+ """
+ Count flops for convolution.
+ """
+ x, w = inputs[:2]
+ x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape)
+ transposed = inputs[6]
+
+ return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
+
+
+def transpose_shape(shape):
+ return [shape[1], shape[0]] + list(shape[2:])
+
+
+def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]):
+ grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]]
+ output_mask = inputs[-1]
+ fwd_transposed = inputs[7]
+ flop_count = 0
+
+ if output_mask[0]:
+ grad_input_shape = outputs[0].shape
+ flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
+ if output_mask[1]:
+ grad_weight_shape = outputs[1].shape
+ flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
+
+ return flop_count
+
+
+def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
+ """
+ Args:
+ affine_arg_index: index of the affine argument in inputs
+ """
+
+ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
+ """
+ Count flops for norm layers.
+ """
+ # Inputs[0] contains the shape of the input.
+ input_shape = inputs[input_arg_index].shape
+
+ has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
+ 'shape') else inputs[affine_arg_index]
+ assert 2 <= len(input_shape) <= 5, input_shape
+ # 5 is just a rough estimate
+ flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
+ return flop
+
+ return norm_flop_jit
+
+
+def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number:
+ if training is None:
+ training = inputs[-3]
+ assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
+ if training:
+ return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
+ has_affine = inputs[1].shape is not None
+ input_shape = reduce(operator.mul, inputs[0].shape)
+ return input_shape * (2 if has_affine else 1)
+
+
+def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable:
+ """
+ Count flops by
+ input_tensor.numel() * input_scale + output_tensor.numel() * output_scale
+ Args:
+ input_scale: scale of the input tensor (first argument)
+ output_scale: scale of the output tensor (first element in outputs)
+ """
+
+ def ewise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
+ ret = 0
+ if input_scale != 0:
+ shape = inputs[0].shape
+ ret += input_scale * reduce(operator.mul, shape) if shape else 0
+ if output_scale != 0:
+ shape = outputs[0].shape
+ ret += output_scale * reduce(operator.mul, shape) if shape else 0
+ return ret
+
+ return ewise_flop
+
+
+def zero_flop_jit(*args):
+ """
+ Count flops for zero flop layers.
+ """
+ return 0
+
+
+if version.parse(torch.__version__) >= version.parse('1.12.0'):
+ flop_mapping = {
+ # gemm
+ aten.mm.default: matmul_flop_jit,
+ aten.matmul.default: matmul_flop_jit,
+ aten.addmm.default: addmm_flop_jit,
+ aten.bmm.default: bmm_flop_jit,
+
+ # convolution
+ aten.convolution.default: conv_flop_jit,
+ aten._convolution.default: conv_flop_jit,
+ aten.convolution_backward.default: conv_backward_flop_jit,
+
+ # normalization
+ aten.native_batch_norm.default: batchnorm_flop_jit,
+ aten.native_batch_norm_backward.default: batchnorm_flop_jit,
+ aten.cudnn_batch_norm.default: batchnorm_flop_jit,
+ aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
+ aten.native_layer_norm.default: norm_flop_counter(2, 0),
+ aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
+
+ # pooling
+ aten.avg_pool1d.default: ewise_flop_counter(1, 0),
+ aten.avg_pool2d.default: ewise_flop_counter(1, 0),
+ aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1),
+ aten.avg_pool3d.default: ewise_flop_counter(1, 0),
+ aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1),
+ aten.max_pool1d.default: ewise_flop_counter(1, 0),
+ aten.max_pool2d.default: ewise_flop_counter(1, 0),
+ aten.max_pool3d.default: ewise_flop_counter(1, 0),
+ aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0),
+ aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0),
+ aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1),
+ aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0),
+ aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1),
+ aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0),
+ aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1),
+ aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0),
+ aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1),
+ aten.embedding_dense_backward.default: ewise_flop_counter(0, 1),
+ aten.embedding.default: ewise_flop_counter(1, 0),
+ }
+
+ ewise_flop_aten = [
+ # basic op
+ aten.add.Tensor,
+ aten.add_.Tensor,
+ aten.div.Tensor,
+ aten.div_.Tensor,
+ aten.div.Scalar,
+ aten.div_.Scalar,
+ aten.mul.Tensor,
+ aten.mul.Scalar,
+ aten.mul_.Tensor,
+ aten.neg.default,
+ aten.pow.Tensor_Scalar,
+ aten.rsub.Scalar,
+ aten.sum.default,
+ aten.sum.dim_IntList,
+ aten.mean.dim,
+
+ # activation op
+ aten.hardswish.default,
+ aten.hardswish_.default,
+ aten.hardswish_backward.default,
+ aten.hardtanh.default,
+ aten.hardtanh_.default,
+ aten.hardtanh_backward.default,
+ aten.hardsigmoid_backward.default,
+ aten.hardsigmoid.default,
+ aten.gelu.default,
+ aten.gelu_backward.default,
+ aten.silu.default,
+ aten.silu_.default,
+ aten.silu_backward.default,
+ aten.sigmoid.default,
+ aten.sigmoid_backward.default,
+ aten._softmax.default,
+ aten._softmax_backward_data.default,
+ aten.relu_.default,
+ aten.relu.default,
+ aten.tanh.default,
+ aten.tanh_backward.default,
+ aten.threshold_backward.default,
+
+ # dropout
+ aten.native_dropout.default,
+ aten.native_dropout_backward.default,
+
+ # distribution
+ aten.bernoulli_.float,
+
+ # where
+ aten.where.self,
+ ]
+ for op in ewise_flop_aten:
+ flop_mapping[op] = ewise_flop_counter(1, 0)
+
+ # fix-me: this will be removed in future
+ zero_flop_aten = [
+ aten.as_strided.default,
+ aten.as_strided_.default,
+ aten.cat.default,
+ aten.clone.default,
+ aten.copy_.default,
+ aten.detach.default,
+ aten.expand.default,
+ aten.empty_like.default,
+ aten.new_empty.default,
+ aten.new_empty_strided.default,
+ aten.ones_like.default,
+ aten._reshape_alias.default,
+ aten.select.int,
+ aten.select_backward.default,
+ aten.squeeze.dim,
+ aten.slice.Tensor,
+ aten.slice_backward.default,
+ aten.split.Tensor,
+ aten.permute.default,
+ aten.t.default,
+ aten.transpose.int,
+ aten._to_copy.default,
+ aten.unsqueeze.default,
+ aten.unbind.int,
+ aten._unsafe_view.default,
+ aten.view.default,
+ aten.zero_.default,
+ aten.zeros_like.default,
+ ]
+
+ for op in zero_flop_aten:
+ flop_mapping[op] = zero_flop_jit
+else:
+ flop_mapping = {}
+ elementwise_flop_aten = {}
+ zero_flop_aten = {}
diff --git a/colossalai/_analyzer/_subclasses/meta_tensor.py b/colossalai/_analyzer/_subclasses/meta_tensor.py
new file mode 100644
index 000000000000..2bc212938ee0
--- /dev/null
+++ b/colossalai/_analyzer/_subclasses/meta_tensor.py
@@ -0,0 +1,207 @@
+import uuid
+from functools import partial
+
+import torch
+import torch.distributed as dist
+from torch.types import _bool, _device, _dtype
+from torch.utils._pytree import tree_flatten, tree_map
+
+from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
+
+__all__ = ['MetaTensor', 'MetaTensorMode']
+
+
+def register_storage(r, data_ptr_fn=None):
+ if isinstance(r, torch.Tensor):
+ if data_ptr_fn is not None:
+ r.data_ptr = data_ptr_fn
+ elif not r.data_ptr():
+ data_ptr = uuid.uuid1()
+ r.data_ptr = lambda: data_ptr
+
+
+def _normalize_tuple(x):
+ if not isinstance(x, tuple):
+ return (x,)
+ return x
+
+
+# a hack of inplace execution in PyTorch
+def _assert_alias(func):
+ return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen # TODO: check if should be this aggressive
+ )
+
+
+class MetaTensor(torch.Tensor):
+ """
+ A wrapping tensor that hacks ``torch.autograd`` without patching more ``torch.ops.aten`` ops.
+ `device` is the device that ``MetaTensor`` is supposed to run on. Meta tensors give you the
+ ability to run PyTorch code without having to actually do computation through tensors
+ allocated on a `meta` device. Because the device is `meta`, meta tensors do not model
+ device propagation. ``MetaTensor`` extends its usage by carrying an additional `device`
+ which tracks devices that would have been used.
+
+ Reference:
+ https://github.com/pytorch/pytorch/blob/master/torch/_subclasses/fake_tensor.py
+ """
+
+ _tensor: torch.Tensor
+
+ @staticmethod
+ def __new__(cls, elem, device=None, data_ptr_fn=None):
+ requires_grad = elem.requires_grad
+ # Avoid multiple wrapping
+ while isinstance(elem, MetaTensor):
+ device = elem.device if device is None else device
+ elem = elem._tensor
+
+ # The wrapping tensor (MetaTensor) shouldn't hold any
+ # memory for the class in question, but it should still
+ # advertise the same device as before
+ r = torch.Tensor._make_wrapper_subclass(
+ cls,
+ elem.size(),
+ strides=elem.stride(),
+ storage_offset=elem.storage_offset(),
+ dtype=elem.dtype,
+ layout=elem.layout,
+ device=device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
+ requires_grad=requires_grad) # deceive the frontend for aten selections
+ r._tensor = elem
+ # ...the real tensor is held as an element on the tensor.
+ if not r._tensor.is_meta:
+ val = elem.data_ptr()
+ data_ptr_fn = lambda: val
+ r._tensor = r._tensor.to(torch.device('meta'))
+
+ # only tensor not on `meta` should be copied to `meta`
+ register_storage(r._tensor, data_ptr_fn)
+ if isinstance(elem, torch.nn.Parameter):
+ r = torch.nn.Parameter(r)
+ return r
+
+ def __repr__(self):
+ name = 'MetaParameter' if getattr(self, '_is_param', False) else 'MetaTensor'
+ if self.grad_fn:
+ return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
+ return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ device = None
+
+ def unwrap(x):
+ nonlocal device
+ if isinstance(x, MetaTensor):
+ device = x.device
+ x = x._tensor
+ elif isinstance(x, torch.Tensor):
+ device = x.device
+ x = x.to(torch.device('meta'))
+ return x
+
+ args = tree_map(unwrap, args)
+ kwargs = tree_map(unwrap, kwargs)
+
+ if 'device' in kwargs:
+ device = kwargs['device']
+ kwargs['device'] = torch.device('meta')
+
+ # run aten for backend=CPU but actually on backend=Meta
+ # here we detect whether or not the execution generates a physical copy
+ # of the input tensor
+ ret = func(*args, **kwargs)
+
+ if _assert_alias(func):
+ val = args[0].data_ptr()
+ tree_map(partial(register_storage, data_ptr_fn=lambda: val), _normalize_tuple(ret))
+
+ # Now, we want to continue propagating this tensor, so we rewrap Tensors in
+ # our custom tensor subclass
+ def wrap(x):
+ return MetaTensor(x, device=device) if isinstance(x, torch.Tensor) else x
+
+ return tree_map(wrap, ret)
+
+ def to(self, *args, **kwargs) -> torch.Tensor:
+ """An extension of `torch.Tensor.to()` to MetaTensor
+ Returns:
+ result (MetaTensor): MetaTensor
+ Usage:
+ >>> tensor = MetaTensor(torch.rand(10), device='cuda:100')
+ >>> tensor.to(torch.uint8)
+ MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), device='cuda:100')
+ >>> tensor.to(torch.device('cuda:42'))
+ MetaTensor(tensor(..., device='meta', size=(10,)), device='cuda:42')
+ >>> tensor.to('vulkan')
+ MetaTensor(tensor(..., device='meta', size=(10,)), device='vulkan')
+ """
+ # this imitates c++ function in the way of @overload
+ device = None
+
+ def replace(x):
+ nonlocal device
+ if isinstance(x, str) or isinstance(x, _device):
+ device = x
+ return torch.device('meta')
+ return x
+
+ elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
+ return MetaTensor(elem, device=device)
+
+ def cpu(self, *args, **kwargs):
+ if self.device.type == 'cpu':
+ return self.to(*args, **kwargs)
+ return self.to(*args, device='cpu', **kwargs)
+
+ def cuda(self, device=None, non_blocking=False):
+ if device is not None:
+ return self.to(device=device, non_blocking=non_blocking)
+ return self.to(device='cuda:0', non_blocking=non_blocking)
+
+ def data_ptr(self):
+ return self._tensor.data_ptr()
+
+
+class MetaTensorMode(object):
+ """
+ A context manager that enables MetaTensor mode.
+
+ Usage:
+ >>> with MetaTensorMode():
+ >>> # all torch.xxx and torch.distributed.xxx will be replaced by patched functions
+ >>> # and the actual execution will be on torch.device('meta')
+ >>> a = torch.rand(100000, 100000)
+ >>> b = torch.rand(100000, 100000)
+ >>> c = torch.mm(a, b)
+ """
+
+ def __init__(self):
+ self.torch_overrides = {} # override torch.xxx
+ self.dist_overrides = {} # override torch.distributed.xxx
+
+ def __enter__(self):
+
+ def _dummy(*args, **kwargs):
+ pass
+
+ def _new(*args, orig_new=torch.empty, **kwargs):
+ return MetaTensor(orig_new(*args, **{
+ **kwargs, 'device': 'meta'
+ }),
+ device=kwargs.get('device', torch.device('cpu')))
+
+ for func in _TorchOverrideableFactoryMethod:
+ self.torch_overrides[func] = getattr(torch, func)
+ setattr(torch, func, partial(_new, orig_new=getattr(torch, func)))
+
+ for func in _DistCommMethod:
+ self.dist_overrides[func] = getattr(dist, func)
+ setattr(dist, func, _dummy)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ for func, func_impl in self.torch_overrides.items():
+ setattr(torch, func, func_impl)
+
+ for func, func_impl in self.dist_overrides.items():
+ setattr(dist, func, func_impl)
diff --git a/colossalai/_analyzer/envs.py b/colossalai/_analyzer/envs.py
new file mode 100644
index 000000000000..b537747c57a8
--- /dev/null
+++ b/colossalai/_analyzer/envs.py
@@ -0,0 +1,7 @@
+from dataclasses import dataclass
+
+
+@dataclass
+class MeshConfig:
+ TFLOPS: float = 1.9e12
+ BANDWIDTH = 1.2e9
diff --git a/colossalai/_analyzer/fx/__init__.py b/colossalai/_analyzer/fx/__init__.py
new file mode 100644
index 000000000000..aa01de0bbe6c
--- /dev/null
+++ b/colossalai/_analyzer/fx/__init__.py
@@ -0,0 +1,3 @@
+from .node_util import MetaInfo
+from .symbolic_profile import symbolic_profile
+from .tracer.symbolic_trace import symbolic_trace
diff --git a/colossalai/_analyzer/fx/codegen.py b/colossalai/_analyzer/fx/codegen.py
new file mode 100644
index 000000000000..1117c0103166
--- /dev/null
+++ b/colossalai/_analyzer/fx/codegen.py
@@ -0,0 +1,456 @@
+from typing import Any, Callable, Dict, Iterable, List, Tuple
+
+import torch
+from torch.fx.graph import (
+ CodeGen,
+ PythonCode,
+ _custom_builtins,
+ _format_target,
+ _is_from_torch,
+ _Namespace,
+ _origin_type_map,
+ _register_custom_builtin,
+ inplace_methods,
+ magic_methods,
+)
+from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
+
+import colossalai
+from colossalai.fx._compatibility import compatibility
+
+_register_custom_builtin('colossalai', 'import colossalai', colossalai)
+
+
+def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
+ """
+ Generate the checkpoint function definition
+ """
+ return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):"
+
+
+def _gen_ckpt_output(output_vars: List[str]) -> str:
+ """
+ Generate the return statement for checkpoint region
+ """
+ return f"return {', '.join(output_vars)}"
+
+
+def _gen_ckpt_usage(label, input_vars, output_vars, use_reentrant=True):
+ """
+ Generate the checkpoint function call code text
+ """
+ outputs = ', '.join(output_vars)
+ inputs = ', '.join(input_vars)
+ return f'{outputs} = torch.utils.checkpoint.checkpoint(self.checkpoint_{label}, {inputs}, use_reentrant={use_reentrant})'
+
+
+def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
+ """
+ Check if the node could end the ckpt region at `ckpt_level`
+ """
+ if len(node.meta['info'].to_recompute) > ckpt_level:
+ return node.meta['info'].to_recompute[ckpt_level] is not None
+ return True
+
+
+def _find_input_and_output_nodes(nodes: List[Node]):
+ """
+ Find the input and output node names which are not found in the given list of nodes.
+ """
+ input_nodes = []
+ output_nodes = []
+
+ # if a node has an input node which is not in the node list
+ # we treat that input node as the input of the checkpoint function
+ for node in nodes:
+ for input_node in node._input_nodes.keys():
+ node_repr = repr(input_node)
+ if input_node not in nodes and node_repr not in input_nodes:
+ input_nodes.append(node_repr)
+
+ # if a node has a user node which is not in the node list
+ # we treat that user node as the node receiving the current node output
+ for node in nodes:
+ for output_node in node.users.keys():
+ node_repr = repr(node)
+ if output_node not in nodes and node_repr not in output_nodes:
+ output_nodes.append(node_repr)
+
+ return input_nodes, output_nodes
+
+
+def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
+ """
+ Find the nested checkpoint regions given a list of consecutive nodes. The outputs
+ will be list of tuples, each tuple is in the form of (start_index, end_index).
+ """
+ ckpt_regions = []
+ start = -1
+ end = -1
+ current_region = None
+
+ for idx, node in enumerate(node_list):
+ if len(node.meta['info'].to_recompute) > ckpt_level:
+ act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
+
+ # this activation checkpoint label is not set yet
+ # meaning this is the first node of the activation ckpt region
+ if current_region is None:
+ current_region = act_ckpt_label
+ start = idx
+
+ # if activation checkpoint has changed
+ # we restart the tracking
+ # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
+ if act_ckpt_label != current_region:
+ assert start != -1
+ ckpt_regions.append((start, idx - 1))
+ current_region = act_ckpt_label
+ start = idx
+ end = -1
+
+ elif current_region is not None and _end_of_ckpt(node, ckpt_level):
+ # used to check the case below
+ # node ckpt states = [ckpt, ckpt, non-ckpt]
+ end = idx - 1
+ assert start != -1 and end != -1
+ ckpt_regions.append((start, end))
+ start = end = -1
+ current_region = None
+
+ else:
+ pass
+
+ if current_region is not None:
+ end = len(node_list) - 1
+ ckpt_regions.append((start, end))
+ return ckpt_regions
+
+
+def emit_ckpt_func(body,
+ ckpt_func,
+ node_list: List[Node],
+ emit_node_func,
+ delete_unused_value_func,
+ ckpt_level=0,
+ in_ckpt=False):
+ """Emit ckpt fuction in nested way
+
+ Args:
+ body: forward code - in recursive calls, this part will be checkpoint
+ functions code
+ ckpt_func: checkpoint functions code - in recursive calls, this part
+ will be a buffer
+ node_list (List[Node]): list of torch.fx.Node
+ emit_node_func: function to emit a node
+ delete_unused_value_func: function to delete unused value
+ level (int, optional): checkpoint level. Defaults to 0.
+ in_ckpt (bool, optional): indicates wether the func is in recursive
+ call. Defaults to False.
+ """
+ inputs, outputs = _find_input_and_output_nodes(node_list)
+
+ # label given by each layer, e.g. if you are currently at level (0, 1, 1)
+ # the label will be '0_1_1'
+ label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
+ ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
+ ckpt_func.append(f'{ckpt_fn_def}\n')
+
+ # if there is more level to fetch
+ if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
+ ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
+ start_idx = [item[0] for item in ckpt_regions]
+ end_idx = [item[1] for item in ckpt_regions]
+
+ # use ckpt_func_buffer to store nested checkpoint functions
+ ckpt_func_buffer = []
+ node_idx = 0
+ while 1:
+ if node_idx >= len(node_list):
+ break
+
+ if node_idx in start_idx:
+ ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
+ emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, delete_unused_value_func,
+ ckpt_level + 1, True)
+ node_idx += len(ckpt_node_list)
+
+ else:
+ node = node_list[node_idx]
+ emit_node_func(node, ckpt_func)
+ ckpt_func[-1] = ' ' + ckpt_func[-1]
+ delete_unused_value_func(node, ckpt_func)
+ node_idx += 1
+
+ ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+ ckpt_func += ckpt_func_buffer
+
+ # last level
+ else:
+ for node in node_list:
+ emit_node_func(node, ckpt_func)
+ ckpt_func[-1] = ' ' + ckpt_func[-1]
+ delete_unused_value_func(node, ckpt_func)
+
+ ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n')
+
+ usage = _gen_ckpt_usage(label, inputs, outputs, False) + '\n'
+ if in_ckpt:
+ usage = ' ' + usage
+ body.append(usage)
+
+
+def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func):
+ """Emit code with nested activation checkpoint
+ When we detect some of the annotation is a , we will use
+ this function to emit the activation checkpoint codes.
+
+ Args:
+ body: forward code
+ ckpt_func: checkpoint functions code
+ nodes: graph.nodes
+ emit_node_func: function to emit node
+ delete_unused_value_func: function to remove the unused value
+ """
+ ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
+ start_idx = [item[0] for item in ckpt_regions]
+ end_idx = [item[1] for item in ckpt_regions]
+
+ node_list = list(nodes)
+
+ node_idx = 0
+ while 1:
+ # break if we finish the processing all the nodes
+ if node_idx >= len(node_list):
+ break
+
+ # process ckpt_regions
+ if node_idx in start_idx:
+ ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1]
+ emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func)
+ node_idx += len(ckpt_node_list)
+
+ # process node in forward function
+ else:
+ node = node_list[node_idx]
+ emit_node_func(node, body)
+ delete_unused_value_func(node, body)
+ node_idx += 1
+
+
+@compatibility(is_backward_compatible=True)
+class ActivationCheckpointCodeGen(CodeGen):
+
+ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
+ free_vars: List[str] = []
+ body: List[str] = []
+ globals_: Dict[str, Any] = {}
+ wrapped_fns: Dict[str, None] = {}
+
+ # Wrap string in list to pass by reference
+ maybe_return_annotation: List[str] = ['']
+
+ def add_global(name_hint: str, obj: Any):
+ """Add an obj to be tracked as a global.
+ We call this for names that reference objects external to the
+ Graph, like functions or types.
+ Returns: the global name that should be used to reference 'obj' in generated source.
+ """
+ if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device
+ # HACK: workaround for how torch custom ops are registered. We
+ # can't import them like normal modules so they must retain their
+ # fully qualified name.
+ return _get_qualified_name(obj)
+
+ # normalize the name hint to get a proper identifier
+ global_name = namespace.create_name(name_hint, obj)
+
+ if global_name in globals_:
+ assert globals_[global_name] is obj
+ return global_name
+ globals_[global_name] = obj
+ return global_name
+
+ # Pre-fill the globals table with registered builtins.
+ for name, (_, obj) in _custom_builtins.items():
+ add_global(name, obj)
+
+ def type_repr(o: Any):
+ if o == ():
+ # Empty tuple is used for empty tuple type annotation Tuple[()]
+ return '()'
+
+ typename = _type_repr(o)
+
+ if hasattr(o, '__origin__'):
+ # This is a generic type, e.g. typing.List[torch.Tensor]
+ origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
+ origin_typename = add_global(_type_repr(origin_type), origin_type)
+
+ if hasattr(o, '__args__'):
+ # Assign global names for each of the inner type variables.
+ args = [type_repr(arg) for arg in o.__args__]
+
+ if len(args) == 0:
+ # Bare type, such as `typing.Tuple` with no subscript
+ # This code-path used in Python < 3.9
+ return origin_typename
+
+ return f'{origin_typename}[{",".join(args)}]'
+ else:
+ # Bare type, such as `typing.Tuple` with no subscript
+ # This code-path used in Python 3.9+
+ return origin_typename
+
+ # Common case: this is a regular module name like 'foo.bar.baz'
+ return add_global(typename, o)
+
+ def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
+
+ def _get_repr(arg):
+ # Handle NamedTuples (if it has `_fields`) via add_global.
+ if isinstance(arg, tuple) and hasattr(arg, '_fields'):
+ qualified_name = _get_qualified_name(type(arg))
+ global_name = add_global(qualified_name, type(arg))
+ return f"{global_name}{repr(tuple(arg))}"
+ return repr(arg)
+
+ args_s = ', '.join(_get_repr(a) for a in args)
+ kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items())
+ if args_s and kwargs_s:
+ return f'{args_s}, {kwargs_s}'
+ return args_s or kwargs_s
+
+ # Run through reverse nodes and record the first instance of a use
+ # of a given node. This represents the *last* use of the node in the
+ # execution order of the program, which we will use to free unused
+ # values
+ node_to_last_use: Dict[Node, Node] = {}
+ user_to_last_uses: Dict[Node, List[Node]] = {}
+
+ def register_last_uses(n: Node, user: Node):
+ if n not in node_to_last_use:
+ node_to_last_use[n] = user
+ user_to_last_uses.setdefault(user, []).append(n)
+
+ for node in reversed(nodes):
+ map_arg(node.args, lambda n: register_last_uses(n, node))
+ map_arg(node.kwargs, lambda n: register_last_uses(n, node))
+
+ # NOTE: we add a variable to distinguish body and ckpt_func
+ def delete_unused_values(user: Node, body):
+ """
+ Delete values after their last use. This ensures that values that are
+ not used in the remainder of the code are freed and the memory usage
+ of the code is optimal.
+ """
+ if user.op == 'placeholder':
+ return
+ if user.op == 'output':
+ body.append('\n')
+ return
+ nodes_to_delete = user_to_last_uses.get(user, [])
+ if len(nodes_to_delete):
+ to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None'])
+ body.append(f'; {to_delete_str}\n')
+ else:
+ body.append('\n')
+
+ # NOTE: we add a variable to distinguish body and ckpt_func
+ def emit_node(node: Node, body):
+ maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}'
+ if node.op == 'placeholder':
+ assert isinstance(node.target, str)
+ maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}'
+ free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}')
+ raw_name = node.target.replace('*', '')
+ if raw_name != repr(node):
+ body.append(f'{repr(node)} = {raw_name}\n')
+ return
+ elif node.op == 'call_method':
+ assert isinstance(node.target, str)
+ body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}'
+ f'({_format_args(node.args[1:], node.kwargs)})')
+ return
+ elif node.op == 'call_function':
+ assert callable(node.target)
+ # pretty print operators
+ if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods:
+ assert isinstance(node.args, tuple)
+ body.append(f'{repr(node)}{maybe_type_annotation} = '
+ f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}')
+ return
+
+ # pretty print inplace operators; required for jit.script to work properly
+ # not currently supported in normal FX graphs, but generated by torchdynamo
+ if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods:
+ body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; '
+ f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}')
+ return
+
+ qualified_name = _get_qualified_name(node.target)
+ global_name = add_global(qualified_name, node.target)
+ # special case for getattr: node.args could be 2-argument or 3-argument
+ # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
+ if global_name == 'getattr' and \
+ isinstance(node.args, tuple) and \
+ isinstance(node.args[1], str) and \
+ node.args[1].isidentifier() and \
+ len(node.args) == 2:
+ body.append(
+ f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}')
+ return
+ body.append(
+ f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})')
+ if node.meta.get('is_wrapped', False):
+ wrapped_fns.setdefault(global_name)
+ return
+ elif node.op == 'call_module':
+ assert isinstance(node.target, str)
+ body.append(f'{repr(node)}{maybe_type_annotation} = '
+ f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})')
+ return
+ elif node.op == 'get_attr':
+ assert isinstance(node.target, str)
+ body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}')
+ return
+ elif node.op == 'output':
+ if node.type is not None:
+ maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
+ body.append(self.generate_output(node.args[0]))
+ return
+ raise NotImplementedError(f'node: {node.op} {node.target}')
+
+ # Modified for activation checkpointing
+ ckpt_func = []
+ emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values)
+
+ if len(body) == 0:
+ # If the Graph has no non-placeholder nodes, no lines for the body
+ # have been emitted. To continue to have valid Python code, emit a
+ # single pass statement
+ body.append('pass\n')
+
+ if len(wrapped_fns) > 0:
+ wrap_name = add_global('wrap', torch.fx.wrap)
+ wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ else:
+ wrap_stmts = ''
+
+ if self._body_transformer:
+ body = self._body_transformer(body)
+
+ for name, value in self.additional_globals():
+ add_global(name, value)
+
+ prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
+ prologue = ''.join(ckpt_func) + prologue
+ prologue = prologue
+
+ code = ''.join(body)
+ code = '\n'.join(' ' + line for line in code.split('\n'))
+ fn_code = f"""
+{wrap_stmts}
+{prologue}
+{code}"""
+ return PythonCode(fn_code, globals_)
diff --git a/colossalai/_analyzer/fx/graph_module.py b/colossalai/_analyzer/fx/graph_module.py
new file mode 100644
index 000000000000..1fdedd758c01
--- /dev/null
+++ b/colossalai/_analyzer/fx/graph_module.py
@@ -0,0 +1,239 @@
+import linecache
+import os
+import sys
+import traceback
+import warnings
+from pathlib import Path
+from typing import Any, Dict, Optional, Union
+
+import torch
+import torch.fx
+import torch.nn as nn
+from torch.fx.graph import PythonCode
+
+try:
+ from torch.fx.graph import _PyTreeCodeGen
+ SUPPORT_PT_CODEGEN = True
+except ImportError:
+ SUPPORT_PT_CODEGEN = False
+
+from torch.fx.graph_module import _exec_with_source, _forward_from_src
+from torch.nn.modules.module import _addindent
+
+
+# This is a copy of torch.fx.graph_module._WrappedCall.
+# It should be removed when we stop supporting torch < 1.12.0.
+class _WrappedCall:
+
+ def __init__(self, cls, cls_call):
+ self.cls = cls
+ self.cls_call = cls_call
+
+ # Previously, if an error occurred when valid
+ # symbolically-traced code was run with an invalid input, the
+ # user would see the source of the error as coming from
+ # `File "`, where N is some number. We use
+ # this function to generate a more informative error message. We
+ # return the traceback itself, a message explaining that the
+ # error occurred in a traced Module's generated forward
+ # function, and five lines of context surrounding the faulty
+ # line
+ @staticmethod
+ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
+ # auxiliary variables (for readability)
+ err_lineno = frame_summary.lineno
+ assert err_lineno is not None
+ line = frame_summary.line
+ assert line is not None
+ err_line_len = len(line)
+ all_src_lines = linecache.getlines(frame_summary.filename)
+
+ # constituent substrings of the error message
+ tb_repr = traceback.format_exc()
+ custom_msg = ("Call using an FX-traced Module, "
+ f"line {err_lineno} of the traced Module's "
+ "generated forward function:")
+ before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
+ marker = "~" * err_line_len + "~~~ <--- HERE"
+ err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
+
+ # joined message
+ return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
+
+ def __call__(self, obj, *args, **kwargs):
+ try:
+ if self.cls_call is not None:
+ return self.cls_call(obj, *args, **kwargs)
+ else:
+ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
+ except Exception as e:
+ assert e.__traceback__
+ topmost_framesummary: traceback.FrameSummary = \
+ traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
+ if "eval_with_key" in topmost_framesummary.filename:
+ print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
+ raise e.with_traceback(None)
+ else:
+ raise e
+
+
+class ColoGraphModule(torch.fx.GraphModule):
+ """
+ ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
+ ColoGraphmodule has a ``graph`` attribute, as well as ``code`` and ``forward``
+ attributes generated from that ``graph``.
+
+ The difference between ``ColoGraphModule`` and ``torch.fx.GraphModule`` is that
+ ``ColoGraphModule`` has a ``bind()`` function to bind customized functions
+ (i.e. activation checkpoint) to ``code`` of ``nn.Module``. If you want to use
+ specific features in Colossal-AI that are not supported by ``torch.fx.GraphModule``,
+ you can use ``ColoGraphModule`` instead.
+
+ ``colossalai.fx.symbolic_trace()`` will return a ``ColoGraphModule`` as default.
+
+ .. warning::
+
+ When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
+ regenerated. However, if you edit the contents of the ``graph`` without reassigning
+ the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
+ code.
+ """
+
+ def __init__(self,
+ root: Union[torch.nn.Module, Dict[str, Any]],
+ graph: torch.fx.Graph,
+ class_name: str = 'GraphModule'):
+ super().__init__(root, graph, class_name)
+
+ def bind(self, ckpt_def, globals):
+ """Bind function needed for correctly execute ``GraphModule.forward()``
+
+ We need to bind checkpoint functions to ``ColoGraphModule`` so that we could
+ correctly execute ``GraphModule.forward()``
+
+ Args:
+ ckpt_def (List[str]): definition before the forward function
+ globals (Dict[str, Any]): global variables
+ """
+
+ ckpt_code = "\n".join(ckpt_def)
+ globals_copy = globals.copy()
+ _exec_with_source(ckpt_code, globals_copy)
+ func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func]
+ for func in func_list:
+ tmp_func = globals_copy[func]
+ setattr(self, func, tmp_func.__get__(self, self.__class__))
+ del globals_copy[func]
+
+ def recompile(self) -> PythonCode:
+ """
+ Recompile this GraphModule from its ``graph`` attribute. This should be
+ called after editing the contained ``graph``, otherwise the generated
+ code of this ``GraphModule`` will be out of date.
+ """
+ if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
+ self._in_spec = self._graph._codegen.pytree_info.in_spec
+ self._out_spec = self._graph._codegen.pytree_info.out_spec
+ python_code = self._graph.python_code(root_module='self')
+ self._code = python_code.src
+
+ # To split ckpt functions code and forward code
+ _code_list = self._code.split("\n")
+ _fwd_def = [item for item in _code_list if "def forward" in item][0]
+ _fwd_idx = _code_list.index(_fwd_def)
+ ckpt_def = _code_list[:_fwd_idx]
+ self._code = "\n".join(_code_list[_fwd_idx:])
+
+ self.bind(ckpt_def, python_code.globals)
+
+ cls = type(self)
+ cls.forward = _forward_from_src(self._code, python_code.globals)
+
+ # Determine whether this class explicitly defines a __call__ implementation
+ # to wrap. If it does, save it in order to have wrapped_call invoke it.
+ # If it does not, wrapped_call can use a dynamic call to super() instead.
+ # In most cases, super().__call__ should be torch.nn.Module.__call__.
+ # We do not want to hold a reference to Module.__call__ here; doing so will
+ # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
+ cls_call = cls.__call__ if "__call__" in vars(cls) else None
+
+ if '_wrapped_call' not in vars(cls):
+ cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
+
+ def call_wrapped(self, *args, **kwargs):
+ return self._wrapped_call(self, *args, **kwargs)
+
+ cls.__call__ = call_wrapped
+
+ # reset self._code to original src, otherwise to_folder will be wrong
+ self._code = python_code.src
+ return python_code
+
+ def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"):
+ """Dumps out module to ``folder`` with ``module_name`` so that it can be
+ imported with ``from import ``
+
+ Args:
+
+ folder (Union[str, os.PathLike]): The folder to write the code out to
+
+ module_name (str): Top-level name to use for the ``Module`` while
+ writing out the code
+ """
+ folder = Path(folder)
+ Path(folder).mkdir(exist_ok=True)
+ torch.save(self.state_dict(), folder / 'state_dict.pt')
+ tab = " " * 4
+
+ # we add import colossalai here
+ model_str = f"""
+import torch
+from torch.nn import *
+import colossalai
+
+
+class {module_name}(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+"""
+
+ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
+ safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
+ if type(module) in safe_reprs:
+ return f"{module.__repr__()}"
+ else:
+ return None
+
+ blobified_modules = []
+ for module_name, module in self.named_children():
+ module_str = _gen_model_repr(module_name, module)
+ if module_str is None:
+ module_file = folder / f'{module_name}.pt'
+ torch.save(module, module_file)
+ blobified_modules.append(module_name)
+ module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
+ module_str = f"torch.load(r'{module_file}') # {module_repr}"
+ model_str += f"{tab*2}self.{module_name} = {module_str}\n"
+
+ for buffer_name, buffer in self._buffers.items():
+ if buffer is None:
+ continue
+ model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
+
+ for param_name, param in self._parameters.items():
+ if param is None:
+ continue
+ model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
+
+ model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
+ model_str += f"{_addindent(self.code, 4)}\n"
+
+ module_file = folder / 'module.py'
+ module_file.write_text(model_str)
+
+ init_file = folder / '__init__.py'
+ init_file.write_text('from .module import *')
+
+ if len(blobified_modules) > 0:
+ warnings.warn("Was not able to save the following children modules as reprs -"
+ f"saved as pickled files instead: {blobified_modules}")
diff --git a/colossalai/_analyzer/fx/node_util.py b/colossalai/_analyzer/fx/node_util.py
new file mode 100644
index 000000000000..8c8956d8ea7c
--- /dev/null
+++ b/colossalai/_analyzer/fx/node_util.py
@@ -0,0 +1,211 @@
+from dataclasses import dataclass, field
+from typing import Callable, ClassVar, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch.autograd.profiler_util import _format_memory, _format_time
+from torch.fx import Graph, GraphModule, Node
+
+from colossalai._analyzer.envs import MeshConfig
+
+
+def intersect(a, b):
+ return {k: a[k] for k in a if k in b}
+
+
+def subtract(a, b):
+ return {k: a[k] for k in a if k not in b}
+
+
+def union(a, b):
+ return {**a, **b}
+
+
+def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
+ """Compute the size of a tensor or a collection of tensors in bytes.
+
+ Args:
+ elem (torch.Tensor | Dict | List | Tuple | int): Arbitrary nested ``torch.Tensor`` data structure.
+
+ Returns:
+ int: The size of the tensor or the collection of tensors in bytes.
+ """
+ nbytes = 0
+ if isinstance(elem, torch.Tensor):
+ if elem.is_quantized:
+ nbytes += elem.numel() * torch._empty_affine_quantized([], dtype=elem.dtype).element_size()
+ else:
+ nbytes += elem.numel() * torch.tensor([], dtype=elem.dtype).element_size()
+ elif isinstance(elem, dict):
+ value_list = [v for _, v in elem.items()]
+ nbytes += compute_size_in_bytes(value_list)
+ elif isinstance(elem, tuple) or isinstance(elem, list) or isinstance(elem, set):
+ for e in elem:
+ nbytes += compute_size_in_bytes(e)
+ return nbytes
+
+
+@dataclass
+class MetaInfo:
+ r"""
+ The base class to store all profiling and static graph analysis information
+ needed for auto-parallel system in Colossal-AI.
+ ============================================================================
+ -------------------------------
+ | FX.Node | <-----
+ [input/param] are ---> |[input/param] [grad_inp]| [grad_inp] contributes to the
+ placeholders (might be | | \__________ | | profiled peak memory in backward
+ saved for backward. | | \ | | pass. [grad_param] is calculated
+ | | \ | | separately.
+ | [interm] -------> [grad_int]| <-----
+ | | \_________ | | [grad_interm] marks the peak
+ | / \ \ | | memory in backward pass.
+ [x] is not counted ---> | [x] [interm] --> [grad_int]| <-----
+ in [interm] because | | \_____ | |
+ it is not saved for | | \ | |
+ backward. | [output] \ | | <----- [output] is potentially
+ ------------------------------- [input] for the next node.
+ ============================================================================
+
+ Accumulate Size = ALL_PREVIOUS_CTX U {Interm Size + Output Size}
+ Output Size = ([output] in global_ctx and not is_alias)
+ Temp Size = ([output] not in global_ctx and not is_alias)
+ Backward Size = ([grad_inp])
+
+ Usage:
+ >>> for node in graph.nodes:
+ >>> n_info = MetaInfo(node) # will create a new MetaInfo instance and store in node.meta['info']
+ >>> # if not exist, otherwise return the existing one
+ >>> n_info.to_recompute = ... # set the to_recompute attribute
+
+ Remarks:
+ This feature is experimental and all the entries are subject to change.
+ """
+
+ # reference
+ node: Node
+
+ # directory
+ mod_dir: str = ''
+
+ # ctx[data_ptr] = Tensor
+ # mark the storage for ctx.save_for_backward
+ global_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # globally shared
+ curr_ctx: Dict[str, torch.Tensor] = field(default_factory=lambda: {}) # global_ctx till this node
+
+ # should be updated after each graph manipulation
+ # ============================== Update ====================================
+ # parameter and buffer within ``Node``
+ parameters: Dict[str, torch.nn.Parameter] = field(default_factory=lambda: {})
+ buffers: Dict[str, torch.Tensor] = field(default_factory=lambda: {})
+
+ inputs: Tuple[torch.Tensor] = ()
+ outputs: Tuple[torch.Tensor] = ()
+ is_alias: Tuple[bool] = () # whether the output is an alias of input
+
+ # compute cost
+ fwd_flop: Optional[int] = 0
+ bwd_flop: Optional[int] = 0
+
+ # communication cost (should be the size in bytes of communication)
+ fwd_comm: Optional[int] = 0
+ bwd_comm: Optional[int] = 0
+
+ # should keep the same whenever manipulated
+ # ============================= Invariant ==================================
+ to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
+ to_offload: Optional[bool] = False
+ sharding_spec: str = 'RR'
+
+ def __new__(cls, node: Node, **kwargs):
+ orig_init = cls.__init__
+
+ # if initialized, return the existing one
+ # should disable the __init__ function
+ if node.meta.get('info', None) is not None:
+
+ def _dummy(self, *args, **kwargs):
+ if getattr(self, '_is_init', False):
+ self._is_init = True
+ orig_init(self, *args, **kwargs)
+ cls.__init__ = orig_init
+
+ cls.__init__ = _dummy
+ return node.meta['info']
+ return super().__new__(cls)
+
+ def __post_init__(self):
+ self.node.meta['info'] = self
+
+ @property
+ def fwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
+ return self.fwd_flop / tflops + self.fwd_comm / bandwidth
+
+ @property
+ def bwd_time(self, tflops: float = MeshConfig.TFLOPS, bandwidth: float = MeshConfig.BANDWIDTH):
+ return self.bwd_flop / tflops + self.bwd_comm / bandwidth
+
+ @property
+ def param_size(self):
+ return compute_size_in_bytes(self.parameters)
+
+ @property
+ def buffer_size(self):
+ return compute_size_in_bytes(self.buffers)
+
+ @property
+ def output_size(self):
+ """Used in CheckpointSolver"""
+ output_ctx = {
+ o.data_ptr(): o
+ for o, is_alias in zip(self.outputs, self.is_alias)
+ if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
+ }
+ return compute_size_in_bytes(intersect(self.global_ctx, output_ctx))
+
+ @property
+ def accumulate_size(self):
+ """Used in CheckpointSolver"""
+ output_ctx = {
+ o.data_ptr(): o
+ for o, is_alias in zip(self.outputs, self.is_alias)
+ if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
+ }
+ return compute_size_in_bytes(union(self.curr_ctx, intersect(self.global_ctx, output_ctx)))
+
+ @property
+ def temp_size(self):
+ """Used in CheckpointSolver"""
+ output_ctx = {
+ o.data_ptr(): o
+ for o, is_alias in zip(self.outputs, self.is_alias)
+ if not is_alias and isinstance(o, torch.Tensor) and not isinstance(o, torch.nn.Parameter)
+ }
+ return compute_size_in_bytes(subtract(output_ctx, self.global_ctx))
+
+ @property
+ def backward_size(self):
+ """Used in CheckpointSolver"""
+ return compute_size_in_bytes(self.inputs)
+
+ def __repr__(self):
+ s = f'Node {self.node.name}'
+ if self.parameters:
+ s += f'\n\thas parameter of size {_format_memory(self.param_size)}'
+ if self.buffers:
+ s += f'\n\thas buffer of size {_format_memory(self.buffer_size)}'
+ if self.output_size:
+ s += f'\n\thas output activation of size {_format_memory(self.output_size)}'
+ # if self.total_size:
+ # s += f'\n\thas total activation of size {_format_memory(self.total_size)}'
+ if self.temp_size:
+ s += f'\n\thas temp activation of size {_format_memory(self.temp_size)}'
+ if self.backward_size:
+ s += f'\n\thas backward activation of size {_format_memory(self.backward_size)}'
+ s += f'\n\tfwd_flop = {self.fwd_flop}'\
+ f'\n\tbwd_flop = {self.bwd_flop}'\
+ f'\n\tfwd_comm = {self.fwd_comm}'\
+ f'\n\tbwd_comm = {self.bwd_comm}'\
+ f'\n\tto_recompute = {self.to_recompute}'\
+ f'\n\tto_offload = {self.to_offload}'\
+ f'\n\tsharding_spec = {self.sharding_spec}'
+ return s
diff --git a/colossalai/_analyzer/fx/passes/__init__.py b/colossalai/_analyzer/fx/passes/__init__.py
new file mode 100644
index 000000000000..ae02d90a236c
--- /dev/null
+++ b/colossalai/_analyzer/fx/passes/__init__.py
@@ -0,0 +1,2 @@
+from .graph_profile import graph_profile_pass
+from .shape_prop import ShapeProp, shape_prop_pass, sim_env
diff --git a/colossalai/_analyzer/fx/passes/graph_profile.py b/colossalai/_analyzer/fx/passes/graph_profile.py
new file mode 100644
index 000000000000..c3e760b31e96
--- /dev/null
+++ b/colossalai/_analyzer/fx/passes/graph_profile.py
@@ -0,0 +1,347 @@
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+
+import torch
+import torch.fx
+from torch.autograd.profiler_util import _format_memory, _format_time
+from torch.fx import GraphModule
+from torch.fx.node import Argument, Node, Target
+
+from colossalai._analyzer._subclasses import flop_count
+from colossalai._analyzer.fx.node_util import MetaInfo
+
+
+def _format_flops(flops: float) -> str:
+ """Returns a formatted FLOP size string"""
+ if flops > 1e12:
+ return f'{flops / 1e12:.2f} TFLOPs'
+ elif flops > 1e9:
+ return f'{flops / 1e9:.2f} GFLOPs'
+ elif flops > 1e6:
+ return f'{flops / 1e6:.2f} MFLOPs'
+ elif flops > 1e3:
+ return f'{flops / 1e3:.2f} kFLOPs'
+ return f'{flops} FLOPs'
+
+
+def _denormalize_tuple(t: Tuple[int, ...]) -> Tuple[int, ...]:
+ return t[0] if len(t) == 1 else t
+
+
+def _normalize_tuple(x):
+ if not isinstance(x, tuple):
+ return (x,)
+ return x
+
+
+def _current_device(module):
+ return next(module.parameters()).device
+
+
+class GraphProfiler(torch.fx.Interpreter):
+ """
+ Fetch shape argument from ``ShapeProp`` without re-executing
+ the ``GraphModule`` from scratch.
+ """
+ _profileable = [
+ 'call_function',
+ 'call_module',
+ 'call_method',
+ ]
+
+ def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
+ super().__init__(module, garbage_collect_values)
+
+ def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
+ """
+ Run `module` via interpretation and return the result.
+
+ Args:
+ *args: The arguments to the Module to run, in positional order
+ initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
+ This is a dict mapping `Node` to any value. This can be used, for example, to
+ pre-populate results for certain `Nodes` so as to do only partial evaluation within
+ the interpreter.
+ enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
+ process_outputs function first before using them.
+
+ Returns:
+ Any: The value returned from executing the Module
+ """
+ self.env = initial_env if initial_env else {}
+
+ # Positional function args are consumed left-to-right by
+ # `placeholder` nodes. Use an iterator to keep track of
+ # position and extract those values.
+ if enable_io_processing:
+ args = self.module.graph.process_inputs(*args)
+ self.args_iter: Iterator[Any] = iter(args)
+
+ for node in self.module.graph.nodes:
+
+ self.run_node(node) # No need to store.
+
+ if self.garbage_collect_values:
+ for to_delete in self.user_to_last_uses.get(node, []):
+ del self.env[to_delete]
+
+ if node.op == 'output':
+ output_val = self.env[node]
+ return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
+
+ def fetch_initial_env(self, device=None) -> Dict[Node, Any]:
+ """
+ Fetch ``initial_env`` for execution. This is because ``ShapeProp``
+ has already attached outputs of each ``Node`` to its ``MetaInfo``.
+
+ Args:
+ device (torch.device): The device to place the execution, default to ``None``
+
+ Returns:
+ Dict[Node, Any]: The initial environment for execution
+ """
+ initial_env = {}
+ for n in self.module.graph.nodes:
+ initial_env[n] = _denormalize_tuple(MetaInfo(n).outputs)
+ return initial_env
+
+ def propagate(self, *args, device=None):
+ """
+ Run `module` via interpretation and profile the execution
+ of each ``Node``.
+
+ Args:
+ *args (Tensor): The sample input, not used
+ device (torch.device): The device to place the execution, default to ``None``
+
+ Returns:
+ Any: The value returned from executing the Module
+ """
+ initial_env = self.fetch_initial_env(device)
+
+ return self.run(initial_env=initial_env)
+
+ def summary(self) -> str:
+ """
+ Summarizes the profiled statistics of the `GraphModule` in
+ tabular format. Note that this API requires the ``tabulate`` module
+ to be installed.
+
+ Returns:
+ str: The summary of the profiled statistics
+ """
+ # https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
+ try:
+ from tabulate import tabulate
+ except ImportError:
+ print("`summary` relies on the library `tabulate`, "
+ "which could not be found on this machine. Run `pip "
+ "install tabulate` to install the library.")
+
+ # Build up a list of summary information for each node
+ node_summaries: List[List[Any]] = []
+ last_n_info = None
+
+ for node in self.module.graph.nodes:
+ node: Node
+ n_info = MetaInfo(node)
+ last_n_info = last_n_info or n_info
+ node_summaries.append([
+ node.op,
+ str(node),
+ _format_memory(n_info.accumulate_size),
+ _format_memory(n_info.accumulate_size - last_n_info.accumulate_size),
+ _format_memory(n_info.output_size),
+ _format_memory(n_info.temp_size),
+ _format_memory(n_info.param_size),
+ _format_memory(n_info.backward_size),
+ _format_flops(n_info.fwd_flop),
+ _format_flops(n_info.bwd_flop),
+ ])
+ last_n_info = n_info
+
+ # Use the ``tabulate`` library to create a well-formatted table
+ # presenting our summary information
+ headers: List[str] = [
+ 'Op type',
+ 'Op',
+ 'Accumulate size',
+ 'Incremental size',
+ 'Output size',
+ 'Temp size',
+ 'Param size',
+ 'Backward size',
+ 'Fwd FLOPs',
+ 'Bwd FLOPs',
+ ]
+
+ return tabulate(node_summaries, headers=headers, stralign='right')
+
+
+class CommunicationProfiler(GraphProfiler):
+ """
+ TODO(lyl): Add this for all comm nodes
+ """
+
+ def __init__(self, module: GraphModule, garbage_collect_values: bool = True):
+ raise NotImplementedError()
+
+
+class FlopProfiler(GraphProfiler):
+ """
+ Execute an FX graph Node-by-Node and record the meta data of the result
+ into the corresponding node.
+
+ Usage:
+ >>> model = MyModule()
+ >>> x = torch.rand(10, 10)
+ >>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x}})
+ >>> shape_interp = ShapeProp(gm) # must do this first
+ >>> shape_interp.propagate(x)
+ >>> profiler = FlopProfiler(gm)
+ >>> profiler.propagate(x)
+
+ Args:
+ module (GraphModule): The module to be executed
+
+ Hints:
+ If you want to add a new flop count rule, you can first
+ check the existing files in ``../_subclasses/flop_tensor.py``.
+ If your flop count rules are incompatible with the existing
+ ones, you can do so by adding a new method to this class
+ with the ``@register_flop_count_impl`` decorator. The method
+ should take (*args, **kwargs) instance as its input and
+ generate flop count for both forward and backward as its
+ output.
+
+ For example, if you want to add a flop count rule for
+ ``my_fn``, which is a hand-written operand not detected by
+ PyTorch, you can do so by adding a new method to this
+ class with the ``@register_flop_count_impl`` decorator:
+
+ >>> @register_flop_count_impl(my_fn)
+ >>> def my_fn_flop_count_impl(*args, **kwargs):
+ >>> return 0, 0
+ """
+ _custom_flop_count_impl = {}
+
+ def run_node(self, n: torch.fx.Node) -> Any:
+ """
+ Run a specific node ``n`` and profile its execution time and memory usage.
+ Calls into call_function, call_method, and call_module only.
+
+ Args:
+ n (Node): The Node to profile
+
+ Returns:
+ Any: The output of the node
+
+ Raises:
+ RuntimeError: If the node is not profileable.
+ """
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
+ n_info = MetaInfo(n)
+
+ if n.op in self._profileable:
+ try:
+ (
+ n_info.fwd_flop,
+ n_info.bwd_flop,
+ ) = getattr(self, n.op)(n.target, args, kwargs)
+ except Exception as e:
+ raise RuntimeError(
+ f'Error {str(e)} occurred when profiling node {n}, node.target = {n.target}. '
+ f'Please refer to function\'s docstring to register the relevant profile_impl for this node!'
+ ) from e
+
+ # retain the autograd graph
+ for param in self.module.parameters():
+ param.grad = None
+
+ return _denormalize_tuple(n_info.outputs)
+
+ def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ """
+ Execute a ``call_function`` node and return the profiling result.
+ Dispatch to ``_custom_flop_count_impl`` if ``call_function`` should be
+ profiled in a user-defined behavior.
+
+ Args:
+ target (Target): The call target for this node. See
+ `Node `__ for
+ details on semantics
+ args (Tuple): Tuple of positional args for this invocation
+ kwargs (Dict): Dict of keyword arguments for this invocation
+
+ Return
+ flop_count (Tuple[int]): (fwd_flop, bwd_flop)
+ """
+ assert not isinstance(target, str)
+
+ # Dispatch the impl for profiling, default will be ``flop_count``
+ if target in self._custom_flop_count_impl:
+ return self._custom_flop_count_impl[target](*args, **kwargs)
+ else:
+ return flop_count(target, *args, **kwargs)
+
+ def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ """
+ Execute a ``call_method`` node and return the profiling result.
+
+ Args:
+ target (Target): The call target for this node. See
+ `Node `__ for
+ details on semantics
+ args (Tuple): Tuple of positional args for this invocation
+ kwargs (Dict): Dict of keyword arguments for this invocation
+
+ Return
+ flop_count (Tuple[int]): (fwd_flop, bwd_flop)
+ """
+ # Execute the method and return the result
+ assert isinstance(target, str)
+ return flop_count(getattr(torch.Tensor, target), *args, **kwargs)
+
+ def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
+ """
+ Execute a ``call_module`` node and return the profiling result.
+
+ Args:
+ target (Target): The call target for this node. See
+ `Node `__ for
+ details on semantics
+ args (Tuple): Tuple of positional args for this invocation
+ kwargs (Dict): Dict of keyword arguments for this invocation
+
+ Return
+ flop_count (Tuple[int]): (fwd_flop, bwd_flop)
+ """
+ # Retrieve executed args and kwargs values from the environment
+
+ # Execute the method and return the result
+ assert isinstance(target, str)
+ submod = self.fetch_attr(target)
+ return flop_count(submod, *args, **kwargs)
+
+
+def graph_profile_pass(module: GraphModule, *args, verbose=False) -> GraphModule:
+ """
+ Run ``module`` via interpretation and profile the execution
+ of each ``Node``.
+
+ Args:
+ module (GraphModule): The GraphModule to profile
+ *args (Any): The sample input, not used
+ verbose (bool): Whether to print the profiling summary
+
+ Returns:
+ GraphModule: The same GraphModule with profiling information
+ """
+ for profiler_cls in (FlopProfiler,
+ # CommunicationProfiler, # TODO: add communication profiling
+ ):
+ profiler = profiler_cls(module)
+ profiler.propagate(*args, device=_current_device(module))
+
+ if verbose:
+ print(profiler.summary())
+ return module
diff --git a/colossalai/_analyzer/fx/passes/shape_prop.py b/colossalai/_analyzer/fx/passes/shape_prop.py
new file mode 100644
index 000000000000..ab3e1a4d6a3d
--- /dev/null
+++ b/colossalai/_analyzer/fx/passes/shape_prop.py
@@ -0,0 +1,211 @@
+"""``torch.fx.ShapeProp``, but with ``MetaTensor``"""
+
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+import torch.fx
+from torch.autograd.graph import saved_tensors_hooks
+from torch.utils._pytree import tree_map
+
+from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
+from colossalai._analyzer.fx.node_util import MetaInfo
+from colossalai.fx._compatibility import compatibility
+
+Target = Union[Callable[..., Any], str]
+
+
+class sim_env(saved_tensors_hooks):
+ """
+ A simulation of memory allocation and deallocation in the forward pass
+ using ``saved_tensor_hooks``.
+
+ Attributes:
+ ctx (Dict[int, torch.Tensor]): A dictionary that maps the
+ data pointer of a tensor to the tensor itself. This is used
+ to track the memory allocation and deallocation.
+
+ param_ctx (Dict[int, torch.Tensor]): A dictionary that maps the
+ data pointer of all model parameters to the parameter itself.
+ This avoids overestimating the memory usage of the intermediate activations.
+ """
+
+ def __init__(self, module: Optional[torch.nn.Module] = None):
+ super().__init__(self.pack_hook, self.unpack_hook)
+ self.ctx = {}
+ self.param_ctx = {param.data_ptr(): param for param in module.parameters()}
+ self.buffer_ctx = {buffer.data_ptr(): buffer for buffer in module.buffers()} if module else {}
+
+ def pack_hook(self, tensor: torch.Tensor):
+ if tensor.data_ptr() not in self.param_ctx and tensor.data_ptr() not in self.buffer_ctx:
+ self.ctx[tensor.data_ptr()] = tensor
+ return tensor
+
+ def unpack_hook(self, tensor):
+ return tensor
+
+
+def _normalize_tuple(x):
+ if not isinstance(x, tuple):
+ return (x,)
+ return x
+
+
+def _current_device(module):
+ return next(module.parameters()).device
+
+
+@compatibility(is_backward_compatible=False)
+class ShapeProp(torch.fx.Interpreter):
+ """
+ Execute an FX graph Node-by-Node and record the meta data of the result
+ into the corresponding node.
+
+ Usage:
+ >>> model = MyModule()
+ >>> x = torch.rand(10, 10)
+ >>> gm = colossalai.fx.symbolic_trace(model, meta_args = {'x': x})
+ >>> interp = ShapeProp(gm)
+ >>> interp.propagate(x)
+
+ Args:
+ module (GraphModule): The module to be executed
+
+ Hints:
+ If you want to add a new shape propagation rule, you can do so by
+ adding a new method to this class with the ``@register_shape_impl``
+ decorator. The method should take (*args, **kwargs) instance as its
+ input and generate output.
+
+ For example, if you want to add a shape propagation rule for
+ ``torch.nn.functional.linear``, you can do so by adding a new method
+ to this class with the ``@register_shape_impl`` decorator (Since the
+ ``MetaTensorMode`` is compatible with ``torch.nn.functional.linear``,
+ in practice you don't have to do as follows):
+
+ >>> @register_shape_impl(torch.nn.functional.linear)
+ >>> def linear_shape_impl(*args, **kwargs):
+ >>> # do something here
+ >>> return torch.empty(output_shape, device=output_device)
+ """
+ _custom_dispatch_func = {}
+ _mode = MetaTensorMode()
+
+ def __init__(self, module: torch.fx.GraphModule, garbage_collect_values: bool = True):
+ super().__init__(module, garbage_collect_values)
+ self.global_hook = sim_env(module=self.module)
+
+ def run_node(self, n: torch.fx.Node) -> Any:
+ """
+ Run a specific node ``n`` and return the result. Attach
+ (
+ ``inputs``, ``outputs``, ``parameters``, ``buffers``
+ ) to ``n``.
+
+ Args:
+ n (Node): The ``Node`` to execute
+
+ Returns:
+ Any: The result of executing ``n``
+ """
+ args, kwargs = self.fetch_args_kwargs_from_env(n)
+ with self.global_hook:
+ r = getattr(self, n.op)(n.target, args, kwargs)
+
+ def unwrap_fn(elem):
+
+ def _convert_meta(t: torch.Tensor):
+ if t.device == 'meta':
+ return t
+ else:
+ return t.to('meta')
+
+ if isinstance(elem, MetaTensor):
+ return _convert_meta(elem._tensor)
+
+ elif isinstance(elem, torch.Tensor):
+ return _convert_meta(elem)
+
+ else:
+ return elem
+
+ # unwrap_fn = lambda elem: elem._tensor if isinstance(elem, MetaTensor) else elem
+ is_pure_tensor = lambda elem: isinstance(elem, MetaTensor) and not isinstance(elem, torch.nn.Parameter)
+ n_info = MetaInfo(n)
+ n_info.outputs = _normalize_tuple(r)
+
+ if n.op == 'call_module':
+ submod = self.fetch_attr(n.target)
+ n_info.parameters.update({k: MetaTensor(v) for k, v in submod.named_parameters()})
+ n_info.buffers.update({k: MetaTensor(v) for k, v in submod.named_buffers()})
+
+ else:
+ n_info.parameters.update({
+ k.name: MetaTensor(v)
+ for k, v in zip(n.args, args)
+ if isinstance(k, torch.fx.Node) and isinstance(v, torch.nn.Parameter)
+ })
+ n_info.parameters.update({k: MetaTensor(v) for k, v in kwargs.items() if isinstance(v, torch.nn.Parameter)})
+
+ n_info.inputs = tuple(v for v in args if is_pure_tensor(v)) + \
+ tuple(v for v in kwargs.values() if is_pure_tensor(v))
+
+ n._meta_data = tree_map(unwrap_fn, _normalize_tuple(r)) # align with SPMD
+
+ n_info.global_ctx = self.global_hook.ctx
+ n_info.curr_ctx = self.global_hook.ctx.copy()
+
+ crit = lambda x: x.data_ptr() in self.global_hook.ctx if isinstance(x, torch.Tensor) else False
+ n_info.is_alias = _normalize_tuple(tree_map(crit, n_info.outputs))
+ return r
+
+ def call_function(self, target: 'Target', args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
+ """
+ Execute a ``call_function`` node and return the result.
+ If the target of ``Node`` is registered with ``@register_shape_impl``,
+ the registered function will be used to execute the node. This is common
+ if we insert some customized kernels.
+
+ Args:
+ target (Target): The call target for this node. See
+ `Node `__ for
+ details on semantics
+ args (Tuple): Tuple of positional args for this invocation
+ kwargs (Dict): Dict of keyword arguments for this invocation
+
+ Return
+ Any: The value returned by the function invocation
+ """
+ if target in self._custom_dispatch_func:
+ return self._custom_dispatch_func[target](*args, **kwargs)
+ else:
+ return super().call_function(target, args, kwargs)
+
+ def propagate(self, *args, device=None):
+ """
+ Run `module` via interpretation and return the result and record the
+ shape of each node.
+ Args:
+ *args (Tensor): The sample input.
+ Returns:
+ Any: The value returned from executing the Module
+ """
+ wrap_fn = lambda elem: MetaTensor(elem, device=device)
+ with self._mode:
+ return super().run(*tree_map(wrap_fn, args))
+
+
+def shape_prop_pass(module: torch.fx.GraphModule, *args) -> torch.fx.GraphModule:
+ """
+ Run ``module`` via interpretation and return the result and record the
+ shape of each ``Node``.
+
+ Args:
+ module (GraphModule): The GraphModule to profile
+ *args (Any): The sample input
+
+ Returns:
+ GraphModule: The same GraphModule with shape information
+ """
+
+ ShapeProp(module).propagate(*args, device=_current_device(module))
+ return module
diff --git a/colossalai/_analyzer/fx/symbolic_profile.py b/colossalai/_analyzer/fx/symbolic_profile.py
new file mode 100644
index 000000000000..dd7f22c6c98a
--- /dev/null
+++ b/colossalai/_analyzer/fx/symbolic_profile.py
@@ -0,0 +1,40 @@
+import torch
+import torch.fx
+from torch.fx import GraphModule
+
+from .passes import ShapeProp, graph_profile_pass, shape_prop_pass
+from .passes.graph_profile import FlopProfiler
+
+
+def register_flop_count_impl(func):
+
+ def wrapper(impl):
+ FlopProfiler._custom_flop_count_impl[func] = impl
+ return impl
+
+ return wrapper
+
+
+def register_shape_impl(func):
+
+ def wrapper(impl):
+ ShapeProp._custom_dispatch_func[func] = impl
+ return impl
+
+ return wrapper
+
+
+def symbolic_profile(module: GraphModule, *args, verbose=False) -> GraphModule:
+ """Symbolically profile a model with sample inputs.
+
+ Args:
+ module (GraphModule): The module to be profiled
+ args (Tuple): The sample inputs
+ verbose (bool): Whether to print the profiling result
+
+ Returns:
+ GraphModule: The profiled module
+ """
+ module = shape_prop_pass(module, *args)
+ module = graph_profile_pass(module, *args, verbose=verbose)
+ return module
diff --git a/colossalai/_analyzer/fx/tracer/__init__.py b/colossalai/_analyzer/fx/tracer/__init__.py
new file mode 100644
index 000000000000..6b1b2256aa44
--- /dev/null
+++ b/colossalai/_analyzer/fx/tracer/__init__.py
@@ -0,0 +1,2 @@
+from .bias_addition import *
+from .custom_leaf_module import *
diff --git a/colossalai/_analyzer/fx/tracer/bias_addition.py b/colossalai/_analyzer/fx/tracer/bias_addition.py
new file mode 100644
index 000000000000..1e75b47ca5b0
--- /dev/null
+++ b/colossalai/_analyzer/fx/tracer/bias_addition.py
@@ -0,0 +1,154 @@
+"""
+If FX.Graph is traced for auto-parallel module, some extra node will be added during
+graph construction to deal with the compatibility between bias-addition and all-reduce.
+"""
+
+import torch
+import torch.nn.functional as F
+from torch.nn.modules.utils import _pair, _single, _triple
+
+from .tracer import register_tracer_impl
+
+__all__ = []
+
+
+@register_tracer_impl(F.linear, name='_bias_addition_impl')
+def linear_impl(input, weight, bias=None):
+ if bias is None:
+ return F.linear(input, weight)
+ else:
+ return F.linear(input, weight) + bias
+
+
+@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
+def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
+ if bias is None:
+ return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ else:
+ return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
+ (-1, 1))
+
+
+@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
+def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
+ if bias is None:
+ return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ else:
+ return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
+ (-1, 1, 1))
+
+
+@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
+def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
+ if bias is None:
+ return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ else:
+ return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
+ (-1, 1, 1, 1))
+
+
+@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
+def conv_transpose1d_impl(input,
+ weight,
+ bias=None,
+ stride=_single(1),
+ padding=_single(0),
+ output_padding=_single(0),
+ groups=1,
+ dilation=_single(1)):
+ if bias is None:
+ return F.conv_transpose1d(input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation)
+ else:
+ return F.conv_transpose1d(input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation) + bias.reshape((-1, 1))
+
+
+@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
+def conv_transpose2d_impl(input,
+ weight,
+ bias=None,
+ stride=_pair(1),
+ padding=_pair(0),
+ output_padding=_pair(0),
+ groups=1,
+ dilation=_pair(1)):
+ if bias is None:
+ return F.conv_transpose2d(input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation)
+ else:
+ return F.conv_transpose2d(input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation) + bias.reshape((-1, 1, 1))
+
+
+@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
+def conv_transpose3d_impl(input,
+ weight,
+ bias=None,
+ stride=_triple(1),
+ padding=_triple(0),
+ output_padding=_triple(0),
+ groups=1,
+ dilation=_triple(1)):
+ if bias is None:
+ return F.conv_transpose3d(input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation)
+ else:
+ return F.conv_transpose3d(input,
+ weight,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups,
+ dilation=dilation) + bias.reshape((-1, 1, 1, 1))
+
+
+@register_tracer_impl(torch.addmm, name='_bias_addition_impl')
+@register_tracer_impl(torch.Tensor.addmm, name='_bias_addition_impl')
+def addmm_impl(input, mat1, mat2, beta=1, alpha=1):
+ if alpha != 1 and beta != 1:
+ return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input * beta
+ elif alpha != 1:
+ return F.linear(mat1, mat2.transpose(0, 1)) * alpha + input
+ elif beta != 1:
+ return F.linear(mat1, mat2.transpose(0, 1)) + input * beta
+ else:
+ return F.linear(mat1, mat2.transpose(0, 1)) + input
+
+
+@register_tracer_impl(torch.addbmm, name='_bias_addition_impl')
+@register_tracer_impl(torch.Tensor.addbmm, name='_bias_addition_impl')
+def addbmm_impl(input, batch1, batch2, beta=1, alpha=1):
+ if alpha != 1 and beta != 1:
+ return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input * beta
+ elif alpha != 1:
+ return torch.bmm(batch1, batch2.transpose(1, 2)) * alpha + input
+ elif beta != 1:
+ return torch.bmm(batch1, batch2.transpose(1, 2)) + input * beta
+ else:
+ return torch.bmm(batch1, batch2.transpose(1, 2)) + input
diff --git a/colossalai/_analyzer/fx/tracer/custom_leaf_module.py b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
new file mode 100644
index 000000000000..112c7c9637d2
--- /dev/null
+++ b/colossalai/_analyzer/fx/tracer/custom_leaf_module.py
@@ -0,0 +1,29 @@
+import torch
+
+from .tracer import register_leaf_module, register_leaf_module_impl
+
+try:
+ import apex
+ register_leaf_module(apex.normalization.FusedLayerNorm)
+ register_leaf_module(apex.normalization.FusedRMSNorm)
+ register_leaf_module(apex.normalization.MixedFusedLayerNorm)
+ register_leaf_module(apex.normalization.MixedFusedRMSNorm)
+
+ @register_leaf_module_impl(apex.normalization.FusedLayerNorm)
+ @register_leaf_module_impl(apex.normalization.FusedRMSNorm)
+ @register_leaf_module_impl(apex.normalization.MixedFusedLayerNorm)
+ @register_leaf_module_impl(apex.normalization.MixedFusedRMSNorm)
+ def torch_nn_normalize(self, input: torch.Tensor):
+ # check shape
+ if isinstance(self, torch.nn.BatchNorm1d):
+ assert input.dim() in [2, 3]
+ elif isinstance(self, torch.nn.BatchNorm2d):
+ assert input.dim() == 4
+ elif isinstance(self, torch.nn.BatchNorm3d):
+ assert input.dim() == 5
+
+ # normalization maintain the same shape as the input
+ return input.clone()
+
+except (ImportError, AttributeError):
+ pass
diff --git a/colossalai/_analyzer/fx/tracer/proxy.py b/colossalai/_analyzer/fx/tracer/proxy.py
new file mode 100644
index 000000000000..ce379efdcf0d
--- /dev/null
+++ b/colossalai/_analyzer/fx/tracer/proxy.py
@@ -0,0 +1,112 @@
+import operator
+from typing import Any, Callable, Dict, Optional, Set, Union
+
+import torch
+import torch.nn as nn
+from torch.fx import Graph, Node, Proxy, Tracer
+from torch.fx.graph import _Namespace
+from torch.utils._pytree import tree_map
+
+from colossalai._analyzer._subclasses import MetaTensor
+
+Target = Union[Callable[..., Any], str]
+
+
+class ColoProxy(Proxy):
+ _func_dispatch: Dict[Target, Callable[..., Any]] = {}
+
+ def __init__(self, *args, data=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._meta_data = data
+
+ @property
+ def meta_data(self):
+ return self._meta_data
+
+ @meta_data.setter
+ def meta_data(self, args):
+ wrap_fn = lambda x: MetaTensor(x) if isinstance(x, torch.Tensor) else x
+ self._meta_data = tree_map(wrap_fn, args)
+
+ @classmethod
+ def __torch_function__(cls, orig_method, types, args=(), kwargs=None):
+ kwargs = {} if kwargs is None else kwargs
+ if orig_method in cls._func_dispatch:
+ impl = cls._func_dispatch.pop(orig_method) # avoid recursion
+ proxy = impl(*args, **kwargs)
+ cls._func_dispatch[orig_method] = impl
+ return proxy
+ else:
+ proxy = cls.from_torch_proxy(super().__torch_function__(orig_method, types, args, kwargs))
+ unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
+ if proxy.meta_data is None:
+ proxy.meta_data = orig_method(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
+ return proxy
+
+ @classmethod
+ def from_torch_proxy(cls, proxy: Proxy):
+ return cls(proxy.node, proxy.tracer)
+
+ def __repr__(self):
+ return f"ColoProxy({self.node.name}, meta_data={self.meta_data})"
+
+ def __len__(self):
+ return len(self.meta_data)
+
+ def __int__(self):
+ return int(self.meta_data)
+
+ def __index__(self):
+ try:
+ return int(self.meta_data)
+ except:
+ return torch.zeros(self.meta_data.shape, dtype=torch.bool).numpy().__index__()
+
+ def __float__(self):
+ return float(self.meta_data)
+
+ def __bool__(self):
+ return self.meta_data
+
+ def __getattr__(self, k):
+ return ColoAttribute(self, k, getattr(self._meta_data, k, None))
+
+ def __setitem__(self, key, value):
+ proxy = self.tracer.create_proxy('call_function', operator.setitem, (self, key, value), {})
+ proxy.meta_data = self._meta_data
+ return proxy
+
+ def __contains__(self, key):
+ if self.node.op == "placeholder":
+ # this is used to handle like
+ # if x in kwargs
+ # we don't handle this case for now
+ return False
+ return super().__contains__(key)
+
+ def __isinstancecheck__(self, type):
+ return isinstance(self.meta_data, type)
+
+
+class ColoAttribute(ColoProxy):
+
+ def __init__(self, root, attr: str, data=None):
+ self.root = root
+ self.attr = attr
+ self.tracer = root.tracer
+ self._meta_data = data
+ self._node: Optional[Node] = None
+
+ @property
+ def node(self):
+ # the node for attributes is added lazily, since most will just be method calls
+ # which do not rely on the getitem call
+ if self._node is None:
+ self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
+ return self._node
+
+ def __call__(self, *args, **kwargs):
+ return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
+
+ def __repr__(self):
+ return f"ColoAttribute({self.node.name}, attr={self.attr})"
diff --git a/colossalai/_analyzer/fx/tracer/symbolic_trace.py b/colossalai/_analyzer/fx/tracer/symbolic_trace.py
new file mode 100644
index 000000000000..2018863f6f5f
--- /dev/null
+++ b/colossalai/_analyzer/fx/tracer/symbolic_trace.py
@@ -0,0 +1,157 @@
+from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
+
+import torch
+from torch.fx import Tracer
+from torch.utils._pytree import tree_map
+
+from colossalai._analyzer._subclasses import MetaTensor
+
+try:
+ from ..codegen import ActivationCheckpointCodeGen
+ SUPPORT_ACTIVATION = True
+except:
+ SUPPORT_ACTIVATION = False
+from ..graph_module import ColoGraphModule
+from .tracer import ColoTracer
+
+
+def _default_device():
+ return torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
+
+
+def _current_device(module: torch.nn.Module):
+ try:
+ return next(module.parameters()).device
+ except:
+ return _default_device()
+
+
+def symbolic_trace(
+ root: Union[torch.nn.Module, Callable[..., Any]],
+ concrete_args: Optional[Dict[str, Any]] = None,
+ meta_args: Optional[Dict[str, Any]] = None,
+ trace_act_ckpt: bool = False,
+ bias_addition_split: bool = False,
+) -> ColoGraphModule:
+ """
+ Traces a ``torch.nn.Module`` or a function and returns a ``GraphModule`` with ``Node``s and ``MetaInfo``
+ attached to the ``Node``s.
+
+ Can be used to trace the usage of ``torch.utils.checkpoint`` and the path of module
+ (https://github.com/pytorch/examples/blob/main/fx/module_tracer.py).
+
+ This tracer is able to trace basic control flow and for loops.
+
+ It will split the bias addition into two parts if ``bias_addition_split`` is set to be ``True``.
+ (See ./bias_addition.py for more details).
+
+ Examples:
+ 1. Tracing a ``torch.nn.Module`` with control flow.
+
+ .. code-block:: python
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(2, 2)
+
+ def forward(self, x):
+ if x.size(0) > 1:
+ x = x.sum(dim=0)
+ return self.linear(x)
+
+ traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)})
+
+ # traced code like:
+ # def forward(self, x):
+ # linear_1 = self.linear(x)
+ # return linear_1
+
+ traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(2, 2, 2)})
+
+ # traced code like:
+ # def forward(self, x):
+ # sum = x.sum(dim=0); x = None
+ # linear = self.linear(sum); sum = None
+ # return linear
+
+ 2. Tracing a ``torch.nn.Module`` with ``torch.utils.checkpoint``.
+
+ .. code-block:: python
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(2, 2)
+
+ def forward(self, x):
+ def custom_forward(x):
+ return self.linear(x)
+ return torch.utils.checkpoint.checkpoint(custom_forward, x)
+
+ traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, trace_act_ckpt=True)
+
+ # traced code like:
+ # def checkpoint_0(self, x):
+ # linear = self.linear(x); x = None
+ # return linear
+ #
+ # def forward(self, x):
+ # linear = torch.utils.checkpoint.checkpoint(checkpoint_0, x); x = None
+ # return linear
+
+ 3. Tracing a ``torch.nn.Module`` with ``bias_addition_split``.
+
+ .. code-block:: python
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(2, 2, bias=True)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ traced = symbolic_trace(MyModule(), meta_args={'x': torch.randn(1, 2, 2)}, bias_addition_split=True)
+
+ # traced code like:
+ # def forward(self, x):
+ # linear_bias = self.linear.bias
+ # linear_weight = self.linear.weight
+ # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
+ # add = linear + linear_bias; linear = linear_bias = None
+ # return add
+
+ Args:
+ root (Union[torch.nn.Module, Callable[..., Any]]): The ``torch.nn.Module`` or function to be traced.
+ concrete_args (Optional[Dict[str, Any]], optional): Concrete arguments to be passed to the ``root``.
+ Defaults to {}.
+ meta_args (Optional[Dict[str, Any]], optional): Meta arguments to be passed to the ``root``. Mostly used
+ for tracing control flow. Defaults to {}.
+ trace_act_ckpt (bool, optional): Whether to trace the usage of ``torch.utils.checkpoint``.
+ Defaults to False.
+ bias_addition_split (bool, optional): Whether to split the bias addition into two parts. Defaults to False.
+
+ Returns:
+ ColoGraphModule: A traced ``GraphModule`` that is ready for activation checkpoint ``CodeGen``.
+
+ Remarks:
+ This part of ``symbolic_trace()`` is maintained by Colossal-AI team. If you encountered
+ any unexpected error during tracing, feel free to raise an issue on Colossal-AI GitHub
+ repo. We welcome any feedback and contributions to enhance the extensibility of
+ Colossal-AI.
+ """
+ if meta_args:
+ device, orig_device = _default_device(), _current_device(root)
+ wrap_fn = lambda elem: MetaTensor(elem, device=device) if isinstance(elem, torch.Tensor) else elem
+ graph = ColoTracer(trace_act_ckpt=trace_act_ckpt,
+ bias_addition_split=bias_addition_split).trace(root.to(device),
+ concrete_args=concrete_args,
+ meta_args=tree_map(wrap_fn, meta_args))
+ if trace_act_ckpt and SUPPORT_ACTIVATION:
+ graph.set_codegen(ActivationCheckpointCodeGen())
+ root.to(orig_device)
+ else:
+ graph = Tracer().trace(root, concrete_args=concrete_args)
+ name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
+ return ColoGraphModule(root, graph, name)
diff --git a/colossalai/_analyzer/fx/tracer/tracer.py b/colossalai/_analyzer/fx/tracer/tracer.py
new file mode 100644
index 000000000000..1a247449f3d8
--- /dev/null
+++ b/colossalai/_analyzer/fx/tracer/tracer.py
@@ -0,0 +1,363 @@
+import functools
+import inspect
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, Iterable, Optional, Set, Tuple, Type, Union
+
+import torch
+import torch.nn as nn
+from torch.fx import Graph, Node, Proxy, Tracer
+from torch.utils._pytree import tree_map
+
+from colossalai._analyzer._subclasses import _TensorPropertyMethod, _TorchFactoryMethod
+
+from ..node_util import MetaInfo
+from .proxy import ColoProxy
+
+Target = Union[Callable[..., Any], str]
+
+
+def _truncate_suffix(s: str):
+ import re
+
+ # FIXME: don't know why but torch.fx always gets a suffix like '_1' in the name
+ return re.sub(r'_\d+$', '', s)
+
+
+def register_tracer_impl(func: Callable[..., Any], name: Optional[str] = '_custom_impl'):
+
+ def wrapper(impl):
+ assert hasattr(ColoTracer, name), f"Cannot register {func.__name__} in ColoTracer.{name}"
+ getattr(ColoTracer, name)[func] = impl
+ return impl
+
+ return wrapper
+
+
+def register_leaf_module_impl(module: nn.Module):
+
+ def wrapper(impl):
+ ColoTracer._custom_leaf_module_impl[module] = impl
+ return impl
+
+ return wrapper
+
+
+def register_leaf_module(module: nn.Module):
+ ColoTracer._custom_leaf_module.add(module)
+
+
+def register_non_leaf_module(module: nn.Module):
+ ColoTracer._custom_non_leaf_module.add(module)
+
+
+class ColoTracer(Tracer):
+ _custom_leaf_module: Set[Type[nn.Module]] = set()
+ _custom_leaf_module_impl: Dict[Type[nn.Module], Callable[..., Any]] = {}
+ _custom_non_leaf_module: Set[Type[nn.Module]] = set()
+ _custom_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
+ _bias_addition_impl: Dict[Callable[..., Any], Callable[..., Any]] = {}
+ _bias_addition_module = [
+ torch.nn.Linear,
+ torch.nn.Conv1d,
+ torch.nn.Conv2d,
+ torch.nn.Conv3d,
+ torch.nn.ConvTranspose1d,
+ torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d,
+ ]
+
+ def __init__(self, trace_act_ckpt: bool = False, bias_addition_split: bool = False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.disable_module_getattr = False
+ self.proxy_buffer_attributes = True
+
+ # whether the tracer will record the usage of torch.utils.checkpoint
+ self.trace_act_ckpt = trace_act_ckpt
+ self.ckpt_regions = []
+ self.ckpt_idx = 0
+
+ self.mod_dir = ''
+
+ # whether the tracer should split the bias_add ops into two ops
+ self.bias_addition_split = bias_addition_split
+
+ def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool:
+ # if bias-addiction split is enabled, and module has bias, then it is not a leaf module
+ # we will enter the module and split the bias-addition ops
+ if self.bias_addition_split and type(m) in self._bias_addition_module and m.bias is not None:
+ return False
+ # user can specify which modules are leaf modules and which are not
+ return (type(m) not in self._custom_non_leaf_module
+ and (type(m) in self._custom_leaf_module or super().is_leaf_module(m, module_qualified_name)))
+
+ def call_module(self, m: torch.nn.Module, forward: Callable[..., Any], args: Tuple[Any, ...],
+ kwargs: Dict[str, Any]) -> Any:
+ curr_dir = self.mod_dir
+ self.mod_dir = 'self.' + self.path_of_module(m)
+ rst = super().call_module(m, forward, args, kwargs)
+ self.mod_dir = curr_dir
+ return rst
+
+ def proxy(self, node: Node) -> 'ColoProxy':
+ return ColoProxy(node, self)
+
+ def create_proxy(self,
+ kind: str,
+ target: Target,
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ name: Optional[str] = None,
+ type_expr: Optional[Any] = None,
+ proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
+
+ proxy: ColoProxy = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
+ unwrap_fn = lambda p: p.meta_data if isinstance(p, ColoProxy) else p
+ if kind == 'placeholder':
+ proxy.meta_data = self.meta_args[target] if target in self.meta_args else self.concrete_args.get(
+ _truncate_suffix(target), None)
+ elif kind == 'get_attr':
+ self.disable_module_getattr = True
+ try:
+ attr_itr = self.root
+ atoms = target.split(".")
+ for atom in atoms:
+ attr_itr = getattr(attr_itr, atom)
+ proxy.meta_data = attr_itr
+ finally:
+ self.disable_module_getattr = False
+ elif kind == 'call_function':
+ proxy.meta_data = target(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
+ elif kind == 'call_method':
+ self.disable_module_getattr = True
+ try:
+ if target == '__call__':
+ proxy.meta_data = unwrap_fn(args[0])(*tree_map(unwrap_fn, args[1:]), **tree_map(unwrap_fn, kwargs))
+ else:
+ if target not in _TensorPropertyMethod:
+ proxy._meta_data = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
+ **tree_map(unwrap_fn, kwargs))
+ finally:
+ self.disable_module_getattr = False
+ elif kind == 'call_module':
+ mod = self.root.get_submodule(target)
+ self.disable_module_getattr = True
+ try:
+ args = tree_map(unwrap_fn, args)
+ kwargs = tree_map(unwrap_fn, kwargs)
+ if type(mod) in self._custom_leaf_module:
+ target = self._custom_leaf_module_impl[type(mod)]
+ proxy.meta_data = target(mod, *args, **kwargs)
+ else:
+ proxy.meta_data = mod.forward(*args, **kwargs)
+ finally:
+ self.disable_module_getattr = False
+ return proxy
+
+ def create_node(self, *args, **kwargs) -> Node:
+ node = super().create_node(*args, **kwargs)
+ n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
+ return node
+
+ def trace(self,
+ root: torch.nn.Module,
+ concrete_args: Optional[Dict[str, torch.Tensor]] = None,
+ meta_args: Optional[Dict[str, torch.Tensor]] = None) -> Graph:
+
+ if meta_args is None:
+ meta_args = {}
+
+ if concrete_args is None:
+ concrete_args = {}
+
+ # check concrete and meta args have valid names
+ sig = inspect.signature(root.forward)
+ sig_names = set(sig.parameters.keys())
+ meta_arg_names = set(meta_args.keys())
+ concrete_arg_names = set(concrete_args.keys())
+ non_concrete_arg_names = sig_names - concrete_arg_names
+ # update concrete args with default values
+ for k, v in sig.parameters.items():
+ if k in sig_names - meta_arg_names and \
+ k not in concrete_args and \
+ v.default is not inspect.Parameter.empty:
+ concrete_args[k] = v.default
+
+ def _check_arg_name_valid(names: Iterable[str]):
+ for name in names:
+ if name not in sig_names:
+ raise ValueError(f"Argument {name} is not in the signature of {root.__class__.__name__}.forward")
+
+ _check_arg_name_valid(meta_arg_names)
+ _check_arg_name_valid(concrete_arg_names)
+
+ self.concrete_args = concrete_args
+ self.meta_args = meta_args
+
+ with self._torch_factory_override(), self._tracer_override(), torch.no_grad():
+ self.mod_dir = 'self'
+ self.graph = super().trace(root, concrete_args=concrete_args)
+ self.mod_dir = ''
+ self.graph.lint()
+
+ for node in self.graph.nodes:
+ if node.op == "placeholder":
+ # Removing default values for inputs as the forward pass will fail with them.
+ if node.target in non_concrete_arg_names:
+ node.args = ()
+ # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
+ # It cannot infer on the attributes and methods the input should have, and fails.
+ node.type = torch.Tensor
+ # It is a concrete arg so it is not used and should be removed.
+ else:
+ if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
+ # Newer versions of torch.fx emit an assert statement
+ # for concrete arguments; delete those before we delete
+ # the concrete arg.
+ to_delete = []
+ for user in node.users:
+ if user.target == torch.fx._symbolic_trace._assert_is_none:
+ to_delete.append(user)
+ for user in to_delete:
+ self.graph.erase_node(user)
+
+ self.graph.erase_node(node)
+
+ # TODO: solves GraphModule creation.
+ # Without this, return type annotation "Tuple" is causing code execution failure.
+ if node.op == "output":
+ node.type = None
+ return self.graph
+
+ @contextmanager
+ def _tracer_override(self):
+ # override the tracer to support custom modules and checkpointing
+ if self.trace_act_ckpt:
+ orig_ckpt_func_apply = torch.utils.checkpoint.CheckpointFunction.apply
+ orig_ckpt_func_without_reentrant = torch.utils.checkpoint._checkpoint_without_reentrant
+
+ def checkpoint(run_function, preserve_rng_state=False, *args):
+ self.ckpt_regions.append(self.ckpt_idx)
+ out = run_function(*args)
+ self.ckpt_idx = self.ckpt_regions.pop(-1) + 1
+ return out
+
+ # override the checkpoint function
+ torch.utils.checkpoint.CheckpointFunction.apply = checkpoint
+ torch.utils.checkpoint._checkpoint_without_reentrant = checkpoint
+
+ # override the custom functions
+ ColoProxy._func_dispatch.update({k: v for k, v in self._custom_impl.items()})
+
+ # override the bias addition functions
+ if self.bias_addition_split:
+ ColoProxy._func_dispatch.update({k: v for k, v in self._bias_addition_impl.items()})
+
+ yield
+
+ if self.trace_act_ckpt:
+ # recover the checkpoint function upon exit
+ torch.utils.checkpoint.CheckpointFunction.apply = orig_ckpt_func_apply
+ torch.utils.checkpoint._checkpoint_reentrant = orig_ckpt_func_without_reentrant
+
+ ColoProxy._func_dispatch = {}
+
+ @contextmanager
+ def _torch_factory_override(self):
+ # override the torch factory functions to create a proxy when the method
+ # is called during ``symbolic_trace()``.
+ def wrap_factory_method(target):
+
+ @functools.wraps(target)
+ def wrapper(*args, **kwargs):
+ is_proxy = any(isinstance(p, ColoProxy) for p in args) | any(
+ isinstance(p, ColoProxy) for p in kwargs.values())
+ if is_proxy:
+ # if the arg is a proxy, then need to record this function called on this proxy
+ # e.g. torch.ones(size) where size is an input proxy
+ self.disable_module_getattr = True
+ try:
+ proxy = self.create_proxy('call_function', target, args, kwargs)
+ finally:
+ self.disable_module_getattr = False
+ return proxy
+ else:
+ return target(*args, **kwargs)
+
+ return wrapper, target
+
+ overrides = {
+ target: wrap_factory_method(getattr(torch, target))
+ for target in _TorchFactoryMethod
+ if callable(getattr(torch, target))
+ }
+ for name, (wrapper, orig) in overrides.items():
+ setattr(torch, name, wrapper)
+
+ yield
+
+ # recover the torch factory functions upon exit
+ for name, (wrapper, orig) in overrides.items():
+ setattr(torch, name, orig)
+
+ def _post_check(self, non_concrete_arg_names: Set[str]):
+ # This is necessary because concrete args are added as input to the traced module since
+ # https://github.com/pytorch/pytorch/pull/55888.
+ for node in self.graph.nodes:
+ if node.op == "placeholder":
+ # Removing default values for inputs as the forward pass will fail with them.
+ if node.target in non_concrete_arg_names:
+ node.args = ()
+ # Without this, torch.jit.script fails because the inputs type is Optional[torch.Tensor].
+ # It cannot infer on the attributes and methods the input should have, and fails.
+ node.type = torch.Tensor
+ # It is a concrete arg so it is not used and should be removed.
+ else:
+ if hasattr(torch.fx._symbolic_trace, "_assert_is_none"):
+ # Newer versions of torch.fx emit an assert statement
+ # for concrete arguments; delete those before we delete
+ # the concrete arg.
+ to_delete = []
+ for user in node.users:
+ if user.target == torch.fx._symbolic_trace._assert_is_none:
+ to_delete.append(user)
+ for user in to_delete:
+ self.graph.erase_node(user)
+
+ self.graph.erase_node(node)
+
+ if node.op == "output":
+ node.type = None
+ self.graph.lint()
+
+ def getattr(self, attr, attr_val, parameter_proxy_cache):
+ return self._module_getattr(attr, attr_val, parameter_proxy_cache)
+
+ def _module_getattr(self, attr, attr_val, parameter_proxy_cache):
+ if getattr(self, "disable_module_getattr", False):
+ return attr_val
+
+ def maybe_get_proxy_for_attr(attr_val, collection_to_search, parameter_proxy_cache):
+ for n, p in collection_to_search:
+ if attr_val is p:
+ if n not in parameter_proxy_cache:
+ kwargs = {}
+ if 'proxy_factory_fn' in inspect.signature(self.create_proxy).parameters:
+ kwargs['proxy_factory_fn'] = (None if not self.param_shapes_constant else
+ lambda node: ColoProxy(self, node, n, attr_val))
+ val_proxy = self.create_proxy('get_attr', n, (), {}, **kwargs) # type: ignore[arg-type]
+ parameter_proxy_cache[n] = val_proxy
+ return parameter_proxy_cache[n]
+ return None
+
+ if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor):
+ maybe_buffer_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_buffers(), parameter_proxy_cache)
+ if maybe_buffer_proxy is not None:
+ return maybe_buffer_proxy
+
+ if isinstance(attr_val, torch.nn.Parameter):
+ maybe_parameter_proxy = maybe_get_proxy_for_attr(attr_val, self.root.named_parameters(),
+ parameter_proxy_cache)
+ if maybe_parameter_proxy is not None:
+ return maybe_parameter_proxy
+
+ return attr_val
diff --git a/colossalai/amp/apex_amp/apex_amp.py b/colossalai/amp/apex_amp/apex_amp.py
index 69a4e348e5a7..e6bdbe4520f9 100644
--- a/colossalai/amp/apex_amp/apex_amp.py
+++ b/colossalai/amp/apex_amp/apex_amp.py
@@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-
import torch.nn as nn
+
try:
import apex.amp as apex_amp
except ImportError:
diff --git a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
index 6d6f2f287e32..e899b9ca4c89 100644
--- a/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
+++ b/colossalai/amp/naive_amp/grad_scaler/dynamic_grad_scaler.py
@@ -58,10 +58,12 @@ def _sanity_checks(self) -> None:
if self._min_scale:
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
+ assert self._min_scale <= self._scale, 'The minimum gradient scale cannot be greater than the current scale'
if self._max_scale:
- assert self._min_scale > 0, 'The maximum gradient scale cannot be zero or negative'
+ assert self._max_scale > 0, 'The maximum gradient scale cannot be zero or negative'
+ assert self._max_scale >= self._scale, 'The maximum gradient scale cannot be smaller than the current scale'
assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
- assert self._backoff_factor < 1 and self._backoff_factor > 0, 'The backoff factor must be between 0 and 1'
+ assert 0 < self._backoff_factor < 1, 'The backoff factor must be between 0 and 1'
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
def update(self, overflow: bool) -> None:
@@ -103,3 +105,17 @@ def _grow_scale(self) -> None:
self._scale = self._scale * self._growth_factor
if self._max_scale:
self._scale = torch.min(self._scale, self._max_scale)
+
+ def state_dict(self):
+ state_dict = dict()
+ state_dict['scale'] = self._scale
+ state_dict['growth_factor'] = self._growth_factor
+ state_dict['backoff_factor'] = self._backoff_factor
+ state_dict['hysteresis'] = self._hysteresis
+ return state_dict
+
+ def load_state_dict(self, state_dict):
+ self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
+ self._growth_factor = state_dict['growth_factor']
+ self._backoff_factor = state_dict['backoff_factor']
+ self._hysteresis = state_dict['hysteresis']
diff --git a/colossalai/auto_parallel/README.md b/colossalai/auto_parallel/README.md
new file mode 100644
index 000000000000..8e47e1bb0b4a
--- /dev/null
+++ b/colossalai/auto_parallel/README.md
@@ -0,0 +1,23 @@
+# Colossal-AUTO
+
+## Challenges
+Recently, large models have achieved the state of the art performances in various fields. In order to support large model training, we have to use distributed training techniques. However, finding an efficient distributed execution plan not only requires fine-grained model statistics, such as memory and computing overhead of each operator but also is a labor-intensive task even for an expert in the field of distributed training.
+
+## Our solution
+To simplify the process of distributed training for foundational models, recent advancements in machine learning systems have led to the emergence of automatic parallel systems. We investigate and research a number of current automatic parallel systems( Tofu , Flexflow , Alpa ) and some auto activation checkpoint algorithms( Rotor , Sublinear ). Inspired from these advanced systems, we build an automatic parallel system upon PyTorch framework. The input of the system is the serial PyTorch code, and the output is a PyTorch program with an optimized distributed execution plan. It is worth emphasizing that the output is a regular PyTorch program, so it is compatible with runtime optimization methods, such as ZeRO-Offload and PatrickStar.
+
+## Key modules
+
+### Analyzer
+
+**Analyzer** is a static analysis system consisting of three parts:
+A *symbolic profiler* for collecting computing and memory overhead related to static computation graph, a *cluster detector* for collecting hardware characteristics and detecting cluster topology and a *tensor layout manager* to find efficient tensor layout conversion path from different sharding spec and record conversion cost.
+
+### Solver
+
+**Solver** is designed to find the optimal execution plan for a given computation graph and cluster in two stages:
+1) *Intra-op parallelism stage* is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimaztion goal of intra-op parallelism solver is modified from Alpa 's intra-op parallelsim ILP solver.
+2) *Activation checkpoint stage* is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimial activation checkpoint is modified from Rotor . The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling.
+
+### Generator
+**Generator** applies the searched execution plan to the computation graph and recompiles the computation graph to optimized PyTorch code. It has *a series compile pass* to insert a communication node or do the kernel substitution as the intra-op parallelism solver required. Additionally, we implement a *code generation* feature to recognize the annotation from the activation checkpoint solver and inject the activation checkpoint block following annotation instructions.
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c
index 0fdcfd58a399..8dad074bc894 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.c
@@ -1,6 +1,12 @@
#define PY_SSIZE_T_CLEAN
#include
+/*
+Rotor solver for checkpointing problem in C. We follow the modeling mentioned in
+paper `Optimal checkpointing for heterogeneous chains: how to train deep neural
+networks with limited memory` https://hal.inria.fr/hal-02352969. Some lines of
+the code are adapted from https://gitlab.inria.fr/hiepacs/rotor.
+*/
long* PySequenceToLongArray(PyObject* pylist) {
if (!(pylist && PySequence_Check(pylist))) return NULL;
Py_ssize_t len = PySequence_Size(pylist);
@@ -81,14 +87,16 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
(mmax + 1) * (chainLength + 1) * (chainLength + 1), sizeof(long));
for (long m = 0; m <= mmax; ++m)
- for (long i = 0; i <= chainLength; ++i)
+ for (long i = 0; i <= chainLength; ++i) {
if ((m >= x[i + 1] + xbar[i + 1] + btmp[i]) &&
- (m >= x[i + 1] + xbar[i + 1] + ftmp[i]))
+ (m >= x[i + 1] + xbar[i + 1] + ftmp[i])) {
COST_TABLE(m, i, i) = ftime[i] + btime[i];
- else
+ } else {
COST_TABLE(m, i, i) = INFINITY;
+ }
+ }
- for (long m = 0; m <= mmax; ++m)
+ for (long m = 0; m <= mmax; ++m) {
for (long d = 1; d <= chainLength; ++d) {
for (long i = 0; i <= chainLength - d; ++i) {
long idx = i + d;
@@ -116,9 +124,10 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
}
}
double chainCost = INFINITY;
- if (m >= xbar[i + 1])
+ if (m >= xbar[i + 1]) {
chainCost =
COST_TABLE(m, i, i) + COST_TABLE(m - xbar[i + 1], i + 1, idx);
+ }
if (bestLeafCost <= chainCost) {
COST_TABLE(m, i, idx) = bestLeafCost;
BACK_PTR(m, i, idx) = bestLeaf;
@@ -126,10 +135,12 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
COST_TABLE(m, i, idx) = chainCost;
BACK_PTR(m, i, idx) = -1;
}
- } else
+ } else {
COST_TABLE(m, i, idx) = INFINITY;
+ }
}
}
+ }
free(ftime);
free(btime);
@@ -158,10 +169,11 @@ static PyObject* computeTable(PyObject* self, PyObject* args) {
PyDict_SetItem(pyCostTable_m_i, pyVar_l, pyCostTable_m_i_l);
Py_DECREF(pyCostTable_m_i_l);
PyObject* pyBackPtr_m_i_l;
- if (BACK_PTR(m, i, l) < 0)
+ if (BACK_PTR(m, i, l) < 0) {
pyBackPtr_m_i_l = Py_BuildValue("(O)", Py_True);
- else
+ } else {
pyBackPtr_m_i_l = Py_BuildValue("(Ol)", Py_False, BACK_PTR(m, i, l));
+ }
PyDict_SetItem(pyBackPtr_m_i, pyVar_l, pyBackPtr_m_i_l);
Py_DECREF(pyBackPtr_m_i_l);
Py_DECREF(pyVar_l);
diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
index 41d23be5c952..21c3bf0da758 100644
--- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
+++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py
@@ -207,11 +207,10 @@ def _compute_table(chain: Chain, mmax: int) -> Tuple:
mmax (int): Maximum number of memory slots.
Returns:
- cost_table (List): cost_table[m][lhs][rhs] with lhs = 0...chain.length
- and rhs = lhs...chain.length (lhs is not included) and m = 0...mmax
- back_ptr (List): back_ptr[m][lhs][rhs] is (True,) if the optimal choice
- is a chain checkpoint (False, j) if the optimal choice is a leaf checkpoint
- of length j
+ cost_table (List): cost_table[m][lhs][rhs] indicates the optimal cost of the subproblem from lhs to rhs
+ with m memory slots.
+ back_ptr (List): back_ptr[m][lhs][rhs] indicates the best operation at this point. It is (True,) if the optimal choice
+ is a chain checkpoint, it is (False, j) if the optimal choice is a leaf checkpoint of length j
"""
ftime = chain.ftime + [0.0]
@@ -224,18 +223,17 @@ def _compute_table(chain: Chain, mmax: int) -> Tuple:
# Build table
cost_table = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
back_ptr = [[{} for _ in range(len(chain) + 1)] for _ in range(mmax + 1)]
- # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
- # Initialize borders of the tables for lmax-lmin = 0
+ # Initialize corner cases where length of sequence equals to 1, i.e. lhs == rhs
for m in range(mmax + 1):
for i in range(len(chain) + 1):
limit = max(x[i + 1] + xbar[i + 1] + ftmp[i], x[i + 1] + xbar[i + 1] + btmp[i])
- if m >= limit: # Equation (1)
+ if m >= limit:
cost_table[m][i][i] = ftime[i] + btime[i]
else:
cost_table[m][i][i] = float("inf")
- # Compute everything
+ # Compute tables
for m in range(mmax + 1):
for d in range(1, len(chain) + 1):
for i in range(len(chain) + 1 - d):
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py
index aa5f77f6591e..4d8b656e17e1 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/__init__.py
@@ -1,6 +1,10 @@
from .activation import *
from .binary_elementwise_ops import *
from .conv import *
+from .embedding import *
from .linear import *
+from .non_spmd import *
from .norm import *
from .pooling import *
+from .tensor import *
+from .where import *
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
index 774457f7d3b6..faeed9f29e61 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/activation.py
@@ -1,74 +1,85 @@
-from typing import List, Tuple
+from typing import Callable, List, Tuple
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
from colossalai.fx.profiler.memory_utils import activation_size
-from colossalai.fx.profiler.opcount import flop_mapping
+from colossalai.fx.profiler.opcount import elementwise_flop_counter
from ..registry import meta_register
-__all__ = ["relu_meta_info"]
+__all__ = ["elementwise_meta_info"]
-@meta_register.register(torch.nn.ReLU)
-def relu_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
- """torch.nn.ReLU metainfo generator
- The aten graph of torch.nn.ReLU is
- graph():
- %input_2 : [#users=1] = placeholder[target=placeholder](default=)
- %relu_default : [#users=2] = call_function[target=torch.ops.aten.relu.default](args = (%input_2,), kwargs = {})
- %zeros_like_default : [#users=1] = call_function[target=torch.ops.aten.zeros_like.default](args = (%relu_default,), kwargs = {dtype: None, layout: None, device: None, pin_memory: None})
- %detach_default : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%relu_default,), kwargs = {})
- %threshold_backward_default : [#users=1] = call_function[target=torch.ops.aten.threshold_backward.default](args = (%zeros_like_default, %detach_default, None), kwargs = {})
- %detach_default_1 : [#users=1] = call_function[target=torch.ops.aten.detach.default](args = (%threshold_backward_default,), kwargs = {})
- %detach_default_2 : [#users=0] = call_function[target=torch.ops.aten.detach.default](args = (%detach_default_1,), kwargs = {})
+def elementwise_meta_info(temp_mem_scale: float = 0, buffer_mem_scale: float = 0) -> Callable:
+ """This is a function to create the meta information generator for elementwise operations
+
+ Args:
+ temp_mem_scale (float, optional): temp memory scaling factor for backward. Defaults to 0.
+ buffer_mem_scale (float, optional): buffer memory scaling factor for forward. Defaults to 0.
Returns:
- Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
+ Callable: meta information generator
"""
- input_tensor = args[0].data
- output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
- is_inplace = kwargs.get("inplace", False)
-
- # construct input args for forward
- fwd_in_args = [input_tensor]
-
- # construct input args for backward
- bwd_in_args = [output_tensor]
-
- # calculate cost
- # the fwd op with compute cost is relu.default
- # the bwd op with compute cost is threshold_backward
-
- # calculate compute cost
- fwd_compute_cost = flop_mapping[torch.ops.aten.relu.default](fwd_in_args, (output_tensor,))
- bwd_compute_cost = flop_mapping[torch.ops.aten.threshold_backward.default](bwd_in_args, (input_tensor,))
- compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
-
- # calculate memory cost
- # NOTE: the inplace ReLU don't have forward memory cost
- # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
- fwd_memory_cost = MemoryCost(
- activation=activation_size(input_tensor) if is_inplace else activation_size([output_tensor, input_tensor]),
- parameter=0,
- temp=0,
- buffer=0)
-
- bwd_memory_cost = MemoryCost(activation=activation_size(input_tensor), parameter=0, temp=0, buffer=0)
-
- # total cost is the sum of forward and backward cost
- total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
- parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter)
-
- memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
-
- # store fwd_in, fwd_buffer, fwd_out
- # NOTE: It might seems a little bit weird here, we just want to align it with the older version
- # of MetaInfoProp. In the future we might modify this part to make it clearer.
- fwd_in = []
- fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
- fwd_out = [torch.zeros_like(output_tensor, device='meta')]
-
- return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
+ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
+ input_tensor = next(
+ filter(
+ lambda x:
+ (x.type == OperationDataType.ARG or x.type == OperationDataType.PARAM) and x.name != 'softmax_dim',
+ args)).data
+ output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
+ is_inplace = 1 if kwargs.get('inplace', False) else 0
+
+ flop_counter = elementwise_flop_counter(1, 0)
+ # calculate compute cost
+ fwd_compute_cost = flop_counter([input_tensor], [output_tensor])
+ bwd_compute_cost = flop_counter([output_tensor], [input_tensor])
+
+ compute_cost = TrainCycleItem(fwd=fwd_compute_cost,
+ bwd=bwd_compute_cost,
+ total=fwd_compute_cost + bwd_compute_cost)
+
+ # calculate memory cost
+ # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
+ # NOTE: if in_place is True, we will not create a new tensor in forward
+ fwd_memory_cost = MemoryCost(activation=activation_size(input_tensor) * (2 - is_inplace),
+ parameter=0,
+ temp=0,
+ buffer=activation_size(input_tensor) * buffer_mem_scale)
+
+ # temp_mem_scale is for situation like softmax backward
+ # the buffer will be removed during backward phase
+ bwd_memory_cost = MemoryCost(
+ activation=activation_size(input_tensor) - activation_size(input_tensor) * buffer_mem_scale,
+ parameter=0,
+ temp=activation_size(input_tensor) * temp_mem_scale + activation_size(input_tensor) * buffer_mem_scale,
+ buffer=0)
+
+ # total cost is the sum of forward and backward cost
+ total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
+ buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
+
+ memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
+
+ # store fwd_in, fwd_buffer, fwd_out
+ fwd_in = []
+ fwd_buffer = [torch.zeros_like(output_tensor, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+
+ return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
+
+ return meta_func
+
+
+# register meta information
+# (0, 0)
+meta_register.register([torch.nn.ReLU, torch.nn.functional.relu, torch.tanh])(elementwise_meta_info(0, 0))
+
+# (1, 0)
+meta_register.register([torch.nn.Softmax, torch.nn.functional.softmax])(elementwise_meta_info(1, 0))
+
+# (0, 0.25) for dropout, the buffer is in bool type so that the buffer memory cost is 0.25 times of input tensor
+meta_register.register([torch.nn.Dropout, torch.nn.functional.dropout])(elementwise_meta_info(0, 0.25))
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
new file mode 100644
index 000000000000..2997f31adff8
--- /dev/null
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/embedding.py
@@ -0,0 +1,52 @@
+from typing import List, Tuple
+
+import torch
+
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.fx.profiler.memory_utils import activation_size
+from colossalai.fx.profiler.opcount import flop_mapping
+
+from ..registry import meta_register
+
+__all__ = ["embedding_meta_info"]
+
+
+@meta_register.register(torch.nn.Embedding)
+def embedding_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
+ """torch.nn.Embedding metainfo generator
+
+ Returns:
+ Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
+ """
+ input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
+ weight_tensor = next(filter(lambda x: x.type == OperationDataType.PARAM, args)).data
+ output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
+
+ # compute cost
+ fwd_compute_cost = flop_mapping[torch.ops.aten.embedding.default]([weight_tensor, input_tensor], [output_tensor])
+ bwd_compute_cost = flop_mapping[torch.ops.aten.embedding_dense_backward.default]([output_tensor, weight_tensor],
+ [weight_tensor])
+
+ compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
+
+ # memory cost
+ # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
+ # NOTE: during the backward phase of torch.nn.Embedding, it seems when the input is large enough, it will
+ # have a temp memory which is kind of weird and we don't know the reason yet, so currently we just assume
+ # that there will be no temp memory, as the temp memory is significantly smaller than the gradient memory
+ fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor]),
+ parameter=0,
+ temp=0,
+ buffer=0)
+ bwd_memory_cost = MemoryCost(activation=activation_size([weight_tensor]), parameter=0, temp=0, buffer=0)
+
+ total_memory_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation)
+
+ memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_memory_cost)
+
+ # store fwd_in, fwd_buffer, fwd_out
+ fwd_in = [torch.zeros_like(input_tensor)]
+ fwd_buffer = []
+ fwd_out = [torch.zeros_like(output_tensor)]
+
+ return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
index 61f8fdff33a1..617375721222 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
@@ -1,3 +1,4 @@
+from functools import reduce
from typing import Callable, Dict, List, Tuple, Union
import torch
@@ -16,7 +17,7 @@
from ..registry import meta_register
-__all__ = ['linear_meta_info']
+__all__ = ['linear_meta_info', 'matmul_meta_info']
@meta_register.register(torch.nn.functional.linear)
@@ -170,3 +171,235 @@ def linear_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
+
+
+@meta_register.register(torch.matmul)
+def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
+ """torch.matmul meta info generator
+ There are several cases for torch.matmul:
+ 1. Vector-vector multiplication => no temp memory, forward memory cost is 1 element (could be neglected), backward memory cost is the same
+ as two input vectors.
+ 2. Matrix-vector multiplication => if the first input is matrix, no temp memory is needed, otherwise, there is a temp memory in the backward
+ phase for the transpose of the matrix. The forward memory cost is the size of output tensor, backward memory cost is the size of the two inputs; if
+ the first input is vector, the forward memory cost is the size of the output tensor, and during the backward phase, it will allocate a temp memory
+ the same size as the input matrix, and allocate memory for the gradient of two inputs.
+ 3. Batched Matrix-vector multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of
+ output tensor, backward memory cost is the size of the two inputs; if the second input is the batched matrix, the matmul will allocate memory for
+ the gradient of the batched matrix in the forward phase (as they create a new tensor without the former batches), so the forward memory cost is
+ the output tensor and the newly created matrix (take the same amount of memory of the input batched matrix). During the backward phase, it will
+ allocate a temp memory the same size as input batched matrix, and allocate a tensor for the gradient of the input vector. The gradient of the batched
+ matrix will be stored in the memory allocated during the forward phase.
+ 3. Matrix-matrix multiplication => no temp memory, forward memory is the size of output tensor, backward memory is the size of the two inputs
+ 4. Batched matrix-matrix multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of two
+ inputs and backward memory cost is the size of the output tensor; if the second input is the batched matrix, during the forward phase it will allocate
+ memory for the output and gradient of the second input, and has a temp memory the same size as the output, during the backward phase, it
+ will allocate memory for the gradient of the first input and has a temp memory which is as big as output and the second input.
+ 5. Batched matrix-batched matrix multiplication => if the two inputs have the same batch dimensions, no temp memory, the forward memory cost is the size
+ of output, backward memory cost is the size of the two inputs; it the two inputs have different batch dimensions, during the forward phase it will allocate
+ memory of the expanded inputs (so that the batch dimensions could match) and the output, and during the backward phase, it has a temp memory of the size of
+ two expanded inputs, and it will allocate memory for the gradient of the two inputs and discard the expanded inputs allocated during the forward phase.
+
+ Returns:
+ Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs
+
+ """
+ # Get input and output tensors
+ input_tensors = [args[0].data, args[1].data]
+ output_tensors = [args[-1].data]
+
+ # Check dimension
+ if all(len(tensor.shape) == 1 for tensor in input_tensors):
+ # Dot
+ fwd_compute_cost = flop_mapping[torch.ops.aten.dot.default](input_tensors, output_tensors)
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](input_tensors[0], output_tensors) * 2
+
+ fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
+ bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
+
+ elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 1:
+ # gemv case 1: matrix-vector multiplication
+ # &
+ # batched gemv case 1: batched matrix-vector multiplication
+
+ fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]], output_tensors)
+
+ # combine the dimensions of output
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
+ [output_tensors[0].reshape(-1), input_tensors[1]],
+ output_tensors) + \
+ flop_mapping[torch.ops.aten.mv.default](
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1)],
+ output_tensors)
+
+ fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
+ bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
+
+ elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) == 2:
+ # gemv case 2: vector-matrix multiplication
+ fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](input_tensors, output_tensors)
+
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor]([output_tensors[0], input_tensors[0]], output_tensors) + \
+ flop_mapping[torch.ops.aten.mv.default]([input_tensors[1], output_tensors[0]], output_tensors)
+
+ fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
+ bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors),
+ parameter=0,
+ temp=activation_size(input_tensors[1]),
+ buffer=0)
+
+ elif len(input_tensors[0].shape) == 1 and len(input_tensors[1].shape) >= 3:
+ # batched gemv case 2: vector-batched matrix multiplication
+
+ fwd_compute_cost = flop_mapping[torch.ops.aten.mv.default](
+ [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0]],
+ [output_tensors[0].reshape(-1)])
+
+ # combine the dimensions of output
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mul.Tensor](
+ [output_tensors[0].reshape(-1), input_tensors[0]],
+ output_tensors
+ ) + \
+ flop_mapping[torch.ops.aten.mv.default](
+ [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]).transpose(0, 1), output_tensors[0].reshape(-1)],
+ output_tensors
+ )
+
+ fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors + [input_tensors[1]]))
+ bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
+ parameter=0,
+ temp=activation_size(input_tensors[1]),
+ buffer=0)
+
+ elif len(input_tensors[0].shape) >= 2 and len(input_tensors[1].shape) == 2:
+ # gemm & batched gemm case 1: batched matrix-matrix multiplication
+
+ fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]), input_tensors[1]],
+ [output_tensors[0].reshape(-1, output_tensors[0].shape[-1])])
+
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1]).transpose(0, 1), output_tensors[0].reshape(-1, output_tensors[0].shape[-1])],
+ [input_tensors[1]]
+ ) + \
+ flop_mapping[torch.ops.aten.mm.default](
+ [output_tensors[0].reshape(-1, output_tensors[0].shape[-1]), input_tensors[1].transpose(0, 1)],
+ [input_tensors[0].reshape(-1, input_tensors[0].shape[-1])]
+ )
+
+ fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors), parameter=0, temp=0, buffer=0)
+ bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors), parameter=0, temp=0, buffer=0)
+
+ elif len(input_tensors[0].shape) == 2 and len(input_tensors[1].shape) >= 3:
+ # batched gemm case 2: matrix-batched matrix multiplication
+ fwd_compute_cost = flop_mapping[torch.ops.aten.mm.default]([
+ input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2]), input_tensors[0].transpose(
+ 0, 1)
+ ], [output_tensors[0].transpose(-2, -1)])
+
+ bwd_compute_cost = flop_mapping[torch.ops.aten.mm.default](
+ [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]).transpose(0, 1), input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])],
+ [input_tensors[0]]
+ ) + \
+ flop_mapping[torch.ops.aten.mm.default](
+ [output_tensors[0].transpose(-2, -1).reshape(-1, output_tensors[0].shape[-2]), input_tensors[0]],
+ [input_tensors[1].transpose(-2, -1).reshape(-1, input_tensors[1].shape[-2])]
+ )
+
+ fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors) + activation_size(input_tensors[1]),
+ temp=activation_size(output_tensors))
+ bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors[0]),
+ parameter=0,
+ temp=activation_size(input_tensors[1]) + activation_size(output_tensors))
+
+ elif all(len(tensor.shape) >= 3 for tensor in input_tensors):
+ # Batched matrix-batched matrix multiplication
+ # Fetch shape of the two inputs and see if the batch dimensions are the same
+ _is_batch_dims_same = True
+ if len(input_tensors[0].shape) == len(input_tensors[1].shape):
+ for (shape_0, shape_1) in zip(input_tensors[0].shape[:-2], input_tensors[1].shape[:-2]):
+ if shape_0 != shape_1:
+ _is_batch_dims_same = False
+ break
+ else:
+ _is_batch_dims_same = False
+
+ # retireve dimensions
+ input_dim_00 = input_tensors[0].shape[-2]
+ input_dim_01 = input_tensors[0].shape[-1]
+ input_dim_10 = input_tensors[1].shape[-2]
+ input_dim_11 = input_tensors[1].shape[-1]
+ output_dim_0 = output_tensors[0].shape[-2]
+ output_dim_1 = output_tensors[0].shape[-1]
+
+ if _is_batch_dims_same:
+ # Case 1: batch dimensions are the same
+
+ # Forward compute cost: C = A * B
+ fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default]([
+ input_tensors[0].reshape(-1, input_dim_00, input_dim_01), input_tensors[1].reshape(
+ -1, input_dim_10, input_dim_11)
+ ], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
+
+ # Backward compute cost: dB = A^T * dC, dA = dC * B^T
+ bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
+ [input_tensors[0].transpose(-2, -1).reshape(-1, input_dim_01, input_dim_00), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
+ [input_tensors[1].reshape(-1, input_dim_11, input_dim_10)]
+ ) + \
+ flop_mapping[torch.ops.aten.bmm.default](
+ [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), input_tensors[1].transpose(-2, -1).reshape(-1, input_dim_11, input_dim_10)],
+ [input_tensors[0].reshape(-1, input_dim_00, input_dim_01)]
+ )
+
+ fwd_mem_cost = MemoryCost(activation=activation_size(output_tensors))
+ bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors))
+
+ else:
+ # Case 2: batch dimensions are different
+ batch_dims = output_tensors[0].shape[:-2]
+ extended_input_0 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
+ input_dim_00,
+ input_dim_01,
+ device="meta")
+ extended_input_1 = torch.rand(reduce(lambda x, y: x * y, batch_dims),
+ input_dim_10,
+ input_dim_11,
+ device="meta")
+
+ # Forward compute cost: C = A * B
+ fwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
+ [extended_input_0, extended_input_1], [output_tensors[0].reshape(-1, output_dim_0, output_dim_1)])
+
+ # Backward compute cost: dB = A^T * dC, dA = dC * B^T
+ bwd_compute_cost = flop_mapping[torch.ops.aten.bmm.default](
+ [extended_input_0.transpose(-2, -1), output_tensors[0].reshape(-1, output_dim_0, output_dim_1)],
+ [extended_input_1]
+ ) + \
+ flop_mapping[torch.ops.aten.bmm.default](
+ [output_tensors[0].reshape(-1, output_dim_0, output_dim_1), extended_input_1.transpose(-2, -1)],
+ [extended_input_0]
+ )
+
+ fwd_mem_cost = MemoryCost(
+ activation=activation_size([output_tensors[0], extended_input_0, extended_input_1]))
+ bwd_mem_cost = MemoryCost(activation=activation_size(input_tensors) -
+ activation_size([extended_input_0, extended_input_1]),
+ temp=activation_size([extended_input_0, extended_input_1]))
+
+ # compute cost
+ compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
+
+ # memory cost
+ total_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+
+ memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_cost)
+
+ # store fwd_in, fwd_buffer, fwd_out
+ fwd_in = input_tensors
+ fwd_buffer = []
+ fwd_out = output_tensors
+
+ return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
new file mode 100644
index 000000000000..4634d3ccdcfd
--- /dev/null
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/non_spmd.py
@@ -0,0 +1,29 @@
+import operator
+from typing import List, Tuple
+
+import torch
+
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.fx.profiler.memory_utils import activation_size
+from colossalai.fx.profiler.opcount import flop_mapping
+
+from ..registry import meta_register
+
+__all__ = ["non_spmd_meta_info"]
+
+
+@meta_register.register(torch.Size)
+@meta_register.register(torch.Tensor.size)
+@meta_register.register(torch.finfo)
+@meta_register.register(operator.le)
+def non_spmd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
+ """Non-SPMD node meta information generator
+ Those nodes will not be handled by SPMD solver, so we just return all zero meta information for it
+
+ Returns:
+ Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
+ """
+ compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
+ memory_cost = TrainCycleItem(fwd=MemoryCost(), bwd=MemoryCost(), total=MemoryCost())
+ fwd_in, fwd_buffer, fwd_out = [], [], []
+ return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
index 9b34332db1b5..3a1db396e188 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/norm.py
@@ -16,7 +16,7 @@
from ..registry import meta_register
-__all__ = ['batchnormnd_meta_info']
+__all__ = ['batchnormnd_meta_info', 'layernorm_meta_info']
@meta_register.register(torch.nn.BatchNorm1d)
@@ -101,3 +101,56 @@ def batchnormnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleIt
fwd_out = [torch.zeros_like(output_tensor, device='meta')]
return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
+
+
+@meta_register.register(torch.nn.LayerNorm)
+def layernorm_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
+ """LayerNorm meta information
+
+ Returns:
+ Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
+ """
+ # construct needed tensors
+ input_tensor = next(filter(lambda x: x.type == OperationDataType.ARG, args)).data
+ output_tensor = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
+ weight_tensor = next(filter(lambda x: x.name == "weight", args)).data
+ bias_tensor = next(filter(lambda x: x.name == "bias", args)).data
+ running_mean = torch.rand(input_tensor.shape[0], 1, device='meta')
+ running_var = torch.rand(input_tensor.shape[0], 1, device='meta')
+
+ # construct args
+ fwd_in_args = [input_tensor, [input_tensor.shape[0]], weight_tensor]
+ fwd_out_args = [output_tensor]
+ bwd_in_args = [input_tensor, output_tensor, [input_tensor.shape[0]]]
+ bwd_out_args = [weight_tensor, bias_tensor]
+
+ # compute cost
+ fwd_compute_cost = flop_mapping[torch.ops.aten.native_layer_norm.default](fwd_in_args, fwd_out_args)
+ bwd_compute_cost = flop_mapping[torch.ops.aten.native_layer_norm_backward.default](bwd_in_args, bwd_out_args)
+ compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
+
+ # memory cost
+ # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
+ fwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, output_tensor, weight_tensor, bias_tensor]),
+ parameter=activation_size([weight_tensor, bias_tensor]),
+ temp=0,
+ buffer=activation_size([running_mean, running_var]))
+
+ bwd_memory_cost = MemoryCost(activation=activation_size([input_tensor, weight_tensor, bias_tensor]),
+ parameter=activation_size([weight_tensor, bias_tensor]),
+ temp=activation_size([running_mean, running_var]),
+ buffer=activation_size([running_mean, running_var]))
+
+ total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,
+ parameter=fwd_memory_cost.parameter + bwd_memory_cost.parameter,
+ temp=fwd_memory_cost.temp + bwd_memory_cost.temp,
+ buffer=fwd_memory_cost.buffer + bwd_memory_cost.buffer)
+
+ memory_cost = TrainCycleItem(fwd=fwd_memory_cost, bwd=bwd_memory_cost, total=total_cost)
+
+ # store fwd_in, fwd_buffer, fwd_out
+ fwd_in = [torch.zeros_like(input_tensor, device='meta')]
+ fwd_buffer = [torch.zeros_like(running_mean, device='meta'), torch.zeros_like(running_var, device='meta')]
+ fwd_out = [torch.zeros_like(output_tensor, device='meta')]
+
+ return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
index 79780c92eed4..21272ea09ac1 100644
--- a/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/pooling.py
@@ -14,7 +14,6 @@
@meta_register.register(torch.nn.AdaptiveAvgPool1d)
@meta_register.register(torch.nn.AdaptiveAvgPool2d)
@meta_register.register(torch.nn.AdaptiveAvgPool3d)
-@meta_register.register(torch.flatten)
def avgpool_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
"""Meta info for AdaptiveAvgPool
The aten graph of AdaptiveAvgPool is
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
new file mode 100644
index 000000000000..332e649d2d7e
--- /dev/null
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/tensor.py
@@ -0,0 +1,79 @@
+from typing import Callable, List, Tuple
+
+import torch
+
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.fx.profiler.memory_utils import activation_size
+from colossalai.fx.profiler.opcount import flop_mapping
+
+from ..registry import meta_register
+
+__all__ = ["tensor_related_metainfo"]
+
+
+def tensor_related_metainfo(bwd_mem_out_factor: float = 1, bwd_mem_tmp_factor: float = 0) -> Callable:
+ """torch.Tensor related metainfo generator template
+
+ Args:
+ bwd_mem_out_factor (float, optional): backward activation memory cost factor. Defaults to 1.
+ bwd_mem_tmp_factor (float, optional): backward temp memory cost factor. Defaults to 0.
+
+ Returns:
+ Callable: torch.Tensor related metainfo generator
+ """
+
+ def meta_func(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
+ """torch.Tensor related metainfo generator
+
+ Returns:
+ Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
+ """
+ outputs = next(filter(lambda x: x.type == OperationDataType.OUTPUT, args)).data
+
+ # compute costs are all zero
+ compute_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
+
+ # memory costs
+ # NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
+ fwd_mem_cost = MemoryCost(activation=activation_size(outputs) * 2, parameter=0, temp=0, buffer=0)
+
+ bwd_mem_cost = MemoryCost(activation=activation_size(outputs) * bwd_mem_out_factor,
+ parameter=0,
+ temp=activation_size(outputs) * bwd_mem_tmp_factor,
+ buffer=0)
+
+ total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+
+ memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
+
+ # store fwd_in, fwd_buffer, fwd_out
+ fwd_in = []
+ fwd_buffer = []
+ if isinstance(outputs, tuple) or isinstance(outputs, list) or isinstance(outputs, dict):
+ # tuple of tensors
+ fwd_out = [torch.zeros_like(tensor) for tensor in outputs]
+ else:
+ # enaged_tensors is a single tensor
+ fwd_out = [torch.zeros_like(outputs)]
+
+ return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
+
+ return meta_func
+
+
+# register torch.Tensor related metainfo
+# (0, 0)
+meta_register.register([torch.tensor, torch.Tensor.to, torch.Tensor.unsqueeze, torch.unsqueeze,
+ torch.arange])(tensor_related_metainfo(0, 0))
+
+# (1, 0)
+meta_register.register([
+ torch.Tensor.flatten, torch.flatten, torch.Tensor.transpose, torch.transpose, torch.Tensor.permute, torch.permute,
+ torch.Tensor.split, torch.split, torch.Tensor.view
+])(tensor_related_metainfo(1, 0))
+
+# (1, 1)
+meta_register.register([torch.Tensor.type, torch.Tensor.contiguous])(tensor_related_metainfo(1, 1))
diff --git a/colossalai/auto_parallel/meta_profiler/meta_registry/where.py b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
new file mode 100644
index 000000000000..c67eb40bc80e
--- /dev/null
+++ b/colossalai/auto_parallel/meta_profiler/meta_registry/where.py
@@ -0,0 +1,60 @@
+from typing import List, Tuple
+
+import torch
+
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
+from colossalai.fx.profiler.memory_utils import activation_size
+from colossalai.fx.profiler.opcount import flop_mapping
+
+from ..registry import meta_register
+
+__all__ = ["where_meta_info"]
+
+
+@meta_register.register(torch.where)
+def where_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]:
+ """torch.where meta information generator
+
+ Returns:
+ Tuple[TrainCycleItem, TrainCycleItem, List[torch.Tensor]]: compute cost, memory cost and forward inputs
+ """
+
+ condition_tensor, x_tensor, y_tensor, output_tensor = [arg.data for arg in args]
+
+ # compute cost
+ fwd_compute_cost = 0
+
+ # if we need to broadcast the condition tensor, during backward we need to do a reduce_sum
+ bwd_compute_cost = 0
+ if x_tensor.shape != output_tensor.shape:
+ bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [x_tensor])
+ if y_tensor.shape != output_tensor.shape:
+ bwd_compute_cost += flop_mapping[torch.ops.aten.sum.dim_IntList]([output_tensor], [y_tensor])
+
+ compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost)
+
+ # memory cost
+ # during the forward phase, torch.where will allocate memory for output tensor and condition tensor
+ # during the backward phase, torch.where will allocate temp memory which is 3 times as output tensor, then generate
+ # gradient matrix for input x and input y, remove the temp memory and condition tensor generated in forward phase
+ # NOTE: currently in SPMD solver we always believe that there will be a new input tensor created in forward
+ fwd_mem_cost = MemoryCost(activation=activation_size([condition_tensor, x_tensor, y_tensor, output_tensor]))
+ bwd_mem_cost = MemoryCost(activation=activation_size([x_tensor, y_tensor]) - activation_size([condition_tensor]),
+ parameter=0,
+ temp=activation_size([output_tensor]) * 3 + activation_size([condition_tensor]) -
+ activation_size([x_tensor, y_tensor]),
+ buffer=0)
+
+ total_mem_cost = MemoryCost(activation=fwd_mem_cost.activation + bwd_mem_cost.activation,
+ parameter=fwd_mem_cost.parameter + bwd_mem_cost.parameter,
+ temp=fwd_mem_cost.temp + bwd_mem_cost.temp,
+ buffer=fwd_mem_cost.buffer + bwd_mem_cost.buffer)
+
+ memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
+
+ # store fwd_in, fwd_buffer, fwd_out
+ fwd_in = [condition_tensor]
+ fwd_buffer = []
+ fwd_out = [output_tensor]
+
+ return compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/__init__.py b/colossalai/auto_parallel/offload/__init__.py
similarity index 100%
rename from examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/__init__.py
rename to colossalai/auto_parallel/offload/__init__.py
diff --git a/colossalai/auto_parallel/offload/amp_optimizer.py b/colossalai/auto_parallel/offload/amp_optimizer.py
new file mode 100644
index 000000000000..a79e5006e7d2
--- /dev/null
+++ b/colossalai/auto_parallel/offload/amp_optimizer.py
@@ -0,0 +1,177 @@
+from typing import Dict, Tuple
+from enum import Enum
+import torch
+from torch.optim import Optimizer
+
+from colossalai.logging import get_dist_logger
+from colossalai.nn.optimizer import ColossalaiOptimizer
+from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
+from colossalai.utils import get_current_device
+
+from .base_offload_module import BaseOffloadModule
+from .region_manager import RegionManager
+from .region import Region
+
+
+class OptimState(Enum):
+ SCALED = 0
+ UNSCALED = 1
+
+class AMPOptimizer(ColossalaiOptimizer):
+
+ """
+ A wrapper for Optimizer.
+ Code reference: https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/optimizer/zero_optimizer.py
+
+ Args:
+ optimizer (Optimizer): An Optimizer instance.
+ module (BaseOffloadModule): A ``BaseOffloadModule`` instance.
+ initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**16.
+ growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
+ backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
+ growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
+ hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
+ min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
+ max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
+ norm_type (float, optional): norm_type used for `clip_grad_norm`.
+ """
+
+ def __init__(self,
+ optimizer: Optimizer,
+ module: BaseOffloadModule,
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ clipping_norm: float = 0.0,
+ norm_type: float = 2.0):
+
+ super().__init__(optimizer)
+
+ self.module = module
+ self.optim_state = OptimState.UNSCALED
+ self.clipping_flag = clipping_norm > 0.0
+ self.max_norm = clipping_norm
+
+ self.region_manager: RegionManager = self.module.region_manager
+ self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
+ self.param_to_region: Dict[torch.nn.Parameter, Region] = dict()
+
+ self.fp32_to_fp16_params: Dict[torch.Tensor, torch.nn.Parameter] = dict()
+
+ if self.clipping_flag:
+ assert norm_type == 2.0, "AMPOptimizer only supports L2 norm now"
+
+ self.__init__optimizer()
+
+ # Grad scaler
+ self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
+ min_scale=min_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval,
+ hysteresis=hysteresis,
+ max_scale=max_scale)
+ self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
+ self._logger = get_dist_logger()
+
+ def _set_grad_ptr(self):
+ for group in self.param_groups:
+ for fake_param in group['params']:
+ region = self.param_to_region[fake_param]
+ begin, end = self.param_to_range[fake_param]
+
+ fake_param.data = region.cpu_grad[begin:end]
+ fake_param.grad = fake_param.data
+ fake_param.data = region.fp32_data[begin:end]
+
+ def _update_fp16_params(self):
+ none_tensor = torch.empty([0])
+ for group in self.param_groups:
+ for fake_param in group['params']:
+ assert fake_param.grad is None
+ fake_param.data = none_tensor
+ self.param_to_region[fake_param].cpu_grad = None
+
+ def _check_overflow(self):
+ # clear previous overflow record
+ self._found_overflow.fill_(self.module.overflow_counter.item())
+ return self._found_overflow.item() > 0
+
+ def _get_combined_scale(self):
+ loss_scale = 1
+
+ if self.optim_state == OptimState.SCALED:
+ loss_scale = self.loss_scale
+ self.optim_state = OptimState.UNSCALED
+
+ combined_scale = loss_scale
+
+ if combined_scale == 1:
+ return -1
+ else:
+ return combined_scale
+
+ @property
+ def loss_scale(self):
+ return self.grad_scaler.scale.item()
+
+ def zero_grad(self, *args, **kwargs):
+ self.module.overflow_counter = torch.cuda.IntTensor([0])
+ return self.optim.zero_grad(set_to_none=True)
+
+ def step(self, *args, **kwargs):
+ # Copy gradients from model params to main params.
+ self._set_grad_ptr()
+
+ found_inf = self._check_overflow()
+ if found_inf:
+ self.optim_state = OptimState.UNSCALED # no need to unscale grad
+ self.grad_scaler.update(found_inf) # update gradient scaler
+ self._logger.info(f'Found overflow. Skip step')
+ self.zero_grad() # reset all gradients
+ self._update_fp16_params()
+ return
+
+ # get combined scale. combined scale = loss scale * clipping norm
+ # so that gradient = gradient / combined scale
+ combined_scale = self._get_combined_scale()
+ self.grad_scaler.update(found_inf)
+
+ ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
+ self.zero_grad()
+ self._update_fp16_params()
+ return ret
+
+ def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
+ raise NotImplementedError
+
+ def backward(self, loss: torch.Tensor):
+ loss = self.loss_scale * loss
+ self.optim_state = OptimState.SCALED
+ self.module.backward(loss)
+
+ def __init__optimizer(self):
+
+ for group in self.optim.param_groups:
+ fake_params_list = list()
+
+ for param in group['params']:
+ region = self.region_manager.get_region(param)
+ fake_param = torch.nn.Parameter(torch.empty([0]))
+ self.param_to_range[fake_param] = region.param_to_range[param]
+ self.param_to_region[fake_param] = region
+ fake_params_list.append(fake_param)
+
+ # Reset existing state dict key to the new main param.
+ if param in self.optim.state:
+ self.optim.state[fake_param] = self.optim.state.pop(param)
+
+ group['params'] = fake_params_list
+
+ # Leverage state_dict() and load_state_dict() to
+ # recast preexisting per-param state tensors
+ self.optim.load_state_dict(self.optim.state_dict())
\ No newline at end of file
diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py
new file mode 100644
index 000000000000..59cea4ece266
--- /dev/null
+++ b/colossalai/auto_parallel/offload/base_offload_module.py
@@ -0,0 +1,109 @@
+from typing import Optional, Set
+from functools import partial
+import torch
+import torch.nn as nn
+
+from colossalai.nn.parallel.data_parallel import _cast_float
+from colossalai.gemini.tensor_utils import free_storage
+
+from .region_manager import RegionManager
+from .util import GlobalRuntimeInfo
+
+
+class BaseOffloadModule:
+ """
+ BaseOffloadModule: A model wrapper for parameter offloading.
+
+ Args:
+ model (nn.Module): model to apply offloading.
+ region_manager (RegionManager): a ``RegionManager`` instance.
+ is_sync (bool): synchronous mode or not.
+ """
+
+ def __init__(self,
+ model: nn.Module,
+ region_manager: RegionManager,
+ is_sync=True):
+
+ self.model = model
+ self.region_manager = region_manager
+ self.grad_hook_list = []
+ self.overflow_counter = torch.cuda.IntTensor([0])
+
+ self.grad_offload_stream = torch.cuda.current_stream() if is_sync else GlobalRuntimeInfo.d2h_stream
+
+ self._cast_buffers()
+
+ def register_grad_hook(self):
+ for p in self.model.parameters():
+ if p.requires_grad:
+ self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
+
+ def remove_grad_hook(self):
+ for hook in self.grad_hook_list:
+ hook.remove()
+
+ def __call__(self, *args, **kwargs):
+ return self.forward(*args, **kwargs)
+
+ def _pre_forward(self):
+ self.register_grad_hook()
+ for region in self.region_manager.region_list:
+ region.cpu_grad = None
+
+ def forward(self, *args, **kwargs):
+ args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
+ self.model.zero_grad(set_to_none=True)
+ self._pre_forward()
+ outputs = self.model(*args, **kwargs)
+ return outputs
+
+ def backward(self, loss):
+ loss.backward()
+ self._post_backward()
+
+ def _post_backward(self):
+ torch.cuda.synchronize()
+ self.remove_grad_hook()
+
+ for p in self.model.parameters():
+ p.grad = None
+
+ GlobalRuntimeInfo.fwd_prefetch_event_map.clear()
+ GlobalRuntimeInfo.bwd_prefetch_event_map.clear()
+
+ def grad_handle(self, p, grad):
+ empty_grad = torch.empty_like(grad)
+ free_storage(empty_grad)
+ with torch._C.DisableTorchFunction():
+ region = self.region_manager.get_region(p)
+ region.copy_grad_to_region_slice(p, grad)
+ if region.can_release:
+ self.overflow_counter += region.has_inf_or_nan
+ master_stream = torch.cuda.current_stream()
+ with torch.cuda.stream(self.grad_offload_stream):
+ GlobalRuntimeInfo.d2h_stream.wait_stream(master_stream)
+ region.move_grad_to_cpu()
+ return empty_grad
+
+ def _cast_buffers(self):
+ for buffer in self.model.buffers():
+ buffer.data = buffer.cuda()
+
+ def parameters(self, recurse: bool = True):
+ return self.model.parameters(recurse)
+
+ def named_parameters(self, prefix: str = '', recurse: bool = True):
+ return self.model.named_parameters(prefix, recurse)
+
+ def named_buffers(self, prefix: str = '', recurse: bool = True):
+ return self.model.named_buffers(prefix, recurse)
+
+ def named_children(self):
+ return self.model.named_children()
+
+ def named_modules(self,
+ memo: Optional[Set[torch.nn.Module]] = None,
+ prefix: str = '',
+ remove_duplicate: bool = True):
+ return self.model.named_modules(memo, prefix, remove_duplicate)
diff --git a/colossalai/auto_parallel/offload/mem_optimize.py b/colossalai/auto_parallel/offload/mem_optimize.py
new file mode 100644
index 000000000000..02778696a106
--- /dev/null
+++ b/colossalai/auto_parallel/offload/mem_optimize.py
@@ -0,0 +1,49 @@
+from typing import Dict
+import torch
+import torch.fx
+from torch.fx import GraphModule
+from torch.utils._pytree import tree_map
+
+from colossalai.fx import ColoTracer, is_compatible_with_meta
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+
+from .region_manager import RegionManager
+from .runtime import runtime_syn_offload_apply_pass, runtime_asyn_offload_apply_pass
+from .base_offload_module import BaseOffloadModule
+from .util import compute_max_param_mem, compute_total_param_mem, compute_act_peak_mem, GlobalRuntimeInfo
+
+def memory_optimize(model: torch.nn.Module,
+ inps: Dict[str, torch.Tensor],
+ memory_budget: float = -1.0,
+ solver_name: str = 'asyn'):
+
+ model = model.cpu().half()
+ tracer = ColoTracer()
+ assert is_compatible_with_meta()
+ wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
+ meta_args = tree_map(wrap_fn, inps)
+ graph = tracer.trace(model, meta_args=meta_args)
+ gm = GraphModule(model, graph, model.__class__.__name__)
+ interp = MetaInfoProp(gm)
+ interp.propagate(*meta_args.values())
+
+ region_manager = RegionManager(graph, solver_name=solver_name, memory_budget=memory_budget)
+ region_manager._build_regions()
+ GlobalRuntimeInfo.region_list = region_manager.region_list
+
+ act_peak_mem = compute_act_peak_mem(region_manager.region_list) / 1024 ** 2
+ max_param_mem = compute_max_param_mem(region_manager.region_list) / 1024 ** 2
+ total_param_mem = compute_total_param_mem(region_manager.region_list) / 1024 ** 2
+ print(
+ f"act_peak_mem={act_peak_mem:.3f} MB | max_param_mem={max_param_mem:.3f} MB | total_param_mem={total_param_mem:.3f}")
+
+ if solver_name == 'syn':
+ gm = runtime_syn_offload_apply_pass(gm, region_manager.region_list)
+ elif solver_name == 'asyn':
+ gm = runtime_asyn_offload_apply_pass(gm, region_manager.region_list)
+ else:
+ raise TypeError(f"Unknown solver name {solver_name}!")
+
+ gm.recompile()
+ optimized_model = BaseOffloadModule(gm, region_manager, solver_name=='syn')
+ return optimized_model
diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py
new file mode 100644
index 000000000000..e6907cc4b81d
--- /dev/null
+++ b/colossalai/auto_parallel/offload/region.py
@@ -0,0 +1,144 @@
+from typing import List, Dict, Tuple
+import torch
+from torch.fx import Node
+from colossalai.gemini.tensor_utils import alloc_storage, free_storage
+
+class Region:
+ """
+ Region: A container owning a piece of contiguous nodes in the DNN computing graph.
+
+ Args:
+ r_id (int): the index of the region in the computing graph.
+ """
+
+ def __init__(self, r_id: int = 0) -> None:
+ self.r_id: int = r_id
+ self.fp16_params: List[torch.nn.Parameter] = []
+ self.param_size: int = 0
+ self.shared_rid: int = self.r_id
+
+ self.param_num: int = 0
+ self.grad_num: int = 0
+ self.fp16_data = None
+ self.fp32_data = None
+ self.cpu_grad = None
+ self.temp_fp32_data = None
+ self.param_to_range: Dict[torch.nn.Parameter, Tuple[int, int]] = dict()
+
+ self.need_offload: bool = False
+ self.is_syn: bool = False
+ self.nodes: List[Node] = []
+ self.fwd_prefetch_region = None
+ self.bwd_prefetch_region = None
+
+ self.in_mem_pool_flag: bool = False
+
+ @property
+ def can_release(self) -> bool:
+ """
+ Check if the region can be released.
+ """
+ return self.grad_num == self.param_num
+
+ @property
+ def has_inf_or_nan(self) -> bool:
+ """
+ Check if the grad of the region has inf or nan values on CUDA.
+ """
+ return torch.isinf(self.fp16_data).any() | torch.isnan(self.fp16_data).any()
+
+ def init_param_data(self, pre_alloc_tensor: torch.Tensor = None):
+ """
+ Map the parameters in the region to a contiguous memory space.
+ """
+
+ self.fp16_data = torch.zeros(
+ self.param_num, dtype=torch.half, device='cuda')
+ offset = 0
+ for param in self.fp16_params:
+ param.data = param.data.cuda()
+ p_num = param.data.numel()
+ self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
+ param.data = self.fp16_data[offset:offset +
+ p_num].view(param.data.shape)
+ self.param_to_range[param] = (offset, offset + p_num)
+ offset += p_num
+
+ self.fp32_data = self.fp16_data.float().cpu().pin_memory()
+ free_storage(self.fp16_data)
+ if self.in_mem_pool_flag and pre_alloc_tensor is not None:
+ self.fp16_data = pre_alloc_tensor
+
+ def move_param_to_cuda(self):
+ """
+ Move parameters from CPU to GPU.
+ It first moves float32 parameters to GPU and
+ then transforms float32 parameters to half-precision on the GPU.
+ The reason is that the performance of precision conversion on the CPU
+ is much slower than the data transfer overhead.
+ """
+
+ self.temp_fp32_data.copy_(self.fp32_data, non_blocking=True)
+ self.temp_fp32_data.record_stream(torch.cuda.current_stream())
+ if not self.in_mem_pool_flag:
+ alloc_storage(self.fp16_data)
+ self.fp16_data[:self.param_num].copy_(self.temp_fp32_data)
+ self.fp16_data.record_stream(torch.cuda.current_stream())
+
+ self.__update_params_ptr()
+
+ def move_grad_to_cpu(self):
+ """
+ Move gradients from GPU to CPU.
+ """
+
+ self.cpu_grad = torch.empty(self.param_num, dtype=torch.half, pin_memory=True)
+ self.cpu_grad.copy_(self.fp16_data[:self.param_num], non_blocking=True)
+ self.fp16_data.record_stream(torch.cuda.current_stream())
+ if not self.in_mem_pool_flag:
+ self.free_cuda_data()
+
+ self.grad_num = 0
+
+ def free_cuda_data(self):
+ free_storage(self.fp16_data)
+
+ # torch.cuda.empty_cache()
+
+ def copy_grad_to_region_slice(self, param: torch.nn.Parameter, data_slice: torch.Tensor) -> None:
+ """
+ Copy data slice to the memory space indexed by the input tensor in the region.
+
+ Args:
+ param (torch.nn.Parameter): the param used to retrive meta information
+ data_slice (torch.Tensor): the tensor to be copied to the region
+ """
+
+ begin, end = self.param_to_range[param]
+ self.fp16_data[begin:end].copy_(data_slice.data.flatten())
+ param.data = self.fp16_data[begin:end].view(param.data.shape)
+
+ self.grad_num += data_slice.numel()
+
+ def split(self, cut_node_idx: int, cut_param_idx: int):
+ """
+ Split the region into two and return the latter.
+ """
+ new_reg = Region(r_id=self.r_id + 1)
+ new_reg.nodes = self.nodes[cut_node_idx:]
+ new_reg.fp16_params = self.fp16_params[cut_param_idx:]
+ for p in new_reg.fp16_params:
+ new_reg.param_size += p.data.numel() * p.data.element_size()
+ new_reg.param_num += p.data.numel()
+
+ self.nodes = self.nodes[:cut_node_idx]
+ self.fp16_params = self.fp16_params[:cut_param_idx]
+ self.param_size -= new_reg.param_size
+ self.param_num -= new_reg.param_num
+
+ return new_reg
+
+ def __update_params_ptr(self) -> None:
+ for param in self.fp16_params:
+ begin, end = self.param_to_range[param]
+ param.data = self.fp16_data[begin:end].view(param.data.shape)
\ No newline at end of file
diff --git a/colossalai/auto_parallel/offload/region_manager.py b/colossalai/auto_parallel/offload/region_manager.py
new file mode 100644
index 000000000000..30bfaf00d493
--- /dev/null
+++ b/colossalai/auto_parallel/offload/region_manager.py
@@ -0,0 +1,526 @@
+from typing import List, Any, Dict, Tuple
+import torch
+from torch.fx import Graph, Node
+
+from .solver import SolverFactory
+from .training_simulator import TrainingSimulator
+from .region import Region
+from .util import NodeInfo
+
+
+class RegionManager:
+ """
+ RegionManager is used to construct and manage the offload plan for the model execution.
+
+ Args:
+ graph (Graph): a Graph object used for analysis and strategy generation.
+ solver_name (str): a solver name which specifies the preferences for plan searching.
+ memory_budget (float): the given memory budget.
+ cnode (List[str], optional): Common node List, should be the subset of input.
+ """
+
+ def __init__(self,
+ graph: Graph,
+ solver_name: str = 'asyn',
+ memory_budget: float = -1.0,
+ cnode: List[str] = None):
+
+ self.graph = graph
+ assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
+ self.root_module = self.graph.owning_module
+ self.nodes = list(graph.nodes)
+ self.cnode = cnode
+ self.only_param_ops = []
+ self.param_region_map: Dict[torch.nn.Parameter, Region] = dict()
+ self.shared_region_pairs: List[Tuple[Region, Region]] = list()
+ self.region_list: List[Region] = list()
+ self.rid_in_pool: List[int] = list()
+ self.mem_block_size: int = 0
+ self.memory_budget = memory_budget
+
+ self.solver_name = solver_name
+ self.require_pool: bool = solver_name == 'asyn'
+
+ self.reg_to_block: Dict[int, int] = dict()
+
+ def _build_regions(self):
+ """
+ 1. Pre-processing, mainly contains linearized computing graph and
+ merge smaller regions into larger ones.
+ 2. Construct a solver to search for an efficient offload strategy.
+ 3. Post-processing, mainly contains early region placement if using asynchronous mode,
+ and initialize region data.
+ """
+
+ self._pre_process()
+
+ solver_cls = SolverFactory.create(self.solver_name)
+ solver = solver_cls(self.region_list, self.memory_budget)
+ solver._call_solver()
+
+ self._post_process(solver.best_ts)
+
+ def _pre_process(self):
+
+ init_region_list = self._linearize_graph()
+
+ if len(self.shared_region_pairs) > 1:
+ raise NotImplementedError(
+ 'The current version only considers at most one pair of parameter sharing.')
+
+ elif len(self.shared_region_pairs) == 1:
+ shared_regs = self.shared_region_pairs[0]
+ assert shared_regs[0].shared_rid == shared_regs[1].r_id \
+ and shared_regs[1].shared_rid == shared_regs[0].r_id
+ fst_id = shared_regs[0].r_id
+ lst_id = shared_regs[1].r_id
+ regs_left_out = init_region_list[:fst_id + 1]
+ regs_right_out = init_region_list[lst_id:]
+ hold_regs = init_region_list[fst_id + 1:lst_id]
+ else:
+ regs_left_out = []
+ regs_right_out = []
+ hold_regs = init_region_list
+
+ self.mem_block_size = self._search_block_size(hold_regs)
+ hold_regs = self._merge_small_regions(hold_regs)
+
+ if self.require_pool:
+ for reg in hold_regs:
+ reg.in_mem_pool_flag = True
+ self.rid_in_pool.append(reg.r_id)
+
+ self.region_list.extend(regs_left_out)
+ self.region_list.extend(hold_regs)
+
+ for reg in regs_right_out:
+ reg.r_id = self.region_list[-1].r_id + 1
+ self.region_list[reg.shared_rid].shared_rid = reg.r_id
+ self.region_list.append(reg)
+
+ self._process_shared_region()
+
+ self.max_param_num = max([reg.param_num for reg in self.region_list])
+ self.memory_budget -= self.max_param_num * torch.tensor([], dtype=torch.float32).element_size()
+
+ def _post_process(self, ts: TrainingSimulator = None):
+ if self.require_pool:
+ self._early_region_placement(ts)
+ self._init_region_data()
+
+ def _early_region_placement(self, ts: TrainingSimulator):
+ """
+ Implemented the early region placement strategy to avoid GPU memory fragmentation.
+ It maps all region data into a contiguous memory space and
+ reuses the same memory space for regions that do not coexist.
+
+ Args:
+ ts (TrainingSimulator): the best training simulator, which records region execution flow.
+
+ Raises:
+ NotImplementedError: due to the naive implementation,
+ it may not find a suitable region placement strategy for the given execution flow.
+ """
+
+ reg_flow = torch.cat(
+ [ts.fwd_reg_flow, ts.bwd_reg_flow], dim=0)
+ mem_block_num = torch.max(
+ torch.sum(reg_flow[:, self.rid_in_pool], dim=1))
+ coexist_matrix = torch.logical_or(
+ ts.fwd_reg_flow, ts.bwd_reg_flow)
+
+ block_to_regs = {}
+ for block_idx in range(mem_block_num):
+ block_to_regs[block_idx] = []
+ for reg in self.region_list:
+ if reg.r_id in self.rid_in_pool:
+ cur_reg_appears = coexist_matrix[:, reg.r_id]
+ cur_reg_coexists = torch.sum(
+ coexist_matrix[cur_reg_appears], dim=0).bool()
+ for block_idx in range(mem_block_num):
+ if not any(cur_reg_coexists[block_to_regs[block_idx]]):
+ block_to_regs[block_idx].append(reg.r_id)
+ self.reg_to_block[reg.r_id] = block_idx
+ break
+
+ if reg.r_id not in self.reg_to_block:
+ raise NotImplementedError(
+ f'can not find a block from the memory pool to store parameters of the region')
+ self.memory_pool = torch.chunk(torch.zeros(int(
+ mem_block_num * self.mem_block_size / 2), dtype=torch.half, device='cuda'), chunks=int(mem_block_num))
+
+ def _merge_small_regions(self, orig_reg_list: List[Region]) -> List[Region]:
+ """
+ Merge smaller regions into larger ones for better bandwidth utilization and easier management.
+ It is inspired by Gemini.
+
+ Args:
+ orig_reg_list (List[Region]): original region list.
+
+ Returns:
+ List[Region]: region list after merging.
+ """
+
+ r_id = orig_reg_list[0].r_id
+ region = Region(r_id=r_id)
+ region_list = [region]
+
+ for orig_reg in orig_reg_list:
+ if region_list[-1].param_size + orig_reg.param_size > self.mem_block_size:
+ r_id += 1
+ region = Region(r_id=r_id)
+ region_list.append(region)
+ region.param_size += orig_reg.param_size
+ region.param_num += orig_reg.param_num
+ region.nodes.extend(orig_reg.nodes)
+ region.fp16_params.extend(orig_reg.fp16_params)
+ self.__update_param_region_map(orig_reg.fp16_params, region)
+
+ return region_list
+
+ def _search_block_size(self,
+ region_list: List[Region],
+ search_interval_byte: int = 1024,
+ search_range_byte: int = 128 * 1024 ** 2) -> int:
+ """
+ Search for a suitable memory block size.
+
+ Args:
+ region_list (List[Region]): region list.
+ search_interval_byte (int): searching interval in byte.
+ search_range_byte (int): searching range in byte.
+
+ Returns:
+ int: the best memory block size.
+ """
+
+ def _get_wasted_mem(size_list: List[int], blk_size: int):
+ """
+ Get wasted byte for a certain block size.
+ """
+ acc_wasted = 0
+ left = 0
+ for s in size_list:
+ if left + s > blk_size:
+ acc_wasted += blk_size - left
+ left = s
+ left += s
+ acc_wasted += blk_size - left
+ return acc_wasted
+
+ param_size_list = [
+ region.param_size for region in region_list if region.r_id == region.shared_rid]
+
+ start_size = max(param_size_list)
+ min_mem_waste = float('+inf')
+ best_block_size = start_size
+
+ for block_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
+ temp_waste = 0
+ temp_waste += _get_wasted_mem(param_size_list, block_size)
+ if temp_waste < min_mem_waste:
+ min_mem_waste = temp_waste
+ best_block_size = block_size
+
+ return best_block_size
+
+ def _init_region_data(self):
+ """
+ Initialize region data, which maps the parameters in the region to a contiguous memory space.
+ """
+
+ self.temp_fp32_data = torch.zeros(self.max_param_num, device='cuda', dtype=torch.float32)
+
+ for region in self.region_list:
+ pre_alloc_tensor = None
+ if self.require_pool and region.r_id in self.rid_in_pool:
+ block_idx = self.reg_to_block[region.r_id]
+ pre_alloc_tensor = self.memory_pool[block_idx]
+
+ if region.r_id <= region.shared_rid:
+ region.init_param_data(pre_alloc_tensor)
+ else:
+ shared_region = self.region_list[region.shared_rid]
+ region.fp16_data = shared_region.fp16_data
+ region.fp32_data = shared_region.fp32_data
+ region.param_to_range = shared_region.param_to_range
+ region.temp_fp32_data = self.temp_fp32_data[:region.param_num].detach(
+ )
+
+ torch.cuda.empty_cache()
+
+ def _process_shared_region(self):
+ """
+ Special processing for the shared region, which uses GPT2 and Bert case as a priori knowledge.
+ """
+
+ if len(self.shared_region_pairs):
+ assert len(self.shared_region_pairs) <= 1
+ former_reg, latter_reg = self.shared_region_pairs[0]
+ assert latter_reg.param_num >= former_reg.param_num
+ embedding_node = former_reg.nodes[-1]
+ assert embedding_node.op == 'call_module' and isinstance(
+ self.root_module.get_submodule(embedding_node.target), torch.nn.Embedding)
+ if latter_reg.param_num > former_reg.param_num:
+ for idx, n in enumerate(latter_reg.nodes):
+ if (n.op == 'call_module' and isinstance(self.root_module.get_submodule(n.target),
+ torch.nn.Linear)) or \
+ (n.op == 'call_function' and n.target is torch.nn.functional.linear):
+ cut_node_idx = idx + 1
+ break
+ assert len(latter_reg.fp16_params) == 2
+ new_reg = latter_reg.split(cut_node_idx, 1)
+ for p in new_reg.fp16_params:
+ self.param_region_map[p] = new_reg
+ self.region_list.insert(new_reg.r_id, new_reg)
+ for reg in self.region_list[new_reg.r_id + 1:]:
+ reg.r_id += 1
+ latter_reg.shared_rid = former_reg.r_id
+ former_reg.shared_rid = latter_reg.r_id
+
+ def _linearize_graph(self) -> List[Region]:
+ """Linearizing the graph
+
+ Args:
+ graph (Graph): The computing graph to be optimized.
+
+ Returns:
+ List[Region]: each region contains the actual 'node' in linearized manner.
+
+ Remarks:
+ Do merge the inplace ops and shape-consistency ops into the previous node.
+ """
+
+ # List of target name that could be seen as common node
+ common_ops = ["getattr", "getitem", "size"]
+
+ def _is_cop(target: Any) -> bool:
+ """Check if an op could be seen as common node
+
+ Args:
+ target (Any): node target
+
+ Returns:
+ bool
+ """
+
+ if isinstance(target, str):
+ return target in common_ops
+ else:
+ return target.__name__ in common_ops
+
+ def _is_act(data: Any) -> bool:
+ """Check if an op could be seen as parameter computation start
+
+ Args:
+ data (Any): meta_data
+
+ Returns:
+ bool
+ """
+
+ label = False
+ if isinstance(data, torch.Tensor):
+ return True
+ elif isinstance(data, (tuple, list)):
+ for d in data:
+ label = label or _is_act(d)
+ return label
+
+ def _maybe_param_comp_start() -> bool:
+ """Check if an op could be seen as parameter computation start
+
+ Args:
+ n (Node): node
+
+ Returns:
+ bool
+ """
+
+ label = False
+ if n.op == "get_attr":
+ label = True
+ elif n.op == "call_module":
+ target = n.target
+ submod = self.root_module.get_submodule(target)
+ if (
+ len(list(submod.named_parameters(recurse=False))) != 0
+ or len(list(submod.named_buffers(recurse=False))) != 0
+ ):
+ label = True
+
+ return label and not sum([v for _, v in param_op_deps.items()])
+
+ def _is_param_comp_end() -> bool:
+ """Check if an op could be seen as parameter computation end
+
+ Args:
+ n (Node): node
+
+ Returns:
+ bool
+ """
+
+ def _is_inplace(n: Node):
+ """Get the inplace argument from ``torch.fx.Node``
+ """
+ inplace = False
+ if n.op == "call_function":
+ inplace = n.kwargs.get("inplace", False)
+ elif n.op == "call_module":
+ inplace = getattr(n.graph.owning_module.get_submodule(
+ n.target), "inplace", False)
+ return inplace
+
+ label = False
+
+ if n.op == "call_module":
+ target = n.target
+ submod = self.root_module.get_submodule(target)
+ if (
+ len(list(submod.named_parameters(recurse=False))) != 0
+ or len(list(submod.named_buffers(recurse=False))) != 0
+ ):
+ label = True
+
+ elif n.op == "call_function":
+ label = any(map(lambda x: x.name in self.only_param_ops, n.all_input_nodes)) and any(
+ map(lambda x: x.name not in self.only_param_ops and not _is_cop(n.target), n.all_input_nodes))
+
+ return label and not sum([v for _, v in param_op_deps.items()]) and not any(map(_is_inplace, n.users))
+
+ def _exception_node_handling():
+ # TODO meta info prop bug
+ if n.name.__contains__("transpose") and n.meta['fwd_out'][0].dim() <= 2:
+ n.meta['fwd_out'] = []
+
+ # make sure that item in cnode is valid
+ if self.cnode:
+ for name in self.cnode:
+ try:
+ assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \
+ f"Common node {name} is not an input of the model."
+ except StopIteration:
+ raise ValueError(f"Common node name {name} not in graph.")
+ else:
+ self.cnode = []
+
+ node_id = 0
+ region_id = 0
+
+ param_op_deps = {}
+
+ deps = {}
+ region_list = []
+ region = Region(r_id=region_id)
+
+ act_n = None
+
+ for n in self.graph.nodes:
+ if n.op != "placeholder" and n.op != "output":
+ for n_par in n.all_input_nodes:
+ if n_par.op != "placeholder" and n_par.name not in self.cnode:
+ deps[n_par] -= 1
+ if n_par.op != "placeholder" and n_par.name in self.only_param_ops:
+ param_op_deps[n_par] -= 1
+
+ if act_n in region.nodes and _maybe_param_comp_start():
+ ns = []
+ border_n_idx = region.nodes.index(act_n)
+ if border_n_idx < len(region.nodes):
+ ns = region.nodes[border_n_idx + 1:]
+ region.nodes = region.nodes[:border_n_idx + 1]
+ region_list.append(region)
+ region_id += 1
+ region = Region(r_id=region_id)
+ region.nodes = ns
+
+ _exception_node_handling()
+ region.nodes.append(n)
+ self._set_node_and_region_info(node_id, n, region)
+ node_id += 1
+
+ # if the node could free all dependencies in graph
+ # we could begin a new region
+ if _is_param_comp_end():
+ region_list.append(region)
+ region_id += 1
+ region = Region(r_id=region_id)
+
+ # propagate common node attr if possible
+ if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode
+ ]) or _is_cop(n.target):
+ self.cnode.append(n.name)
+ else:
+ deps[n] = len(
+ [user for user in n.users if user.op != "output"])
+
+ # propagate param node attr if possible
+ if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.only_param_ops
+ ]) or n.op == "get_attr":
+ self.only_param_ops.append(n.name)
+ param_op_deps[n] = len(
+ [user for user in n.users if user.op != "output"])
+
+ # record last activation node
+ if _is_act(n._meta_data):
+ act_n = n
+
+ if len(region.nodes):
+ region_list.append(region)
+
+ return region_list
+
+ def _set_node_and_region_info(self, node_id: int, cur_n: Node, cur_reg: Region):
+
+ cur_n.node_info = NodeInfo(node_id)
+
+ if cur_n.op == 'call_module':
+ target = cur_n.target
+ submod = self.root_module.get_submodule(target)
+ for p in list(submod.parameters(recurse=False)):
+
+ if p in self.param_region_map:
+ cur_reg.shared_rid = self.param_region_map[p].r_id
+ self.param_region_map[p].shared_rid = cur_reg.r_id
+ self.shared_region_pairs.append(
+ (self.param_region_map[p], cur_reg))
+ else:
+ self.param_region_map[p] = cur_reg
+
+ cur_reg.fp16_params.append(p)
+ cur_reg.param_num += p.data.numel()
+ cur_reg.param_size += p.data.numel() * p.data.element_size()
+
+ elif cur_n.op == "get_attr":
+ attr_itr = self.root_module
+ atoms = cur_n.target.split(".")
+ for atom in atoms:
+ attr_itr = getattr(attr_itr, atom)
+
+ if isinstance(attr_itr, torch.nn.Parameter):
+
+ if attr_itr in self.param_region_map:
+ cur_reg.shared_rid = self.param_region_map[attr_itr].r_id
+ self.param_region_map[attr_itr].shared_rid = cur_reg.r_id
+ self.shared_region_pairs.append(
+ (self.param_region_map[attr_itr], cur_reg))
+ else:
+ self.param_region_map[attr_itr] = cur_reg
+
+ cur_reg.fp16_params.append(attr_itr)
+ cur_reg.param_num += attr_itr.data.numel()
+ cur_reg.param_size += attr_itr.data.numel() * attr_itr.data.element_size()
+
+ def get_region(self, param: torch.nn.Parameter) -> Region:
+ """
+ Return the region owning the parameter.
+
+ Args:
+ param (torch.nn.Parameter): a torch parameter object
+ """
+ return self.param_region_map[param]
+
+ def __update_param_region_map(self, params: List[torch.nn.Parameter], region: Region):
+ for p in params:
+ self.param_region_map[p] = region
diff --git a/colossalai/auto_parallel/offload/runtime.py b/colossalai/auto_parallel/offload/runtime.py
new file mode 100644
index 000000000000..91c7945bd65f
--- /dev/null
+++ b/colossalai/auto_parallel/offload/runtime.py
@@ -0,0 +1,253 @@
+from typing import List
+import torch
+from torch.fx.node import Node
+
+from .region import Region
+from .util import GlobalRuntimeInfo, requires_upload_p_in_fwd
+
+
+class SynPreFwdPostBwdOP(torch.autograd.Function):
+ """
+ A customized prefetch and offload operation.
+
+ Args:
+ input_: input tensor.
+ fwd_info: information dict, which contains region indices
+ that need to be uploaded or freed during forward pass.
+ bwd_info: information dict, which contains region indices
+ that need to be uploaded during backward pass.
+ """
+
+ @staticmethod
+ def forward(ctx, input_, fwd_info, bwd_info):
+ ctx.bwd_info = bwd_info
+ d2h_rid = fwd_info.get('d2h_rid', None)
+ if d2h_rid is not None:
+ free_region = GlobalRuntimeInfo.region_list[d2h_rid]
+ assert isinstance(free_region, Region)
+ free_region.free_cuda_data()
+
+ h2d_rid = fwd_info.get('h2d_rid', None)
+ if h2d_rid is not None:
+ h2d_region = GlobalRuntimeInfo.region_list[h2d_rid]
+ assert isinstance(h2d_region, Region)
+ h2d_region.move_param_to_cuda()
+
+ return input_
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ h2d_rid = ctx.bwd_info.get('h2d_rid', None)
+ if h2d_rid is not None:
+ pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
+ assert isinstance(pref_region, Region)
+ pref_region.move_param_to_cuda()
+
+ return grad_output, None, None
+
+
+class AsynPreFwdPostBwdOP(torch.autograd.Function):
+ """
+ A customized prefetch and offload operation.
+
+ Args:
+ input_: input tensor.
+ fwd_info: information dict, which contains region indices
+ that need to be prefetched, waited, or freed during forward pass.
+ bwd_info: information dict, which contains region indices
+ that need to be prefetched or waited during backward pass.
+ """
+
+ @staticmethod
+ def forward(ctx, input_, fwd_info, bwd_info):
+ ctx.bwd_info = bwd_info
+
+ sync_rid = fwd_info.get('sync_rid', None)
+ if sync_rid is not None:
+ prefetch_event = GlobalRuntimeInfo.fwd_prefetch_event_map.get(
+ sync_rid, None)
+ if prefetch_event:
+ prefetch_event.wait()
+
+ h2d_rid = fwd_info.get('h2d_rid', None)
+ if h2d_rid is not None:
+ pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
+ assert isinstance(pref_region, Region)
+ master_stream = torch.cuda.current_stream()
+ with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
+ GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
+ pref_region.move_param_to_cuda()
+
+ prefetch_event = torch.cuda.Event()
+ prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
+ GlobalRuntimeInfo.fwd_prefetch_event_map[h2d_rid] = prefetch_event
+
+ return input_
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ sync_rid = ctx.bwd_info.get('sync_rid', None)
+ if sync_rid is not None:
+ wait_region = GlobalRuntimeInfo.region_list[sync_rid]
+ assert isinstance(wait_region, Region)
+ prefetch_event = GlobalRuntimeInfo.bwd_prefetch_event_map.get(
+ sync_rid, None)
+ if prefetch_event:
+ prefetch_event.wait()
+ else:
+ wait_region.move_param_to_cuda()
+
+ h2d_rid = ctx.bwd_info.get('h2d_rid', None)
+ if h2d_rid is not None:
+ pref_region = GlobalRuntimeInfo.region_list[h2d_rid]
+ assert isinstance(pref_region, Region)
+ master_stream = torch.cuda.current_stream()
+ with torch.cuda.stream(GlobalRuntimeInfo.h2d_stream):
+ GlobalRuntimeInfo.h2d_stream.wait_stream(master_stream)
+ pref_region.move_param_to_cuda()
+
+ prefetch_event = torch.cuda.Event()
+ prefetch_event.record(GlobalRuntimeInfo.h2d_stream)
+ GlobalRuntimeInfo.bwd_prefetch_event_map[h2d_rid] = prefetch_event
+ return grad_output, None, None
+
+
+def convert_fwd_upload_bwd_offload_to_action(tensor, fwd_info, bwd_info):
+ '''
+ Convert Upload and Offload operation into runtime action.
+
+ Argument:
+ tensor(torch.Tensor): input tensor.
+ fwd_info(dict): information dict, which contains region indices
+ that need to be uploaded, or freed during forward pass.
+ bwd_info(dict): information dict, which contains region indices
+ that need to be uploaded during backward pass.
+ '''
+ with torch._C.DisableTorchFunction():
+ ret = SynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
+ return ret
+
+def convert_fwd_prefetch_bwd_offload_to_action(tensor, fwd_info, bwd_info):
+ '''
+ Convert Prefetch and Offload operation into runtime action.
+
+ Argument:
+ tensor(torch.Tensor): input tensor.
+ fwd_info(dict): information dict, which contains region indices
+ that need to be prefetched, waited, or freed during forward pass.
+ bwd_info(dict): information dict, which contains region indices
+ that need to be prefetched or waited during backward pass.
+ '''
+ with torch._C.DisableTorchFunction():
+ ret = AsynPreFwdPostBwdOP.apply(tensor, fwd_info, bwd_info)
+ return ret
+
+
+def replace_node_users(orig_node: Node, inserted_node: Node, rep_user_nodes: List[Node] = None):
+ user_list = list(orig_node.users.keys())
+ if rep_user_nodes is not None:
+ user_list = rep_user_nodes
+ for user in user_list:
+ if user == inserted_node:
+ continue
+ new_args = list(user.args)
+ new_kwargs = dict(user.kwargs)
+ # the origin node may be a positional argument or key word argument of user node
+ if orig_node in new_args:
+ # substitute the origin node with offload_apply_node
+ new_args[new_args.index(orig_node)] = inserted_node
+ user.args = tuple(new_args)
+ elif str(orig_node) in new_kwargs:
+ # substitute the origin node with offload_apply_node
+ new_kwargs[str(orig_node)] = inserted_node
+ user.kwargs = new_kwargs
+
+
+def runtime_syn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
+ """
+ This pass is used to add the synchronous upload and offload spec apply node to the origin graph.
+ """
+ mod_graph = gm.graph
+ last_inp_node = tuple(mod_graph.nodes)[0]
+
+ for r_idx, region in enumerate(region_list):
+ # forward upload
+ fwd_info = {}
+ if requires_upload_p_in_fwd(region_list[region.shared_rid]):
+ fwd_info['h2d_rid'] = region.r_id
+
+ # forward offload
+ if r_idx > 0 and region_list[r_idx - 1].need_offload:
+ fwd_info['d2h_rid'] = r_idx - 1
+
+ bwd_info = {}
+ # backward upload
+ if r_idx > 0 and region_list[r_idx - 1].need_offload:
+ bwd_info['h2d_rid'] = region_list[r_idx - 1].r_id
+
+ if fwd_info or bwd_info:
+ with mod_graph.inserting_after(last_inp_node):
+ new_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
+ args=(last_inp_node, fwd_info, bwd_info))
+ replace_node_users(last_inp_node, new_node)
+
+ last_inp_node = region.nodes[-1]
+
+ return gm
+
+
+def runtime_asyn_offload_apply_pass(gm: torch.fx.GraphModule, region_list: List[Region]):
+ """
+ This pass is used to add the asynchronous prefetch and offload spec apply node to the origin graph.
+ """
+ mod_graph = gm.graph
+
+ # upload parameters of the first region
+ last_inp_node = tuple(mod_graph.nodes)[0]
+ first_region_with_p = [
+ region for region in region_list if region.param_size][0]
+ fwd_info = {"h2d_rid": first_region_with_p.r_id}
+ with mod_graph.inserting_after(last_inp_node):
+ upload_apply_node = mod_graph.create_node('call_function', convert_fwd_upload_bwd_offload_to_action,
+ args=(last_inp_node, fwd_info, {}))
+ replace_node_users(last_inp_node, upload_apply_node)
+ last_inp_node = upload_apply_node
+
+ for r_idx, region in enumerate(region_list):
+ # forward prefetch
+ fwd_info = {}
+ if region.param_size:
+ fwd_info['sync_rid'] = region.r_id
+ fwd_prefetch_region = region.fwd_prefetch_region
+ if fwd_prefetch_region and requires_upload_p_in_fwd(region_list[fwd_prefetch_region.shared_rid]):
+ fwd_info['h2d_rid'] = fwd_prefetch_region.r_id
+
+ # forward offload
+ if r_idx > 0 and region_list[r_idx-1].need_offload:
+ fwd_info['d2h_rid'] = r_idx - 1
+
+ bwd_info = {}
+ # backward prefetch
+ if r_idx > 0 and region_list[r_idx-1].need_offload:
+ bwd_info['sync_rid'] = r_idx - 1
+ if r_idx > 0 and region_list[r_idx-1].bwd_prefetch_region:
+ bwd_info['h2d_rid'] = region_list[r_idx-1].bwd_prefetch_region.r_id
+
+ if fwd_info or bwd_info:
+ with mod_graph.inserting_after(last_inp_node):
+ new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
+ args=(last_inp_node, fwd_info, bwd_info))
+ replace_node_users(last_inp_node, new_node)
+
+ last_inp_node = region.nodes[-1]
+
+ if region.bwd_prefetch_region:
+ bwd_info = {'h2d_rid': region.bwd_prefetch_region.r_id}
+ with mod_graph.inserting_after(last_inp_node):
+ new_node = mod_graph.create_node('call_function', convert_fwd_prefetch_bwd_offload_to_action,
+ args=(last_inp_node, {}, bwd_info))
+ replace_node_users(last_inp_node, new_node)
+ # gm.graph.print_tabular()
+ return gm
diff --git a/colossalai/auto_parallel/offload/solver.py b/colossalai/auto_parallel/offload/solver.py
new file mode 100644
index 000000000000..161f7ff86898
--- /dev/null
+++ b/colossalai/auto_parallel/offload/solver.py
@@ -0,0 +1,523 @@
+import time
+from typing import List, Dict, Type
+from abc import ABC, abstractmethod
+
+NOT_NVML = False
+try:
+ from pynvml import *
+except:
+ NOT_NVML = True
+
+import torch
+from torch.fx.node import Node
+from colossalai.utils.cuda import get_current_device
+
+from .training_simulator import TrainingSimulator, SynTrainingSimulator, AsynTrainingSimulator
+from .region import Region
+from .util import NodeInfo, NvDevicePower
+
+
+def benchmark_func(func, number=1, repeat=1, warmup=3):
+ """
+ benchmark data transfer cost.
+ """
+
+ for i in range(warmup):
+ func()
+
+ costs = []
+
+ for i in range(repeat):
+ torch.cuda.synchronize()
+ begin = time.time()
+ for i in range(number):
+ func()
+ torch.cuda.synchronize()
+ costs.append((time.time() - begin) / number)
+
+ return sum(costs) / len(costs)
+
+
+class Solver(ABC):
+ """
+ The parameter offload solver.
+
+ Args:
+ region_list (List[Region]): represents the linearized DNN computing graph.
+ memory_budget (float): the given memory budget.
+ error_factor (float): the error factor.
+ It is used to reduce the memory budget. Due to some errors in the estimation of peak memory and execution time.
+ """
+
+ def __init__(self,
+ region_list: List[Region],
+ memory_budget: float = -1.0,
+ error_factor: float = 0.95) -> None:
+
+ self.region_list = region_list
+
+ self.error_factor: float = error_factor
+ if memory_budget > 0:
+ self.memory_budget = memory_budget * self.error_factor
+ else:
+ self.memory_budget = torch.cuda.get_device_properties(
+ get_current_device()).total_memory * self.error_factor
+
+ self.link_to_bandwidth: Dict[str, Dict[float, float]] = self._profile_bandwidth()
+ self.comp_power: float = self._extract_computing_power()
+
+ @abstractmethod
+ def _call_solver(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def _try_to_offload(self, *args):
+ raise NotImplementedError
+
+ @abstractmethod
+ def _eval_one_choice(self, *args):
+ raise NotImplementedError
+
+ def _compute_offload_profit(self, total_mem_saving: float, peak_mem_saving: float, extra_cost: float):
+ """
+ Compute the profits of the offload strategies,
+ which packages the memory savings information for subsequent comparisons.
+
+ Args:
+ total_mem_saving (float): the total memory saving of the offload strategy.
+ peak_mem_saving (float): the peak memory saving of the offload strategy.
+ extra_cost (float): extra data transfer cost.
+
+ Returns:
+ tuple: profit information, the first term represents memory savings per unit of time.
+ """
+
+ if extra_cost == 0:
+ # means data transfer overhead can be completely overlapped
+ return (float('inf'), total_mem_saving, peak_mem_saving)
+ return (total_mem_saving / extra_cost, total_mem_saving, peak_mem_saving)
+
+ def _compare_profit(self, profit_a: tuple, profit_b: tuple) -> bool:
+ """
+ Compare the profits of the two offload strategies using the dictionary order algorithm.
+
+ Args:
+ profit_a (tuple): the profit of a offload strategy.
+ profit_b (tuple): the profit of another offload strategy.
+
+ Returns:
+ bool: whether profit_a is greater than profit_b.
+ """
+
+ for val1, val2 in zip(profit_a, profit_b):
+ if val1 != val2:
+ return val1 > val2
+ return False
+
+ def _update_state(self, best_ts: TrainingSimulator):
+ """
+ Update the solver state.
+ """
+
+ self.best_ts = best_ts
+ self._update_node_mem_info(best_ts.fwd_node_mem, best_ts.bwd_node_mem)
+
+ def _update_node_mem_info(self,
+ fwd_mem_info: Dict[Node, float],
+ bwd_mem_info: Dict[Node, float]):
+ """
+ Update the runtime memory information of the node.
+
+ Args:
+ fwd_mem_info (Dict[Node, float]): the runtime memory of each node in forward pass.
+ bwd_mem_info (Dict[Node, float]): the runtime memory of each node in backward pass.
+ """
+
+ for node, mem in fwd_mem_info.items():
+ assert hasattr(node, 'node_info') and isinstance(
+ node.node_info, NodeInfo)
+ node.node_info.runtime_fwd_mem = mem
+ for node, mem in bwd_mem_info.items():
+ assert hasattr(node, 'node_info') and isinstance(
+ node.node_info, NodeInfo)
+ node.node_info.runtime_bwd_mem = mem
+
+ def _extract_computing_power(self):
+ """
+ return the FP16 computing performance of the current NVIDIA GPU.
+
+ Raises:
+ TypeError: Unknown NVIDIA GPU device.
+ """
+
+ nvmlInit()
+ handle = nvmlDeviceGetHandleByIndex(0)
+ device_name = nvmlDeviceGetName(handle)
+ units = 1e12
+
+ if device_name.__contains__("RTX 3080"):
+ return NvDevicePower.RTX3080_FP16 * units
+ elif device_name.__contains__("RTX 3090"):
+ return NvDevicePower.RTX3090_FP16 * units
+ elif device_name.__contains__('V100'):
+ return NvDevicePower.V100_FP16 * units
+ elif device_name.__contains__("A100"):
+ return NvDevicePower.A100_FP16 * units
+ else:
+ raise TypeError(f'Unknown NVIDIA GPU device name {device_name}')
+
+ def _profile_bandwidth(self):
+ """
+ Profile the bidirectional communication bandwidth between CPU and GPU
+ using data volumes ranging from 1KB to 1GB.
+ """
+
+ print('profiling bandwidth ......')
+ link_to_bandwidth = {}
+ links = ['h2d', 'd2h']
+
+ for link in links:
+ t_size = 1024
+ size_to_bandwidth = {}
+
+ # from 1KB to 1GB
+ for i in range(21):
+ if link == 'h2d':
+ src_tensor = torch.ones(
+ int(t_size), dtype=torch.int8, pin_memory=True)
+ dst_tensor = torch.ones(
+ (int(t_size)), dtype=torch.int8, device='cuda')
+ elif link == 'd2h':
+ src_tensor = torch.ones(
+ int(t_size), dtype=torch.int8, device='cuda')
+ dst_tensor = torch.ones(
+ (int(t_size)), dtype=torch.int8, pin_memory=True)
+
+ def func():
+ dst_tensor.copy_(src_tensor)
+
+ size_to_bandwidth[t_size] = t_size / benchmark_func(func, number=5, repeat=3)
+ print(f'size: {t_size / 1024 ** 2:.3f} MB, '
+ f'{src_tensor.device.type}-to-{dst_tensor.device.type} '
+ f'bandwidth: {size_to_bandwidth[t_size] / 1024 ** 3:.3f} GB/s')
+
+ t_size *= 2
+
+ link_to_bandwidth[link] = size_to_bandwidth
+ return link_to_bandwidth
+
+
+class SynGreedySolver(Solver):
+
+ def __init__(self,
+ region_list: List[Region],
+ memory_budget: float = -1.0) -> None:
+ super().__init__(region_list, memory_budget)
+
+ self.best_ts: SynTrainingSimulator = None
+ self._init_state()
+
+ def _init_state(self):
+ """
+ Initialize the solver state when without offloading.
+ """
+
+ ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
+ ts.execute()
+ self._update_state(ts)
+
+ def _call_solver(self):
+ """
+ Call the solver to search an efficient parameter offloading strategy for the linearized graph.
+ The solver adopts greedy algorithm.
+
+ Raises:
+ NotImplementedError: Unable to find a solution for the given memory budget.
+ """
+
+ print("search offloading strategy ......")
+ while self.best_ts.peak_mem > self.memory_budget:
+ offload_region = None
+ best_ts = None
+ max_profit = (0,)
+
+ # search which region should be offloaded,
+ # the last region does not need to be offloaded.
+ for region in self.region_list[:-1]:
+ if region.param_size and not region.need_offload:
+ temp_ts, profit = self._try_to_offload(region)
+ if self._compare_profit(profit, max_profit):
+ offload_region = region
+ max_profit = profit
+ best_ts = temp_ts
+
+ if offload_region is not None and best_ts is not None:
+ offload_region.need_offload = True
+ offload_region.is_syn = True
+ self._update_state(best_ts)
+ else:
+ raise NotImplementedError(
+ f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
+ f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
+
+ def _call_solver_l2l(self):
+ """
+ The layer-wise offload strategy.
+ """
+
+ for region in self.region_list[:-1]:
+ region.need_offload = True
+ region.is_syn = True
+
+ def _try_to_offload(self, offload_region: Region):
+
+ # record previous information
+ orig_need_offload = offload_region.need_offload
+ assert not orig_need_offload
+ offload_region.need_offload = True
+
+ ts, profit = self._eval_one_choice(offload_region)
+
+ # restore previous information
+ offload_region.need_offload = orig_need_offload
+ return ts, profit
+
+ def _eval_one_choice(self, offload_region: Region):
+ """
+ Evaluate the profit of a strategy choice.
+
+ Args:
+ offload_region (Region): the offload region of current choice.
+
+ Returns:
+ SynTrainingSimulator: the training simulator corresponding to the current strategy.
+ tuple: contains memory saving and cost information of the current strategy.
+ """
+
+ ts = SynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
+ ts.execute()
+
+ extra_comm_cost = 2.0 * \
+ ts._get_communication_overhead('h2d', offload_region.param_size)
+ # the shared region needs to be moved twice
+ if offload_region.r_id < offload_region.shared_rid:
+ extra_comm_cost *= 2.0
+ profit = self._compute_offload_profit(
+ ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
+
+ return ts, profit
+
+
+class AsynGreedySolver(Solver):
+
+ def __init__(self,
+ region_list: List[Region],
+ memory_budget: float = -1.0,
+ search_window_size: int = 3):
+ super().__init__(region_list, memory_budget)
+
+ self.search_window_size = search_window_size
+ # Records the prefetch execution location of the offloaded region
+ self.region_to_region_map = {}
+ self.best_ts: AsynTrainingSimulator = None
+
+ self._init_state()
+
+ def _init_state(self):
+ """
+ Initialize the solver state when without offloading.
+ """
+
+ ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
+ ts.execute()
+ self._update_state(ts)
+ print("init peak memory", self.best_ts.peak_mem / 1024 ** 2, "MB")
+
+ def _call_solver(self):
+ """
+ Call the solver to search an efficient parameter offloading strategy for the linearized graph.
+ The solver adopts greedy algorithm.
+
+ Raises:
+ NotImplementedError: Unable to find a solution for the given memory budget.
+ """
+
+ print("search for offloading strategy ......")
+ # Records the prefetch execution location of the offloaded region
+ region_to_region_map = {}
+ while self.best_ts.peak_mem > self.memory_budget:
+ region_to_offload = None
+ max_offload_profit = (0,)
+ best_offl_ts = None
+
+ # search which region should be offloaded,
+ # the last region does not need to be offloaded
+ for region in self.region_list[:-1]:
+ if region.param_size and not region.need_offload:
+ max_prefetch_profit = (0,)
+ best_pref_ts = None
+
+ # search when to prefetch the region offloaded
+ for host_region in self.region_list[region.r_id + 1:region.r_id + 1 + self.search_window_size]:
+ if host_region.bwd_prefetch_region is not None:
+ continue
+
+ temp_ts, profit = self._try_to_offload(
+ host_region, region)
+
+ if self._compare_profit(profit, max_prefetch_profit):
+ region_to_region_map[region.r_id] = host_region
+ max_prefetch_profit = profit
+ best_pref_ts = temp_ts
+ if profit[0] == float('inf'):
+ break
+
+ if self._compare_profit(max_prefetch_profit, max_offload_profit):
+ region_to_offload = region
+ max_offload_profit = max_prefetch_profit
+ best_offl_ts = best_pref_ts
+
+ if (region_to_offload is not None) and (best_offl_ts is not None):
+ region_to_offload.need_offload = True
+ if region_to_region_map[region_to_offload.r_id] == region_to_offload:
+ region_to_offload.is_syn = True
+ else:
+ region_to_region_map[region_to_offload.r_id].bwd_prefetch_region = region_to_offload
+ self.region_to_region_map[region_to_offload.r_id] = region_to_region_map[region_to_offload.r_id]
+
+ self._update_state(best_offl_ts)
+
+ elif self.region_to_region_map.__len__() > 0:
+ self._repair_strategy()
+ else:
+ raise NotImplementedError(
+ f"can't find the offload strategy met the memory budget {self.memory_budget / 1024 ** 2} MB, "
+ f"it needs {self.best_ts.peak_mem / 1024 ** 2:.3f} MB at least!")
+
+ region_to_region_map.clear()
+
+ def _try_to_offload(self, host_region: Region, offload_region: Region):
+ """
+ Attempts to offload the region and prefetch it in backward pass.
+ """
+
+ # record previous information
+ orig_prefetch = host_region.bwd_prefetch_region
+ orig_is_syn = offload_region.is_syn
+ orig_need_offload = offload_region.need_offload
+
+ if host_region == offload_region:
+ offload_region.is_syn = True
+ else:
+ host_region.bwd_prefetch_region = offload_region
+ offload_region.need_offload = True
+
+ ts, profit = self._eval_one_choice()
+
+ # restore previous information
+ host_region.bwd_prefetch_region = orig_prefetch
+ offload_region.is_syn = orig_is_syn
+ offload_region.need_offload = orig_need_offload
+
+ return ts, profit
+
+ def _try_convert_to_syn_upload(self, host_region: Region, offload_region: Region):
+ """
+ Attempts to convert asynchronous prefetch into synchronous upload operations.
+ """
+
+ # record previous information
+ orig_prefetch = host_region.bwd_prefetch_region
+ orig_is_syn = offload_region.is_syn
+ assert orig_prefetch is not None and not orig_is_syn
+
+ host_region.bwd_prefetch_region = None
+ offload_region.is_syn = True
+
+ ts, profit = self._eval_one_choice()
+
+ # restore previous information
+ host_region.bwd_prefetch_region = orig_prefetch
+ offload_region.is_syn = orig_is_syn
+
+ return ts, profit
+
+ def _repair_strategy(self):
+ """
+ Repair offload strategy.
+ It attempts to convert asynchronous prefetch into synchronous upload operations and selects the best one.
+ The repair process does not end until peak memory is reduced or there is no asynchronous prefetch operation.
+ """
+ print("repair strategy ......")
+
+ peak_mem_saving = 0
+ while len(self.region_to_region_map) and peak_mem_saving <= 0:
+
+ max_profit = (0,)
+ best_ts = None
+ undo_host_region = None
+ undo_offload_region = None
+
+ for offload_region_id, host_region in self.region_to_region_map.items():
+ offload_region = self.region_list[offload_region_id]
+ assert host_region.bwd_prefetch_region == offload_region
+ assert offload_region.need_offload
+ assert not offload_region.is_syn
+
+ ts, profit = self._try_convert_to_syn_upload(host_region,
+ offload_region)
+
+ if self._compare_profit(profit, max_profit):
+ undo_host_region = host_region
+ undo_offload_region = offload_region
+ max_profit = profit
+ best_ts = ts
+
+ if best_ts is None:
+ raise NotImplementedError('repair error!')
+
+ assert not undo_offload_region.is_syn
+ undo_offload_region.is_syn = True
+ undo_host_region.bwd_prefetch_region = None
+
+ peak_mem_saving = self.best_ts.peak_mem - best_ts.peak_mem
+
+ self._update_state(best_ts)
+ self.region_to_region_map.pop(undo_offload_region.r_id)
+
+ return best_ts
+
+ def _eval_one_choice(self):
+ """
+ Evaluate the profit of a strategy choice.
+
+ Returns:
+ AsynTrainingSimulator: the training simulator corresponding to the current strategy.
+ tuple: contains memory saving and cost information of the current strategy.
+ """
+
+ ts = AsynTrainingSimulator(self.region_list, self.comp_power, self.link_to_bandwidth)
+ ts.execute()
+
+ extra_comm_cost = max(ts.iter_end_time - self.best_ts.iter_end_time, 0)
+ profit = self._compute_offload_profit(
+ ts.total_mem_saving, self.best_ts.peak_mem - ts.peak_mem, extra_comm_cost)
+
+ return ts, profit
+
+
+class SolverFactory:
+ solvers: Dict[str, Type[Solver]] = {
+ 'syn': SynGreedySolver,
+ 'asyn': AsynGreedySolver
+ }
+
+ @staticmethod
+ def create(solver_name: str) -> Type[Solver]:
+ if solver_name not in SolverFactory.solvers:
+ raise TypeError(f"Unknown parameter offload policy {solver_name}")
+ return SolverFactory.solvers[solver_name]
+
+ @staticmethod
+ def get_solver_names():
+ return tuple(SolverFactory.solvers.keys())
diff --git a/colossalai/auto_parallel/offload/training_simulator.py b/colossalai/auto_parallel/offload/training_simulator.py
new file mode 100644
index 000000000000..f277c183a912
--- /dev/null
+++ b/colossalai/auto_parallel/offload/training_simulator.py
@@ -0,0 +1,458 @@
+import bisect
+from typing import List, Dict
+from collections import OrderedDict
+from abc import ABC, abstractmethod
+
+from torch.fx.node import Node
+
+from .region import Region
+from .util import *
+
+
+@dataclass
+class ExecutionPeriod:
+ start_time: float = 0
+ end_time: float = 0
+
+
+class TrainingSimulator(ABC):
+ """
+ The Training Simulator is used to simulate the training process.
+ It records computation, communication, and runtime memory during forward and backward passes.
+
+ Args:
+ region_list (List[Region]): represents the linearized DNN computing graph.
+ comp_power (float): the NVIDIA GPU FP16 compuing power.
+ link_to_bw (Dict[str, Dict[float, float]]): communication links and the corresponding bandwidth.
+ """
+
+ def __init__(self,
+ region_list: List[Region],
+ comp_power: float,
+ link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ self.region_list = region_list
+ self.region_num = len(region_list)
+
+ self.runtime_mem: int = 0
+ self.peak_mem: int = 0
+ self.total_mem_saving: int = 0
+
+ self.fwd_node_mem: Dict[Node, float] = {}
+ self.bwd_node_mem: Dict[Node, float] = {}
+
+ # Node dependencies in backward pass
+ self.bwd_node_deps: Dict[Node, int] = {}
+
+ self.comp_power: float = comp_power
+ self.link_to_bandwidth: Dict[str, Dict[float, float]] = link_to_bw
+
+ @abstractmethod
+ def execute(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def _eval_fwd_mem_per_region(self, region: Region):
+ raise NotImplementedError
+
+ @abstractmethod
+ def _eval_bwd_mem_per_region(self, region: Region):
+ raise NotImplementedError
+
+ def _get_bandwidth(self, link: str, comm_volumn: float) -> float:
+ """
+ Get the data transfer bandwidth.
+
+ Args:
+ link (str): the data transfer link.
+ comm_volumn (float): the amount of data transferred.
+
+ Returns:
+ float: the data transfer bandwidth.
+ """
+
+ assert len(self.link_to_bandwidth)
+ if link not in self.link_to_bandwidth:
+ raise TypeError(f"Unknown data transfer link {link}")
+
+ # size_list = sorted(list(map(float, self.link_to_bandwidth[link].keys())))
+ size_list = sorted(self.link_to_bandwidth[link].keys())
+ d_idx = bisect.bisect_left(size_list, comm_volumn)
+ return self.link_to_bandwidth[link][size_list[d_idx]]
+
+ def _get_communication_overhead(self, link: str, comm_volumn: float) -> float:
+ return comm_volumn / self._get_bandwidth(link, comm_volumn)
+
+ def _get_computing_overhead(self, flop: float) -> float:
+ return flop / self.comp_power
+
+
+class SynTrainingSimulator(TrainingSimulator):
+
+ def __init__(self,
+ region_list: List[Region],
+ comp_power: float,
+ link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ super().__init__(region_list, comp_power, link_to_bw)
+
+ def execute(self):
+ """
+ Simulate synchronous training process.
+ """
+
+ for reg in self.region_list:
+ self._eval_fwd_mem_per_region(reg)
+
+ for reg in self.region_list.__reversed__():
+ self._eval_bwd_mem_per_region(reg)
+
+ def _eval_fwd_mem_per_region(self, region: Region):
+ """
+ Evaluate the runtime and peak memory when the forward execution reaches the current region.
+ """
+
+ # upload parameters of the current region
+ if requires_upload_p_in_fwd(self.region_list[region.shared_rid]):
+ self.runtime_mem += region.param_size
+
+ for node in region.nodes:
+ self.runtime_mem += calculate_fwd_tmp(node) + \
+ calculate_fwd_out(node)
+ self.fwd_node_mem[node] = self.runtime_mem
+ self.peak_mem = max(self.runtime_mem, self.peak_mem)
+ self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
+
+ if region.need_offload:
+ self.runtime_mem -= region.param_size
+
+ def _eval_bwd_mem_per_region(self, region: Region):
+ """
+ Evaluate the runtime and peak memory when the backward execution reaches the current region.
+ """
+
+ # upload parameters of the current region
+ if region.need_offload:
+ self.runtime_mem += region.param_size
+
+ # add the gradient of the parameter
+ if region.r_id < region.shared_rid:
+ # gradient accumulation is required for shared parameters
+ self.runtime_mem += 2.0 * region.param_size
+ else:
+ self.runtime_mem += region.param_size
+
+ for node in region.nodes.__reversed__():
+
+ self.runtime_mem -= calculate_fwd_out(node)
+ self.runtime_mem += node.meta['bwd_mem_tmp'] + \
+ node.meta['bwd_mem_out']
+ self.peak_mem = max(self.runtime_mem, self.peak_mem)
+
+ # The memory savings of a node may be negative due to parameter prefetch.
+ self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
+ self.bwd_node_mem[node] = self.runtime_mem
+
+ self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
+ calculate_fwd_tmp(node))
+
+ # free bwd_mem_out
+ self.bwd_node_deps[node] = len(node.all_input_nodes)
+ for user_node in node.users:
+ if user_node in self.bwd_node_deps:
+ self.bwd_node_deps[user_node] -= 1
+ if self.bwd_node_deps[user_node] <= 0:
+ self.runtime_mem -= user_node.meta['bwd_mem_out']
+
+ if self.runtime_mem < 0:
+ raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
+ f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
+ f"runtime memory computed less than 0, which is miscalculated!")
+
+ # release parameter and offload gradient in region
+ if region.r_id == region.shared_rid:
+ self.runtime_mem -= 2.0 * region.param_size
+ elif region.r_id < region.shared_rid:
+ self.runtime_mem -= 3.0 * region.param_size
+ elif self.region_list[region.shared_rid].need_offload:
+ self.runtime_mem -= region.param_size
+
+
+class AsynTrainingSimulator(TrainingSimulator):
+
+ def __init__(self,
+ region_list: List[Region],
+ comp_power: float,
+ link_to_bw: Dict[str, Dict[float, float]]) -> None:
+ super().__init__(region_list, comp_power, link_to_bw)
+
+ self.iter_end_time: int = 0
+ # the last computation execution period
+ self.last_comp: ExecutionPeriod = ExecutionPeriod(
+ start_time=0, end_time=0)
+ # the last parameter prefetch execution period
+ self.last_h2d: ExecutionPeriod = ExecutionPeriod(
+ start_time=0, end_time=0)
+ # the last gradient offload execution period
+ self.last_d2h: ExecutionPeriod = ExecutionPeriod(
+ start_time=0, end_time=0)
+ # the forward computation execution period of the region
+ self.fwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
+ # the forward parameter prefetch execution period of the region
+ self.fwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
+ # the backward computation execution period of the region
+ self.bwd_reg_to_comp: OrderedDict[int, ExecutionPeriod] = OrderedDict()
+ # the backward parameter prefetch execution period of the region
+ self.bwd_reg_to_pref: OrderedDict[int, ExecutionPeriod] = OrderedDict()
+ # the gradient offload execution period of the region
+ # which is divided into those that are waiting and those that have been released
+ self.bwd_reg_to_offl_waiting: OrderedDict[int,
+ ExecutionPeriod] = OrderedDict()
+ self.bwd_reg_to_offl_freed: OrderedDict[int,
+ ExecutionPeriod] = OrderedDict()
+ # the region buffer, which records regions that are offloaded but not released
+ self.reg_buffer_to_free: List[int] = []
+
+ # node dependencies in backward pass
+ self.bwd_node_deps: Dict[Node, int] = {}
+
+ # the region execution flow,
+ # where fwd_reg_flow[i,j] denotes whether the parameters of j-th region are in the GPU
+ # when the execution reaches the i-th region.
+ self.fwd_reg_flow = torch.zeros(
+ (self.region_num, self.region_num)).bool()
+ self.bwd_reg_flow = torch.zeros(
+ (self.region_num, self.region_num)).bool()
+
+ def execute(self):
+ """
+ Simulate asynchronous training process.
+ In forward pass, parameter prefetching is advanced by one region.
+ In backward pass, parameter prefetching is executed at the specified location,
+ and gradient offloading is urgent.
+ """
+
+ for reg in self.region_list:
+ if reg.param_size and reg.r_id < self.region_num - 1:
+ for nr in self.region_list[reg.r_id + 1:]:
+ if nr.param_size and requires_upload_p_in_fwd(self.region_list[nr.shared_rid]):
+ reg.fwd_prefetch_region = nr
+ break
+ self._eval_fwd_cost_per_region(reg)
+ self._eval_fwd_mem_per_region(reg)
+
+ for reg in self.region_list.__reversed__():
+ self._eval_bwd_cost_per_region(reg)
+ self._eval_bwd_mem_per_region(reg)
+
+ # release remaining grads
+ for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
+ self.bwd_reg_to_offl_freed[reg_id] = offl_exec
+ self.runtime_mem -= self.region_list[reg_id].param_size
+ self.bwd_reg_to_offl_waiting.clear()
+
+ self.iter_end_time = max(
+ self.last_comp.end_time, self.last_d2h.end_time)
+
+ def _insert_h2d_exec(self, region: Region, is_fwd: bool = True):
+ """
+ Insert parameter prefetch execution period of the current region to the end of the h2d stream
+ """
+
+ pref_start_time = max(self.last_h2d.end_time, self.last_comp.end_time)
+ pref_end_time = pref_start_time + \
+ 2.0 * self._get_communication_overhead('h2d', region.param_size)
+ pref_ep = ExecutionPeriod(
+ start_time=pref_start_time, end_time=pref_end_time)
+ if is_fwd:
+ self.fwd_reg_to_pref[region.r_id] = pref_ep
+ else:
+ self.bwd_reg_to_pref[region.r_id] = pref_ep
+ self.last_h2d = pref_ep
+
+ def _insert_comp_exec(self, region: Region, is_fwd: bool = True):
+ """
+ Insert computation execution period of the current region to the end of the computing stream
+ """
+
+ if is_fwd:
+ reg_to_comp = self.fwd_reg_to_comp
+ reg_to_pref = self.fwd_reg_to_pref
+ flop_key = 'fwd_flop'
+ else:
+ reg_to_comp = self.bwd_reg_to_comp
+ reg_to_pref = self.bwd_reg_to_pref
+ flop_key = 'bwd_flop'
+ comp_start_time = max(self.last_comp.end_time, reg_to_pref.get(
+ region.r_id, ExecutionPeriod(0, 0)).end_time)
+ comp_end_time = comp_start_time + \
+ sum([self._get_computing_overhead(node.meta.get(flop_key, 0))
+ for node in region.nodes])
+ comp_ep = ExecutionPeriod(
+ start_time=comp_start_time, end_time=comp_end_time)
+ reg_to_comp[region.r_id] = comp_ep
+ self.last_comp = comp_ep
+
+ def _insert_d2h_exec(self, region: Region):
+ """
+ Insert gradient offload execution period of the current region to the end of the d2h stream
+ """
+
+ offl_start_time = max(self.last_d2h.end_time, self.last_comp.end_time)
+ offl_end_time = offl_start_time + \
+ self._get_communication_overhead('d2h', region.param_size)
+ offl_ep = ExecutionPeriod(
+ start_time=offl_start_time, end_time=offl_end_time)
+ self.bwd_reg_to_offl_waiting[region.r_id] = offl_ep
+ self.last_d2h = offl_ep
+
+ def _eval_fwd_cost_per_region(self, region: Region):
+ """
+ Evaluate computation and communication execution period of the region in forward pass.
+ """
+
+ # upload parameters of the first region
+ if region.r_id == 0:
+ self._insert_h2d_exec(region)
+
+ # prefetch parameters of the next region
+ fwd_prefetch_region = region.fwd_prefetch_region
+ if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
+ self._insert_h2d_exec(fwd_prefetch_region)
+
+ # execute computation
+ self._insert_comp_exec(region)
+
+ def _eval_fwd_mem_per_region(self, region: Region):
+ """
+ Evaluate the runtime and peak memory when the forward execution reaches the current region.
+ """
+
+ # upload parameters of the current region
+ if region.r_id <= 0:
+ self.runtime_mem += region.param_size
+ self.fwd_reg_flow[region.r_id, region.r_id] = True
+ else:
+ self.fwd_reg_flow[region.r_id] = self.fwd_reg_flow[region.r_id - 1]
+ self.fwd_reg_flow[region.r_id,
+ self.reg_buffer_to_free] = False
+ self.reg_buffer_to_free.clear()
+
+ # prefetch parameters of the next region
+ fwd_prefetch_region = region.fwd_prefetch_region
+ if fwd_prefetch_region and requires_upload_p_in_fwd(self.region_list[fwd_prefetch_region.shared_rid]):
+ self.runtime_mem += fwd_prefetch_region.param_size
+ self.fwd_reg_flow[region.r_id,
+ fwd_prefetch_region.r_id] = True
+
+ for node in region.nodes:
+ self.runtime_mem += calculate_fwd_tmp(node) + \
+ calculate_fwd_out(node)
+ self.peak_mem = max(self.runtime_mem, self.peak_mem)
+
+ self.total_mem_saving += node.node_info.runtime_fwd_mem - self.runtime_mem
+ self.fwd_node_mem[node] = self.runtime_mem
+
+ if region.need_offload:
+ self.runtime_mem -= region.param_size
+
+ assert len(
+ self.reg_buffer_to_free) <= 1, f'{len(self.reg_buffer_to_free)}'
+ self.reg_buffer_to_free.append(region.r_id)
+
+ def _eval_bwd_cost_per_region(self, region: Region):
+ """
+ Evaluate computation and communication execution period of the region in backward pass.
+ """
+
+ # upload parameters of the current region
+ if region.is_syn:
+ assert region.need_offload
+ self._insert_h2d_exec(region, is_fwd=False)
+
+ # prefetch parameters of the region choiced, which is parallel to computation
+ if region.bwd_prefetch_region is not None:
+ self._insert_h2d_exec(region.bwd_prefetch_region, is_fwd=False)
+
+ # execute computation
+ self._insert_comp_exec(region, is_fwd=False)
+
+ # offload gradient
+ if requires_offload_g_in_bwd(region):
+ self._insert_d2h_exec(region)
+
+ assert len(self.reg_buffer_to_free) == 0
+ for reg_id, offl_exec in self.bwd_reg_to_offl_waiting.items():
+ if offl_exec.end_time >= self.last_comp.start_time:
+ break
+ self.reg_buffer_to_free.append(reg_id)
+ self.bwd_reg_to_offl_freed[reg_id] = offl_exec
+
+ for reg_id in self.reg_buffer_to_free:
+ self.bwd_reg_to_offl_waiting.pop(reg_id)
+
+ def _eval_bwd_mem_per_region(self, region: Region):
+ """
+ Evaluate the runtime and peak memory when the backward execution reaches the current region.
+ """
+
+ if region.r_id + 1 < self.region_num:
+ self.bwd_reg_flow[region.r_id] = self.bwd_reg_flow[region.r_id + 1]
+ else:
+ self.bwd_reg_flow[region.r_id] = self.fwd_reg_flow[-1]
+ self.bwd_reg_flow[region.r_id,
+ self.reg_buffer_to_free] = False
+
+ # free gradients in the buffer
+ while len(self.reg_buffer_to_free):
+ reg_id = self.reg_buffer_to_free.pop(0)
+ self.runtime_mem -= self.region_list[reg_id].param_size
+
+ # upload parameters of the current region
+ if region.is_syn:
+ self.runtime_mem += region.param_size
+ self.bwd_reg_flow[region.r_id, region.r_id] = True
+
+ # prefetch parameters of the region choiced
+ bwd_prefetch_region = region.bwd_prefetch_region
+ if bwd_prefetch_region:
+ self.runtime_mem += bwd_prefetch_region.param_size
+ self.bwd_reg_flow[region.r_id,
+ bwd_prefetch_region.r_id] = True
+
+ # add the gradient of the parameter
+ if region.r_id < region.shared_rid:
+ # gradient accumulation is required for shared parameters
+ self.runtime_mem += 2.0 * region.param_size
+ else:
+ self.runtime_mem += region.param_size
+
+ for node in region.nodes.__reversed__():
+
+ self.runtime_mem -= calculate_fwd_out(node)
+ self.runtime_mem += node.meta['bwd_mem_tmp'] + \
+ node.meta['bwd_mem_out']
+ self.peak_mem = max(self.runtime_mem, self.peak_mem)
+
+ # The memory savings of a node may be negative due to parameter prefetch.
+ self.total_mem_saving += node.node_info.runtime_bwd_mem - self.runtime_mem
+
+ self.bwd_node_mem[node] = self.runtime_mem
+
+ self.runtime_mem -= (node.meta['bwd_mem_tmp'] +
+ calculate_fwd_tmp(node))
+
+ # free bwd_mem_out
+ self.bwd_node_deps[node] = len(node.all_input_nodes)
+ for user_node in node.users:
+ if user_node in self.bwd_node_deps:
+ self.bwd_node_deps[user_node] -= 1
+ if self.bwd_node_deps[user_node] <= 0:
+ self.runtime_mem -= user_node.meta['bwd_mem_out']
+
+ if self.runtime_mem < 0:
+ raise ValueError(f"region id: {region.r_id}, node name: {node.name}, "
+ f"runtime_mem: {self.runtime_mem / 1024 ** 2:.3f}MB ---"
+ f"runtime memory computed less than 0, which is miscalculated!")
+
+ # release parameters of the region
+ if requires_release_p_in_bwd(self.region_list[region.shared_rid]):
+ self.runtime_mem -= region.param_size
diff --git a/colossalai/auto_parallel/offload/util.py b/colossalai/auto_parallel/offload/util.py
new file mode 100644
index 000000000000..a99c4eb20225
--- /dev/null
+++ b/colossalai/auto_parallel/offload/util.py
@@ -0,0 +1,90 @@
+from dataclasses import dataclass
+from typing import List
+import torch
+from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp
+
+from .region import Region
+
+
+@dataclass
+class NodeInfo:
+ node_id: int = 0
+ runtime_fwd_mem: float = 0
+ runtime_bwd_mem: float = 0
+
+class NvDevicePower:
+ """
+ NVIDIA GPU computing performance (TFLOPs).
+ """
+
+ RTX3080_FP16 = 70
+ RTX3080_FP32 = 34.1
+
+ RTX3090_FP16 = 71
+ RTX3090_FP32 = 35.7
+
+ V100_FP16 = 31.4
+ V100_FP32 = 15.7
+
+ A100_FP16 = 78
+ A100_FP32 = 19.5
+
+
+class GlobalRuntimeInfo:
+ h2d_stream = torch.cuda.Stream()
+ d2h_stream = torch.cuda.Stream()
+ fwd_prefetch_event_map = {}
+ bwd_prefetch_event_map = {}
+ region_list = []
+
+
+def compute_act_peak_mem(region_list: List[Region]) -> float:
+ act_peak_mem = 0
+ runtime_mem = 0
+ # forward
+ for region in region_list:
+ for node in region.nodes:
+ runtime_mem = runtime_mem + \
+ calculate_fwd_tmp(node) + calculate_fwd_out(node)
+ act_peak_mem = max(runtime_mem, act_peak_mem)
+ # backward
+ bwd_deps = {}
+ for region in region_list.__reversed__():
+ for node in region.nodes.__reversed__():
+ runtime_mem -= calculate_fwd_out(node)
+ runtime_mem = runtime_mem + \
+ node.meta['bwd_mem_tmp'] + node.meta['bwd_mem_out']
+
+ act_peak_mem = max(runtime_mem, act_peak_mem)
+
+ runtime_mem = runtime_mem - \
+ node.meta['bwd_mem_tmp'] - calculate_fwd_tmp(node)
+
+ # free bwd_mem_out
+ bwd_deps[node] = len(node.all_input_nodes)
+ for user_node in node.users:
+ if user_node in bwd_deps:
+ bwd_deps[user_node] -= 1
+ if bwd_deps[user_node] <= 0:
+ runtime_mem -= user_node.meta['bwd_mem_out']
+
+ return act_peak_mem
+
+def compute_max_param_mem(region_list: List[Region]) -> float:
+ return max(region.param_size for region in region_list)
+
+def compute_total_param_mem(region_list: List[Region]) -> float:
+ return sum(region.param_size for region in region_list if region.r_id <= region.shared_rid)
+
+def requires_upload_p_in_fwd(shared_reg: Region):
+ return (shared_reg.r_id >= shared_reg.shared_rid) or (
+ shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
+
+def requires_release_p_in_bwd(shared_reg: Region):
+ return (shared_reg.r_id >= shared_reg.shared_rid) or (
+ shared_reg.r_id < shared_reg.shared_rid and shared_reg.need_offload)
+
+def requires_offload_g_in_bwd(region: Region):
+ return region.param_size and (region.r_id <= region.shared_rid)
+
+
diff --git a/colossalai/auto_parallel/passes/constants.py b/colossalai/auto_parallel/passes/constants.py
index b86088474644..485a87492f4c 100644
--- a/colossalai/auto_parallel/passes/constants.py
+++ b/colossalai/auto_parallel/passes/constants.py
@@ -6,3 +6,8 @@
torch.nn.ReLU,
torch.nn.Softmax,
]
+
+# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
+# This list could be extended if any other method has the same
+# argument style as view and reshape.
+SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape]
diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py
index 7f2aac42b7f8..9d83f105748b 100644
--- a/colossalai/auto_parallel/passes/runtime_apply_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py
@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
+ if 'activation_checkpoint' in user_node.meta:
+ shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
+
+ if 'activation_checkpoint' in node.meta:
+ comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
+
+ return gm
+
+
+def _act_annotataion_pass(gm: torch.fx.GraphModule):
+ """
+ This pass is used to add the act annotation to the new inserted nodes.
+ """
+ mod_graph = gm.graph
+ nodes = tuple(mod_graph.nodes)
+
+ for node in nodes:
+ if not hasattr(node.meta, 'activation_checkpoint'):
+ from .runtime_preparation_pass import size_processing
+
+ user_act_annotation = -1
+ input_act_annotation = -1
+ for user_node in node.users.keys():
+ if 'activation_checkpoint' in user_node.meta:
+ user_act_annotation = user_node.meta['activation_checkpoint']
+ break
+ for input_node in node._input_nodes.keys():
+ if 'activation_checkpoint' in input_node.meta:
+ input_act_annotation = input_node.meta['activation_checkpoint']
+ break
+ if user_act_annotation == input_act_annotation and user_act_annotation != -1:
+ node.meta['activation_checkpoint'] = user_act_annotation
+
return gm
diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
index f9b89026393d..e63bfdfe730c 100644
--- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
@@ -19,6 +19,8 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
+from .constants import SHAPE_ARGUMENT_OPS
+
shape_consistency_manager = ShapeConsistencyManager()
@@ -51,23 +53,16 @@ def size_processing(size: Union[int, torch.Size],
return size
-def _solution_annotatation(gm: torch.fx.GraphModule,
- solution: List[int],
- strategies_constructor: StrategiesConstructor = None):
+def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
+ strategies_constructor: StrategiesConstructor):
"""
This method is used to stick the solution strategy to the nodes and add the information
required in runtime into graph as placeholder nodes.
"""
mod_graph = gm.graph
- # TODO: In future PR, strategies_constructor should be a required argument,
- # instead of optional argument. This is because we don't need to consider nodes with
- # no strategy in runtime preparation pass.
- if strategies_constructor is not None:
- nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
- no_strategy_nodes = strategies_constructor.no_strategy_nodes
- else:
- nodes = tuple(mod_graph.nodes)
- no_strategy_nodes = []
+
+ nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
+ no_strategy_nodes = strategies_constructor.no_strategy_nodes
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict = {}
@@ -97,6 +92,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
target_sharding_specs.append(target_sharding_spec)
sharding_spec_convert_dict[index] = target_sharding_specs
setattr(node, 'target_sharding_specs', target_sharding_specs)
+
# the get_attr node strategy is kind of pending strategy, which means we will change it
# to the same strategy of the user node.
if node.op == 'get_attr':
@@ -134,7 +130,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
-def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
+def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
need to be converted to the size of original tensor and managed by the users, such as torch.view,
@@ -145,6 +141,80 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
nodes = tuple(mod_graph.nodes)
node_pairs = {}
+ # DeviceMesh information instructs the scaling of the size value
+ device_mesh_info = {}
+ for dim, dim_size in enumerate(device_mesh.mesh_shape):
+ device_mesh_info[dim] = dim_size
+
+ def _extract_target_dim(node):
+ '''
+ A helper function to etract the target dimension from size node.
+ There are two usages of torch.Tensor.size:
+ 1. tensor.size()
+ 2. tensor.size(dim)
+
+ If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
+ Otherwise, the output will be in type of torch.Size and this function will return None.
+ '''
+ target_dim = None
+ if len(node.args) > 1:
+ target_dim = node.args[1]
+ if target_dim < 0:
+ target_dim += node.args[0]._meta_data.dim()
+ return target_dim
+
+ def _post_processing(node, size_processing_node):
+ '''
+ This function is used to process the dependency between the size node and its users after
+ inserting the size_process_node.
+ '''
+ # store original node and processing node pair in node_pairs dictioanry
+ # It will be used to replace the original node with processing node in slice object
+ node_pairs[node] = size_processing_node
+ size_processing_node._meta_data = node._meta_data
+ if 'activation_checkpoint' in node.meta:
+ size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
+
+ user_list = list(node.users.keys())
+ for user in user_list:
+ if user == size_processing_node:
+ continue
+ new_args = list(user.args)
+ new_kwargs = dict(user.kwargs)
+ # the origin node may be a positional argument or key word argument of user node
+ if node in new_args:
+ # substitute the origin node with size_processing_node
+ new_args[new_args.index(node)] = size_processing_node
+ user.args = tuple(new_args)
+ elif str(node) in new_kwargs:
+ # substitute the origin node with size_processing_node
+ new_kwargs[str(node)] = size_processing_node
+ user.kwargs = new_kwargs
+
+ def _update_slice_object_args(slice_object):
+ '''
+ This function is used to update the slice object argument list.
+ If the slice object contains the Node argument, then the size node will be replaced with
+ '''
+ if isinstance(slice_object, slice):
+ start = slice_object.start
+ stop = slice_object.stop
+ step = slice_object.step
+ if start in node_pairs:
+ start = node_pairs[start]
+ if stop in node_pairs:
+ stop = node_pairs[stop]
+ if step in node_pairs:
+ step = node_pairs[step]
+ return slice(start, stop, step)
+ elif isinstance(slice_object, int):
+ if slice_object in node_pairs:
+ return node_pairs[slice_object]
+ else:
+ return slice_object
+ else:
+ raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
+
for node in nodes:
if node.op == 'call_method' and node.target == 'size':
@@ -154,47 +224,15 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
sharding_spec = node.args[0].sharding_spec
dim_partition_dict = sharding_spec.dim_partition_dict
- # there are two usages of torch.Tensor.size:
- # tensor.size()
- # tensor.size(dim)
- # if a target_dim is assigned, then the output will be
- # in type of int, instead of torch.Size
- target_dim = None
- if len(node.args) > 1:
- target_dim = node.args[1]
- if target_dim < 0:
- target_dim += node.args[0]._meta_data.dim()
-
- # DeviceMesh information instructs the scaling of the size value
- device_mesh_info = {}
- for dim, dim_size in enumerate(device_mesh.mesh_shape):
- device_mesh_info[dim] = dim_size
+ target_dim = _extract_target_dim(node)
+ # insert size_processing node
with mod_graph.inserting_after(node):
size_processing_node = mod_graph.create_node('call_function',
size_processing,
args=(node, dim_partition_dict, device_mesh_info,
target_dim, node.name))
- # store original node and processing node pair in node_pairs dictioanry
- # It will be used to replace the original node with processing node in slice object
- node_pairs[node] = size_processing_node
- size_processing_node._meta_data = node._meta_data
-
- user_list = list(node.users.keys())
- for user in user_list:
- if user == size_processing_node:
- continue
- new_args = list(user.args)
- new_kwargs = dict(user.kwargs)
- # the origin node may be a positional argument or key word argument of user node
- if node in new_args:
- # substitute the origin node with size_processing_node
- new_args[new_args.index(node)] = size_processing_node
- user.args = tuple(new_args)
- elif str(node) in new_kwargs:
- # substitute the origin node with size_processing_node
- new_kwargs[str(node)] = size_processing_node
- user.kwargs = new_kwargs
+ _post_processing(node, size_processing_node)
if node.op == 'call_function' and node.target == operator.getitem:
@@ -215,14 +253,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# In this pass, we need process the last two cases because
# node arguments may potentially appear in these cases.
if isinstance(getitem_index, slice):
- new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
- if getitem_index.start in node_pairs:
- new_start = node_pairs[getitem_index.start]
- elif getitem_index.stop in node_pairs:
- new_stop = node_pairs[getitem_index.stop]
- elif getitem_index.step in node_pairs:
- new_step = node_pairs[getitem_index.step]
- new_slice_item = slice(new_start, new_stop, new_step)
+ new_slice_item = _update_slice_object_args(getitem_index)
new_args = (node.args[0], new_slice_item)
node.args = new_args
@@ -235,16 +266,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
if slice_item is None:
new_slice_items.append(None)
continue
-
- new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
-
- if slice_item.start in node_pairs:
- new_start = node_pairs[slice_item.start]
- elif slice_item.stop in node_pairs:
- new_stop = node_pairs[slice_item.stop]
- elif slice_item.step in node_pairs:
- new_step = node_pairs[slice_item.step]
- new_slice_item = slice(new_start, new_stop, new_step)
+ new_slice_item = _update_slice_object_args(slice_item)
new_slice_items.append(new_slice_item)
new_args = (node.args[0], tuple(new_slice_items))
@@ -253,104 +275,109 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
return gm
-def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
+def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
- for node in nodes:
- # skip the placeholder node added in _solution_annotation pass
- if not hasattr(node, 'sharding_spec'):
- continue
-
- def _process_sharding_spec(sharding_spec):
- if isinstance(sharding_spec, ShardingSpec):
- dim_partition_dict = sharding_spec.dim_partition_dict
- device_mesh = sharding_spec.device_mesh
- return dim_partition_dict, device_mesh
- if sharding_spec is None:
- return None, None
- assert isinstance(sharding_spec,
- (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
-
- device_mesh = sharding_spec[0].device_mesh
- dim_partition_dict = []
- for element in sharding_spec:
- dim_partition_dict.append(_process_sharding_spec(element))
- return dim_partition_dict, sharding_spec
-
- output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
+ def _extract_info_from_sharding_spec(sharding_spec):
+ '''
+ This function is used to extract the dim_partition_dict and device_mesh from
+ sharding spec instance or a list of sharding spec.
+ '''
+ if isinstance(sharding_spec, ShardingSpec):
+ dim_partition_dict = sharding_spec.dim_partition_dict
+ device_mesh = sharding_spec.device_mesh
+ return dim_partition_dict, device_mesh
+ if sharding_spec is None:
+ return None, None
+ assert isinstance(sharding_spec,
+ (tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
+
+ device_mesh = sharding_spec[0].device_mesh
+ dim_partition_dict = []
+ for element in sharding_spec:
+ dim_partition_dict.append(_extract_info_from_sharding_spec(element))
+ return dim_partition_dict, sharding_spec
+
+ def _process_node_arguments(node):
new_args = []
+ for arg in node.args:
+ # There are two args style:
+ # 1. (input, *shape)
+ # 2. (input, shape)
+ # We will extract the elements from shape and add them into the new_args
+ # Finally, the args style of new_args will be unified to (input, *shape)
+ if isinstance(arg, Node):
+ if isinstance(arg._meta_data, (tuple, list)):
+ new_args.extend(arg._meta_data)
+ elif isinstance(arg._meta_data, int):
+ new_args.append(arg._meta_data)
+ else:
+ new_args.append(arg)
+ else:
+ assert isinstance(arg,
+ (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
+ if isinstance(arg, (tuple, list)):
+ new_args.extend(arg)
+ else:
+ new_args.append(arg)
+ return new_args
+
+ def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
+ new_args = _process_node_arguments(node)
+ if node.op == 'call_method':
+ args_to_process = list(new_args[1:])
+ else:
+ args_to_process = list(new_args)
+ for dim, shard_dims in dim_partition_dict.items():
+ total_shard_size = 1
+ for shard_dim in shard_dims:
+ total_shard_size *= device_mesh.shape[shard_dim]
+
+ # we will skip the dim with -1 value
+ if args_to_process[dim] == -1:
+ continue
+ else:
+ # TODO: add assertion here to make sure the dim size is divisible by total_shard_size
+ args_to_process[dim] //= total_shard_size
+
+ args_to_process = tuple(args_to_process)
if node.op == 'call_method':
- method = getattr(node.args[0]._meta_data.__class__, node.target)
- # process the node with (input, *shape) style args
- if method in (torch.Tensor.view, torch.Tensor.reshape):
-
- for arg in node.args:
- if isinstance(arg, Node):
- if isinstance(arg._meta_data, (int, tuple, list)):
- new_args.append(arg._meta_data)
- else:
- new_args.append(arg)
- else:
- assert isinstance(
- arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
- new_args.append(arg)
-
- for dim, shard_dims in output_dim_partition_dict.items():
- total_shard_size = 1
- for shard_dim in shard_dims:
- total_shard_size *= device_mesh.shape[shard_dim]
- # There are two ways to use torch.view:
- # 1. torch.view(input, *shape)
- # 2. torch.view(input, shape)
- if isinstance(new_args[1], int):
- # we will skip the dim with -1 value
- if new_args[dim + 1] == -1:
- continue
- else:
- new_args[dim + 1] //= total_shard_size
- else:
- new_args[1] = list(new_args[1])
- # we will skip the dim with -1 value
- if new_args[1][dim] == -1:
- continue
- else:
- new_args[1][dim] //= total_shard_size
- node.args = tuple(new_args)
+ new_args = (new_args[0],) + args_to_process
+ else:
+ new_args = args_to_process
+ node.args = new_args
+
+ def _filter_node_with_shape_args(node):
+ if node.op == 'call_method':
+ target = getattr(node.args[0]._meta_data.__class__, node.target)
elif node.op == 'call_function':
target = node.target
- # process the node with (input, torch.Size) style args
- if target in (torch.reshape,):
- for arg in node.args:
- if isinstance(arg, Node):
- if isinstance(arg._meta_data, (tuple, list)):
- new_args.append(list(arg._meta_data))
- else:
- new_args.append(arg)
- else:
- assert isinstance(
- arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
- new_args.append(list(arg))
-
- for dim, shard_dims in output_dim_partition_dict.items():
- # we will skip the dim with -1 value
- if new_args[1][dim] == -1:
- continue
- total_shard_size = 1
- for shard_dim in shard_dims:
- total_shard_size *= device_mesh.shape[shard_dim]
- new_args[1][dim] //= total_shard_size
- node.args = tuple(new_args)
+ else:
+ target = None
+
+ if target in SHAPE_ARGUMENT_OPS:
+ return True
+ return False
+
+ for node in nodes:
+ # skip the placeholder node added in _solution_annotation pass
+ if not hasattr(node, 'sharding_spec'):
+ continue
+
+ output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
+ if _filter_node_with_shape_args(node):
+ _scale_args_adapt_sharding_spec(output_dim_partition_dict, device_mesh, node)
return gm
-def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
+def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
@@ -359,6 +386,50 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
nodes = tuple(mod_graph.nodes)
# This stream is created for overlaping the communication and computation.
reduction_stream = torch.cuda.Stream()
+
+ def _add_hook_for_grad_communication(node, param):
+
+ comm_actions = node.best_strategy.communication_actions
+
+ def _filter_param_to_hook(node, op_data, comm_action):
+ if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
+ return True
+ if node.op == 'get_attr' and isinstance(
+ node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
+ return True
+ return False
+
+ for operation_data, comm_action in comm_actions.items():
+ comm_spec_to_use = comm_action.comm_spec
+ # register hook to the parameters
+ if _filter_param_to_hook(node, operation_data, comm_action):
+
+ def wrapper(param, comm_spec, stream, overlap):
+
+ def hook_fn(grad):
+ if overlap:
+ with torch.cuda.stream(stream):
+ _all_reduce(grad, comm_spec, async_op=True)
+ else:
+ _all_reduce(grad, comm_spec, async_op=False)
+
+ param.register_hook(hook_fn)
+
+ wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
+
+ def _shard_param(param, target_sharding_spec):
+ # apply the sharding spec of parameters
+ if target_sharding_spec.dim_partition_dict != {}:
+ origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
+ setattr(param, 'sharding_spec', origin_sharding_spec)
+ # TODO: build a ColoParamter class to manager the distributed parameters
+ # we could use .data here, because all the operations just happen before the real training
+ # loop, so we don't need to track these operations in the autograd graph.
+ param = torch.nn.Parameter(
+ shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
+ target_sharding_spec).detach().clone())
+ return param
+
for node in nodes:
if node.op == 'call_module':
target_module = node.graph.owning_module.get_submodule(node.target)
@@ -368,31 +439,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
setattr(target_module, 'processed', True)
for name, param in target_module.named_parameters():
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
- # apply the sharding spec of parameters
- if target_sharding_spec.dim_partition_dict != {}:
- origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
- setattr(param, 'sharding_spec', origin_sharding_spec)
- # TODO: build a ColoParamter class to manager the distributed parameters
- # we could use .data here, because all the operations just happen before the real training
- # loop, so we don't need to track these operations in the autograd graph.
- param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
- param.data, param.sharding_spec, target_sharding_spec).detach().clone()
+ param = _shard_param(param, target_sharding_spec)
setattr(target_module, name, param)
- comm_actions = node.best_strategy.communication_actions
- for operation_data, comm_action in comm_actions.items():
- comm_spec_to_use = comm_action.comm_spec
- # register hook to the parameters
- if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
-
- def wrapper(param, comm_spec):
-
- def hook_fn(grad):
- _all_reduce(grad, comm_spec, async_op=False)
-
- param.register_hook(hook_fn)
-
- wrapper(param, comm_spec_to_use)
+ _add_hook_for_grad_communication(node, param)
sharded_buffer_dict = {}
# apply the sharding spec of buffers
@@ -420,32 +470,12 @@ def hook_fn(grad):
target = getattr(target_module, atoms[-1])
target_sharding_spec = node.sharding_spec
- if target_sharding_spec.dim_partition_dict != {}:
- origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
- setattr(target, 'sharding_spec', origin_sharding_spec)
- # TODO: build a ColoParamter class to manager the distributed parameters
- # we could use .data here, because all the operations just happen before the real training
- # loop, so we don't need to track these operations in the autograd graph.
- target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
- target.data, target.sharding_spec, target_sharding_spec).detach().clone()
+ target = _shard_param(target, target_sharding_spec)
assert hasattr(target_module, atoms[-1])
setattr(target_module, atoms[-1], target)
+ _add_hook_for_grad_communication(node, target)
- comm_actions = node.best_strategy.communication_actions
- for operation_data, comm_action in comm_actions.items():
- comm_spec_to_use = comm_action.comm_spec
- # register hook to the parameters
- if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
-
- def wrapper(param, comm_spec):
-
- def hook_fn(grad):
- _all_reduce(grad, comm_spec, async_op=False)
-
- param.register_hook(hook_fn)
-
- wrapper(target, comm_spec_to_use)
return gm
@@ -459,13 +489,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
- strategies_constructor: StrategiesConstructor = None):
- gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
+ strategies_constructor: StrategiesConstructor,
+ overlap=False):
+ gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
gm, solution, strategies_constructor)
- gm = _size_value_converting(gm, device_mesh)
- gm = _node_args_converting(gm, device_mesh)
+ gm = size_value_converting_pass(gm, device_mesh)
+ gm = node_args_converting_pass(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
- gm = _module_params_sharding(gm, device_mesh)
+ gm = module_params_sharding_pass(gm, device_mesh, overlap=overlap)
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py
deleted file mode 100644
index bd47f2adf3d6..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from .cost_graph import CostGraph
-from .graph_analysis import GraphAnalyser
-from .options import SolverOptions
-from .sharding_strategy import ShardingStrategy, StrategiesVector
-from .solver import Solver
-from .strategies_constructor import StrategiesConstructor
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py b/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py
deleted file mode 100644
index d6af7ad57154..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/_utils.py
+++ /dev/null
@@ -1,142 +0,0 @@
-import functools
-import operator
-import warnings
-from functools import reduce
-from typing import Dict, List, Optional, Union
-
-import torch
-from torch.fx.node import Node
-
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-from .constants import INFINITY_COST
-
-
-def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
- dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
- """
- Generate the sharding spec of the tensor based on the given dim_partition_dict.
-
-
- Args:
- input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
- device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
- dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
- """
-
- if isinstance(input_, Node):
- assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
- meta_tensor = input_._meta_data
- assert meta_tensor is not None, "The given node's _meta_data attribute is None"
- shape = meta_tensor.shape
- elif isinstance(input_, torch.Tensor):
- shape = input_.shape
- else:
- raise TypeError(
- f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
- )
- for dim_index, sharding_index_list in dim_partition_dict.items():
- sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
- sharding_size = reduce(operator.mul, sharding_list, 1)
- assert shape[
- dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
-
- sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
- return sharding_spec
-
-
-def generate_resharding_costs(nodes: List[Node],
- sharding_specs: List[ShardingSpec],
- count_backward: Optional[bool] = True,
- dtype: Optional[torch.dtype] = None,
- index=None):
- '''
- Compute the resharding costs with this specific strategy.
-
- Argument:
- nodes (List[Node]): a list of nodes
- sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
- count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
- dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
- '''
- # The resharding_cost of weight is counted due to sharing weight cases.
- resharding_costs = {}
- size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
-
- # shape consistency manager is a singleton class
- shape_consistency_manager = ShapeConsistencyManager()
-
- for input_node, input_spec in zip(nodes, sharding_specs):
- resharding_costs[input_node] = []
- for strategy in input_node.strategies_vector:
- input_sharding_spec = strategy.output_sharding_spec
- if not isinstance(input_sharding_spec, ShardingSpec):
- assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
- input_sharding_spec = input_sharding_spec[index]
- assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
- try:
- # compute the resharding cost
- _, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
- input_sharding_spec, input_spec)
-
- # we need multiply the size of elem dtype to get correct communication cost
- resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
- except AssertionError as e:
- warnings.warn(f'{e}')
- resharding_cost = INFINITY_COST
- resharding_costs[input_node].append(resharding_cost)
- return resharding_costs
-
-
-def ignore_sharding_exception(func):
- """
- A function wrapper which executes the function with a specified seed.
- """
-
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- try:
- rst = func(*args, **kwargs)
- return rst
- except AssertionError as e:
- warnings.warn(f'{e}')
-
- return wrapper
-
-
-def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
- dim_partition_list = []
- # enumerate all the 2D sharding cases
- for i in range(dim_size):
- for j in range(i + 1, dim_size):
- dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
- dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
- dim_partition_list.append(dim_partition_dict_0)
- dim_partition_list.append(dim_partition_dict_1)
- for i in range(dim_size):
- dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
- dim_partition_list.append(dim_partition_dict_flatten)
-
- return dim_partition_list
-
-
-def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
- dim_partition_list = []
- # enumerate all the 1D sharding cases
- for i in range(dim_size):
- dim_partition_dict_0 = {i: [mesh_dim_0]}
- dim_partition_list.append(dim_partition_dict_0)
-
- return dim_partition_list
-
-
-def generate_sharding_size(dim_partition_dict, device_mesh):
- total_sharding_size = 1
- for mesh_dim_list in dim_partition_dict.values():
- mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
- sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
- total_sharding_size *= sharding_size
-
- return total_sharding_size
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/constants.py b/colossalai/auto_parallel/tensor_shard/deprecated/constants.py
deleted file mode 100644
index 91c20d343487..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/constants.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import torch
-import operator
-
-__all__ = [
- 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
- 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
- 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
-]
-
-ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
-ELEMENTWISE_FUNC_OP = [
- torch.abs,
- torch.cos,
- torch.exp,
- operator.neg,
- torch.multiply,
- torch.nn.functional.relu,
- torch.nn.functional.dropout,
- # softmax should not be here
- torch.nn.functional.softmax
-]
-ELEMENTWISE_METHOD_OP = [
- torch.Tensor.to,
- torch.Tensor.type,
- # TODO: contiguous maybe need some extra processes.
- torch.Tensor.contiguous
-]
-RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
-RESHAPE_METHOD_OP = [
- torch.Tensor.view,
- torch.Tensor.unsqueeze,
- torch.Tensor.split,
- torch.Tensor.permute,
- torch.Tensor.transpose,
-]
-BCAST_FUNC_OP = [
- torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
- operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
-]
-CONV_MODULE_OP = [
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
- torch.nn.ConvTranspose3d
-]
-CONV_FUNC_OP = [
- torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
-]
-EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
-LINEAR_MODULE_OP = [torch.nn.Linear]
-LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
-BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
-LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
-POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
-NON_PARAM_FUNC_OP = [
- torch.flatten,
- torch.reshape,
- torch.abs,
- torch.cos,
- torch.exp,
- operator.neg,
- torch.multiply,
- torch.nn.functional.relu,
- torch.nn.functional.dropout,
- torch.flatten,
- torch.where,
- operator.pow,
- torch.pow,
- torch.tanh,
- torch.add,
- torch.sub,
- torch.mul,
- torch.div,
- torch.floor_divide,
- torch.true_divide,
- operator.add,
- operator.sub,
- operator.mul,
- operator.floordiv,
- operator.truediv,
- # softmax should not be here
- torch.nn.functional.softmax
-]
-
-INFINITY_COST = 1e13
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py b/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py
deleted file mode 100644
index 239d02115d0e..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py
+++ /dev/null
@@ -1,172 +0,0 @@
-from typing import List
-import math
-from torch.fx.node import Node
-from .constants import INFINITY_COST
-
-
-class CostGraph:
- '''
- A graph data structure to simplify the edge cost graph. It has two main functions:
- 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
- CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
- 2. To reduce the searching space, we merge computationally-trivial operators, such as
- element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
- be given by the StrategiesVector depending on the type of target node and following nodes.
-
- Argument:
- leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
- simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
- '''
-
- def __init__(self, leaf_strategies, simplify=True):
- self.leaf_strategies = leaf_strategies
- self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
- # stores number of strategies in each node
- self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
- # extra_node_costs will store the extra costs introduced by merging nodes
- self.extra_node_costs = {}
- self.following_dict = {}
- self.simplify = simplify
- self._build_cost_graph()
-
- def _remove_invalid_node(self, node, attr_name):
- remove_list = []
- target_node_list = getattr(node, attr_name, [])
- for target_node in target_node_list:
- if target_node not in self.nodes:
- remove_list.append(target_node)
- for element in remove_list:
- target_node_list.remove(element)
-
- def _build_cost_graph(self):
- '''
- This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
- set to node.
- '''
- self.edge_costs = {}
- if self.simplify:
- self.merge_pair = []
- for strategies_vector in self.leaf_strategies:
- # build edge_cost
- dst_node = strategies_vector.node
- for src_node in strategies_vector.predecessor_nodes:
- if src_node not in self.nodes:
- continue
- node_pair = (src_node, dst_node)
- # src_index = strategies_vector.predecessor_nodes.index(src_node)
- edge_cost = {}
- for i in range(len(strategies_vector)):
- for j in range(len(src_node.strategies_vector)):
- edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
- self.edge_costs[node_pair] = edge_cost
- # add parents and children attribute to node
- setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
- setattr(dst_node, 'children', strategies_vector.successor_nodes)
- self._remove_invalid_node(dst_node, 'parents')
- self._remove_invalid_node(dst_node, 'children')
-
- if self.simplify and strategies_vector.check_merge():
- for followed_node in strategies_vector.predecessor_nodes:
- self.merge_pair.append((followed_node, dst_node))
-
- def get_edge_cost(self, src_node, dst_node):
- return self.edge_costs[(src_node, dst_node)]
-
- def merge_node(self, src_node, dst_node):
- '''
- To merge dst_node into src_node, we need to do it in following steps:
-
- 1. For each strategy in dst_node, we need to pick an appropriate strategy
- of src_node to merge, it is important because the logical resharding costs
- between the parents node of src_node and merged node depend on the src_node
- strategies dispatching. For example, for the graph 0->1->2, after merging node 1
- into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
- x represents the picking strategy of node 1 merged into node 2 strategy 0.
-
- 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
- contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
- another is the origin extra costs in src_node strategy.
-
- 3. Build connections between new node pairs, and remove the src_node after all consumer nodes
- detached from it.
-
- Argument:
- src_node(Node): The node will be merged into dst_node.
- dst_node(Node): The node to integrate src_node.
- '''
- src_node_index = dst_node.parents.index(src_node)
- # build merge_map
- merge_map = {}
- for src_index, strategy in enumerate(src_node.strategies_vector):
- min_cost = INFINITY_COST
- lowest_cost_index = -1
- for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
- resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
- if resharding_cost <= min_cost:
- min_cost = resharding_cost
- lowest_cost_index = dst_index
- merge_map[src_index] = lowest_cost_index
-
- # extra_node_cost for src node
- self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
- for src_index, strategy in enumerate(src_node.strategies_vector):
- target_strate_index = merge_map[src_index]
- target_strategy = dst_node.strategies_vector[target_strate_index]
- self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index]
- if dst_node in self.extra_node_costs:
- self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
-
- # add new node pair to cost graph
- for child_node in dst_node.children:
- new_node_pair = (src_node, child_node)
- old_node_pair = (dst_node, child_node)
- if new_node_pair in self.edge_costs:
- continue
- edge_cost = {}
- for i in range(self.node_lens[src_node]):
- for j in range(self.node_lens[child_node]):
- dst_strate_index = merge_map[i]
- # dst_strategy = dst_node.strategies_vector[dst_strate_index]
- edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
- if new_node_pair not in self.edge_costs:
- self.edge_costs[new_node_pair] = edge_cost
- else:
- # we should accumulate the resharding costs if args of child node contain
- # both src node and dst node.
- for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
- self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
-
- # connect src node and children of dst node
- dst_node.parents.remove(src_node)
- src_node.children.remove(dst_node)
- self.edge_costs.pop((src_node, dst_node))
- for child_node in dst_node.children:
- if child_node not in src_node.children:
- src_node.children.append(child_node)
- if src_node not in child_node.parents:
- child_node.parents.append(src_node)
- # remove dst node from cost graph when dst node has no producer.
- if len(dst_node.parents) == 0:
- child_node.parents.remove(dst_node)
- node_pair = (dst_node, child_node)
- self.edge_costs.pop(node_pair)
- if len(dst_node.parents) == 0:
- self.following_dict[dst_node] = src_node
- dst_node.children = []
-
- def _reindexing_src(self, src):
- if src not in self.following_dict:
- return src
- return self._reindexing_src(self.following_dict[src])
-
- def simplify_graph(self):
- if not self.simplify:
- return
- self.merge_pair.reverse()
- for (src_node, dst_node) in self.merge_pair:
- self.merge_node(src_node, dst_node)
- self.merge_pair.reverse()
- reindexing_following_dict = {}
- for dst, src in self.following_dict.items():
- reindexing_following_dict[dst] = self._reindexing_src(src)
- self.following_dict = reindexing_following_dict
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py b/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py
deleted file mode 100644
index 831e7eadd179..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/graph_analysis.py
+++ /dev/null
@@ -1,163 +0,0 @@
-from dataclasses import dataclass
-from torch.fx.node import Node
-from torch.fx.graph import Graph
-from torch.fx.graph_module import GraphModule
-from collections import OrderedDict as ODict
-from typing import List, OrderedDict, Union, Any
-from colossalai.fx.passes.utils import get_node_module
-
-__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
-
-
-@dataclass
-class LiveVariable:
- """
- LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
- """
- name: str
- node: Node
- is_inplace: bool
-
-
-class LiveVariableVector(list):
- """
- LiveVariableVector is a data structure to store the list of LiveVariable objects.
- """
-
- def exists(self, name) -> bool:
- """
- Check if a variable has already existed in the current list by name.
- """
- for var in self:
- if name == var.name:
- return True
- return False
-
- def get(self, name) -> LiveVariable:
- for var in self:
- if name == var.name:
- return var
- raise KeyError(f"Variable {name} is not found")
-
- def copy(self) -> "LiveVariableVector":
- """
- Create a copy of this vector
- """
- vector = LiveVariableVector()
- for var in self:
- vector.append(var)
- return vector
-
-
-@dataclass
-class LiveStage:
- """
- LiveStage is a data structure to record the living variables at this current node.
- """
- name: str
- node: Node
- all_live_vars: LiveVariableVector
- unique_live_vars: LiveVariableVector
-
-
-class GraphAnalyser:
-
- def __init__(self, gm: GraphModule):
- self._gm = gm
- self._graph = gm.graph
-
- @property
- def gm(self) -> GraphModule:
- """
- Return the GraphModule object associated with this analyser.
- """
- return self._gm
-
- @property
- def graph(self) -> Graph:
- """
- Return the Graph object associated with this analyser.
- """
- return self._graph
-
- def liveness_analysis(self) -> List[LiveStage]:
- """
- Analyse the graph to obtain the variable liveness information. This function returns
- an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
- """
- compute_nodes = self.graph.nodes
- liveness_list = []
-
- # checked: record all variables created since the first stage
- # all: record the live variables only exist until the current stage.
- # this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
- # unique: record the unique live variables only exist until the current stage.
- # this is different from `all list` as some variables are duplicated.
- checked_variables = LiveVariableVector()
- all_live_variables = LiveVariableVector()
- unique_live_vars = LiveVariableVector()
-
- for idx, node in enumerate(compute_nodes):
- #############################
- # find new living variables #
- #############################
- # detect whether the current op is an in-place op
- # if it is an in-place op, we would deem it as a duplciate var
- is_inplace = False
- if node.op == 'call_function':
- # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
- if node.kwargs.get('inplace', False):
- is_inplace = True
- elif node.op == 'call_module':
- # to check if this is an inplace op such as torch.nn.Relu(inplace=True)
- module = get_node_module(node)
- if getattr(module, 'inplace', False):
- is_inplace = True
-
- # add the output var
- meta = getattr(node, '_meta_data', None)
- live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
- if not is_inplace:
- unique_live_vars.append(live_var)
- checked_variables.append(live_var)
- all_live_variables.append(live_var)
-
- # check if any input is not checked yet
- for arg in node.args:
- if not isinstance(arg, Node):
- continue
- arg_name = arg.name
- if not checked_variables.exists(arg_name):
- live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
- all_live_variables.append(live_var_from_arg)
- checked_variables.append(live_var_from_arg)
- unique_live_vars.append(live_var_from_arg)
-
- # TODO: add the logic to remove live variables
- # this should be completed if we are able to trace the backward compute graph
-
- # add this stage to liveness dict
- stage = LiveStage(name=node.name,
- node=node,
- all_live_vars=all_live_variables.copy(),
- unique_live_vars=unique_live_vars.copy())
- # if a LiveStage is covered by another LiveStage, we just keep the larger one.
- replace = False
- for index, prev_stage in enumerate(liveness_list):
- all_covered = True
- for ele in prev_stage.unique_live_vars:
- if ele not in stage.unique_live_vars:
- all_covered = False
- break
- if all_covered:
- replace = True
- break
- if replace:
- liveness_list[index] = stage
- else:
- liveness_list.append(stage)
-
- return liveness_list
-
- def get_alias_set(self):
- pass
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py
deleted file mode 100644
index 723e1bcf95ed..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from .batch_norm_handler import BatchNormHandler
-from .bcast_op_handler import BcastOpHandler
-from .conv_handler import ConvHandler
-from .dot_handler import DotHandler
-from .embedding_handler import EmbeddingHandler
-from .layer_norm_handler import LayerNormHandler
-from .operator_handler import OperatorHandler
-from .reshape_handler import ReshapeHandler
-from .unary_elementwise_handler import UnaryElementwiseHandler
-from .where_handler import WhereHandler
-
-__all__ = [
- 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
- 'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler', 'LayerNormHandler'
-]
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py
deleted file mode 100644
index 519436270828..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/batch_norm_handler.py
+++ /dev/null
@@ -1,492 +0,0 @@
-import operator
-from functools import reduce
-
-import torch
-from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
- ignore_sharding_exception
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
-
-from .operator_handler import OperatorHandler
-
-__all__ = ['BatchNormHandler']
-
-
-class BatchNormHandler(OperatorHandler):
- """
- A OperatorHandler which deals with the sharding strategies of normalization.
-
- To keep the math consistency, there are two way to do BatchNorm if the input
- shards on batch dimension:
- 1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
- 2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
- us to keep the computing correctness.
- In this handler, both methods will be considered.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.input_data = self.predecessor_node[0]._meta_data
- self.weight = self.module_named_parameters['weight']
- self.bias = self.module_named_parameters['bias']
- self.output_data = self.node._meta_data
- self._sanity_check()
-
- def _sanity_check(self):
- '''
- In sanity check, we need make sure the input data having correct dimension size.
- For BatchNorm1d, the dim of input data should be 3([N, C, L]).
- For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
- For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- assert self.input_data.dim() in (3, 4,
- 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
-
- def _generate_compute_cost(self, bs, channel_in):
- '''
- Compute the computation cost per device with this specific strategy.
-
- Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
-
- Argument:
- bs(int): Batch size of the input data.
- channel_in(int): The channel dimension of input data.
-
- Return:
- compute_cost(float): Computation cost per device with this specific strategy
- '''
- # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
- # TODO: a constant coefficient need to be added.
- # 1D: (L) * N * Cin
- # 2D: (H * W) * N * Cin
- # 3D: (H * W * D) * N * Cin
-
- input_size = self.input_data.shape[2:]
- input_size_product = reduce(operator.mul, input_size, 1)
- forward_compute_cost = input_size_product * bs * channel_in
- backward_activation_compute_cost = input_size_product * bs * channel_in
- backward_weight_compute_cost = input_size_product * bs * channel_in
- backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
- compute_cost = forward_compute_cost + backward_compute_cost
- return compute_cost
-
- def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
- '''
- Compute the memory cost per device with this specific strategy.
-
- Argument:
- sharding_size_forward(int): The forward activation will be divided
- into sharding_size_forward number partions.
- sharding_size_backward_activation(int): The backward activation will
- be divided into sharding_size_backward_activation number partions.
- sharding_size_weight(int): The backward weight will be divided
- into sharding_size_weight number partions.
-
- Return:
- memory_cost(Tuple[float]): Memory cost per device with this
- specific strategy, the first element of this tuple is forward
- memory cost, and the second element of this tuple is backward
- memory cost.
- memory_cost_forward(float): Memory cost of forward activation per
- device with this specific strategy.
- memory_cost_backward_activation(float): Memory cost of backward activation
- per device with this specific strategy.
- '''
- # compute the memory cost of this strategy
- dtype = self.input_data.dtype
- numel_output = self.output_data.numel()
- numel_input = numel_output
- numel_weight = self.weight.numel()
- size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
-
- # forward memory_cost
- memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
- memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
- memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
-
- # backward memory_cost
- memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
- memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
- memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
-
- # memory_cost pair
- memory_cost = (memory_cost_forward, memory_cost_backward)
-
- return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
-
- @ignore_sharding_exception
- def split_input_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
-
- dim_partition_dict_for_input = {1: [mesh_dim_0]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_0]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {1: [mesh_dim_0]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0]
- channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
- compute_cost = self._generate_compute_cost(bs, channel_in)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
- sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
- memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
- sharding_size_weight)
-
- # This strategy do not need to do all_reduce operation
- communication_cost = 0
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- # shard the output batch dimension to get all possible sharding strategy from this basic strategy
- new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
-
- dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
- new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
- # the computation cost is all the same
- new_compute_cost = compute_cost
-
- # the memory cost need to be recomputed
- # compute the memroy cost of new strategy
- new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
- sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
- new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
- new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
-
- # the communication cost need to count the sharding cost into this strategy
- # compute the communication cost of new strategy
- origin_communication_cost = communication_cost
- tiny_shard_cost = 10
- new_forward_communication_cost = tiny_shard_cost
- # we need to all gather the batch dimension for the basic strategy
- new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1)
- new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
-
- sharding_strategies = ShardingStrategy(new_name,
- output_sharding_spec=new_sharding_spec_for_output,
- compute_cost=new_compute_cost,
- communication_cost=new_communication_cost,
- memory_cost=new_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
-
- dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0]
- channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
- self.device_mesh.shape[mesh_dim_1])
- compute_cost = self._generate_compute_cost(bs, channel_in)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
- sharding_size_weight)
-
- # This strategy do not need to do all_reduce operation
- communication_cost = 0
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def non_split(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RR x R'
-
- dim_partition_dict_for_input = {}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0]
- channel_in = self.input_data.shape[1]
- compute_cost = self._generate_compute_cost(bs, channel_in)
-
- # compute the memory cost of this strategy
- sharding_size_forward = 1
- sharding_size_backward_activation = 1
- sharding_size_weight = 1
- memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
- sharding_size_weight)
-
- # This strategy do not need to do all_reduce operation
- communication_cost = 0
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
- dim_partition_dict_for_output = {0: mesh_dim_list}
- new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # the computation cost is all the same
- new_compute_cost = compute_cost
-
- # the memory cost need to be recomputed
- new_sharding_size_input = 1
- for mesh_dim in mesh_dim_list:
- new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim]
- new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
- new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight)
-
- # the communication cost need to count the sharding cost into this strategy
- origin_communication_cost = communication_cost
- tiny_shard_cost = 10
- new_forward_communication_cost = tiny_shard_cost
- if len(mesh_dim_list) == 1:
- new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation,
- mesh_dim_list[0])
- else:
- new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost(
- memory_cost_backward_activation, 0)
- new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
-
- new_sharding_strategy = ShardingStrategy(new_name,
- output_sharding_spec=new_sharding_spec_for_output,
- compute_cost=new_compute_cost,
- communication_cost=new_communication_cost,
- memory_cost=new_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input,
- sharding_spec_for_weight))
-
- return new_sharding_strategy
-
- # shard the output batch dimension to get all possible sharding strategy from this basic strategy
- # shard on mesh_dim_0
- new_name = f'S{mesh_dim_0}R = RR x R'
- mesh_dim_list = [mesh_dim_0]
- new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
- self.strategies_vector.append(new_sharding_strategy)
-
- # shard on mesh_dim_1
- new_name = f'S{mesh_dim_1}R = RR x R'
- mesh_dim_list = [mesh_dim_1]
- new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
- self.strategies_vector.append(new_sharding_strategy)
-
- # shard on mesh_dim_0, mesh_dim_1
- new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R'
- mesh_dim_list = [mesh_dim_0, mesh_dim_1]
- new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
- self.strategies_vector.append(new_sharding_strategy)
-
- @ignore_sharding_exception
- def split_input_batch(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
- channel_in = self.input_data.shape[1]
- compute_cost = self._generate_compute_cost(bs, channel_in)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
- sharding_size_weight = 1
- memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
- sharding_size_backward_activation,
- sharding_size_weight)
-
- # the all reduce communication will happen during the sync bn computing.
- communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
- channel_in = self.input_data.shape[1]
- compute_cost = self._generate_compute_cost(bs, channel_in)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_weight = 1
- memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
- sharding_size_backward_activation,
- sharding_size_weight)
-
- # the all reduce communication will happen during the sync bn computing.
- communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0)
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
- channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
- compute_cost = self._generate_compute_cost(bs, channel_in)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
- memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
- sharding_size_backward_activation,
- sharding_size_weight)
-
- # the all reduce communication will happen during the sync bn computing.
- communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- def register_strategy(self) -> StrategiesVector:
- '''
- Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
-
- Example:
- norm_handler = BatchNormHandler(node, strategies_vector,
- self.shape_consistency_manager)
- norm_handler.register_strategy()
- for strategy in norm_handler.strategies_vector:
- print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
-
- Output:
- RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
- RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
- RR = RR x R, computation_cost: 262144, memory_cost: 1048576
- RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
- '''
-
- # RS = RS x S and strategies based on it, such as
- # SS = RS x S
- self.split_input_channel(0, 1)
- self.split_input_channel(1, 0)
-
- # RR = RR x R and strategies based on it, such as
- # SR = SR x R
- self.non_split(0, 1)
-
- # RS01 = RS01 x S01
- self.split_input_channel_1d(0, 1)
-
- # SR = SR x R WITH SYNC_BN
- self.split_input_batch(0)
- self.split_input_batch(1)
-
- # SS = SS x S WITH SYNC_BN
- self.split_input_both_dim(0, 1)
- self.split_input_both_dim(1, 0)
-
- # S01R = S01R x R WITH SYNC_BN
- self.split_input_batch_1d(0, 1)
-
- return self.strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py
deleted file mode 100644
index 6ac6dce76675..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/bcast_op_handler.py
+++ /dev/null
@@ -1,552 +0,0 @@
-import operator
-import warnings
-from copy import deepcopy
-from functools import reduce
-from typing import Dict, List
-
-import torch
-from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
- enumerate_all_possible_2d_sharding,
- ignore_sharding_exception)
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-from .operator_handler import OperatorHandler
-
-__all__ = ['BcastOpHandler']
-
-
-class BcastOpHandler(OperatorHandler):
- """
- An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add).
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- assert len(self.predecessor_node) == 2
- self.lhs_data = self.predecessor_node[0]._meta_data
- self.rhs_data = self.predecessor_node[1]._meta_data
- self.lhs = self.predecessor_node[0]
- self.rhs = self.predecessor_node[1]
- self.output_data = self.node._meta_data
-
- def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
- shape = list(input_.shape)
-
- # padding the shape to the same length as output_data
- while len(shape) < self.output_data.dim():
- shape.insert(0, 1)
- shape = torch.Size(shape)
-
- # if the sharding happens on a size one dimension, we should record it as R.
- processed_dim_partition_dict = deepcopy(dim_partition_dict)
- for dim_index, _ in dim_partition_dict.items():
- if shape[dim_index] == 1:
- processed_dim_partition_dict.pop(dim_index)
- for dim_index, sharding_index_list in processed_dim_partition_dict.items():
- sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
- sharding_size = reduce(operator.mul, sharding_list, 1)
- assert shape[
- dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
- sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=shape,
- dim_partition_dict=processed_dim_partition_dict)
-
- return sharding_spec
-
- def _generate_compute_cost(self, total_sharding_size):
- lhs_matrix_shape = self.lhs_data.shape[-2:]
- rhs_matrix_shape = self.rhs_data.shape[-2:]
- batch_dimensions_shape = self.output_data.shape[:-2]
- batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
- compute_cost = reduce(
- operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
- return compute_cost
-
- def _generate_resharding_costs(self, sharding_specs):
- # The resharding_cost of weight is counted due to sharing weight cases.
- dtype = self.node._meta_data.dtype
- nodes = self.predecessor_node
- resharding_costs = {}
- size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
-
- # shape consistency manager is a singleton class
- shape_consistency_manager = ShapeConsistencyManager()
-
- for input_node, input_spec in zip(nodes, sharding_specs):
- resharding_costs[input_node] = []
- for strategy in input_node.strategies_vector:
- input_sharding_spec = strategy.output_sharding_spec
- assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
- # if the input shape is smaller than the target input, we will fill the input to the same length as target.
- # Then, use the padded input sharding spec to compute the resharding cost.
- if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
- new_entire_shape = list(input_sharding_spec.entire_shape)
- while len(new_entire_shape) < len(input_spec.entire_shape):
- new_entire_shape.insert(0, 1)
- new_entire_shape = torch.Size(new_entire_shape)
- new_device_mesh = input_sharding_spec.device_mesh
- new_dim_partition_dict = input_sharding_spec.dim_partition_dict
- input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
- entire_shape=new_entire_shape,
- dim_partition_dict=new_dim_partition_dict)
-
- # compute the resharding cost
- _, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
- input_sharding_spec, input_spec)
-
- # we need multiply the size of elem dtype to get correct communication cost
- resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
- resharding_costs[input_node].append(resharding_cost)
-
- return resharding_costs
-
- def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
-
- sharding_spec_list = []
- check_duplicated_list = []
- for output_dim_partition_dict in dim_partition_list:
- try:
- output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
- except AssertionError as e:
- warnings.warn(f'{e}')
- break
- sharding_seq = output_sharding_spec.sharding_sequence
- if sharding_seq not in check_duplicated_list:
- check_duplicated_list.append(sharding_seq)
- sharding_spec_list.append(output_sharding_spec)
-
- return sharding_spec_list
-
- def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
- # use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
-
- output_dim_partition_list = []
- dim_size = self.output_data.dim()
- # enumerate all the 2D sharding cases
- sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
- output_dim_partition_list.extend(sharding_list_2d)
-
- # enumerate all the 1D sharding cases
- sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
- output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
- sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
- output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
-
- # add empty dict for fully replicated case
- output_dim_partition_list.append({})
- output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
-
- return output_sharding_spec_list
-
- @ignore_sharding_exception
- def _register_strategy(self, output_sharding_spec):
- dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
- sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
- sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input)
-
- name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
- dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
-
- # compute the computation cost of this strategy
- sharding_dims = []
- for mesh_dims in dim_partition_dict_for_output.values():
- for mesh_dim in mesh_dims:
- sharding_dims.append(self.device_mesh.shape[mesh_dim])
- sharding_size = reduce(operator.mul, sharding_dims, 1)
- memory_cost = self.output_data.numel() / sharding_size
- compute_cost = memory_cost
- communication_cost = 0
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=output_sharding_spec,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
-
- self.strategies_vector.append(sharding_strategies)
-
- ##############################################
- #used to generate strategies for torch.matmul#
- ##############################################
- @ignore_sharding_exception
- def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
- # this dim partition dict only describes the batch dimensions, but in this scenario,
- # matrix dimensions are fully replicated, so it do not need extra process.
- sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_batch_dim)
- sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_batch_dim)
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_batch_dim)
-
- name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
-
- # compute the memory cost of this strategy
- batch_sharding_dims = []
- for mesh_dims in dim_partition_dict_for_batch_dim.values():
- for mesh_dim in mesh_dims:
- batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
- batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
- # in this case, total_sharding_size is equal to the batch sharding size
- memory_cost = self.output_data.numel() / batch_sharding_size
-
- # compute the computation cost of this strategy
- compute_cost = self._generate_compute_cost(batch_sharding_size)
-
- # in this case, no communication takes place.
- # TODO: add all-reduce cost if lhs or rhs is type of Parameters.
- communication_cost = 0
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
- # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
- # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
- # In this scenario, matrix dimensions will be sharded on 'i' dimension.
-
- # in this case, the matrix dimensions of lhs is sharded on 'i' dimension.
- dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
- dim_partition_dict_for_lhs.update({-2: mesh_dim_on_matrix})
-
- # in this case, the matrix dimensions of rhs is fully replicated.
- dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
-
- # in this case, the matrix dimensions of output is sharded on 'i' dimension.
-
- dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
- dim_partition_dict_for_output.update({-2: mesh_dim_on_matrix})
-
- # generate sharding specs
- sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
- sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
-
- # compute the memory cost of this strategy
- total_sharding_dims = []
-
- # append batch sharding dims
- for mesh_dims in dim_partition_dict_for_batch_dim.values():
- for mesh_dim in mesh_dims:
- total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
-
- # append the sharding dims on matrix dimension
- for mesh_dim in mesh_dim_on_matrix:
- total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
- total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
-
- # in this case, output_data uses all the sharding dims.
- memory_cost = self.output_data.numel() / total_sharding_size
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # TODO: add all-reduce cost if lhs or rhs is type of Parameters.
- communication_cost = 0
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
- # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
- # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
- # In this scenario, matrix dimensions will be sharded on 'k' dimension.
-
- # in this case, the matrix dimensions of lhs is sharded on 'k' dimension.
- dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
- dim_partition_dict_for_lhs.update({-1: mesh_dim_on_matrix})
-
- # in this case, the matrix dimensions of rhs is sharded on 'k' dimension.
- dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
- dim_partition_dict_for_rhs.update({-2: mesh_dim_on_matrix})
-
- # in this case, the matrix dimensions of output is fully replicated.
- dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
-
- # generate sharding specs
- sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
- sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
-
- # compute the memory cost of this strategy
- total_sharding_dims = []
- batch_sharding_dims = []
- # append batch sharding dims
- for mesh_dims in dim_partition_dict_for_batch_dim.values():
- for mesh_dim in mesh_dims:
- total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
- batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
-
- # append the sharding dims on matrix dimension
- for mesh_dim in mesh_dim_on_matrix:
- total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
- batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
- total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
-
- # in this case, output_data is fully replicated on matrix dimensions.
- memory_cost = self.output_data.numel() / batch_sharding_size
-
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # TODO: add all-reduce cost if lhs or rhs is type of Parameters.
- # The communication takes place during forward activation computation.
- if len(mesh_dim_on_matrix) == 1:
- communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
- else:
- communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
- # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
- # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
- # In this scenario, matrix dimensions will be is sharded on 'j' dimension.
-
- # in this case, the matrix dimensions of lhs is fully replicated.
- dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
-
- # in this case, the matrix dimensions of rhs is sharded on 'j' dimension.
- dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
- dim_partition_dict_for_rhs.update({-1: mesh_dim_on_matrix})
-
- # in this case, the matrix dimensions of output is sharded on 'j' dimension.
- dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
- dim_partition_dict_for_output.update({-1: mesh_dim_on_matrix})
-
- # generate sharding specs
- sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
- sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
-
- # compute the memory cost of this strategy
- total_sharding_dims = []
-
- # append batch sharding dims
- for mesh_dims in dim_partition_dict_for_batch_dim.values():
- for mesh_dim in mesh_dims:
- total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
-
- # append the sharding dims on matrix dimension
- for mesh_dim in mesh_dim_on_matrix:
- total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
- total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
-
- # in this case, output_data uses all the sharding dims.
- memory_cost = self.output_data.numel() / total_sharding_size
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # TODO: add all-reduce cost if lhs or rhs is type of Parameters.
- # The communication takes place during backward activation computation.
- if len(mesh_dim_on_matrix) == 1:
- communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
- else:
- communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
-
- self.strategies_vector.append(sharding_strategies)
-
- def _registry_1d_strategies_for_matmul(self, dim_partition_dict, mesh_dim_list):
- self._split_dim_i(dim_partition_dict, mesh_dim_list)
- self._split_dim_k(dim_partition_dict, mesh_dim_list)
- self._split_dim_j(dim_partition_dict, mesh_dim_list)
-
- @ignore_sharding_exception
- def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
- dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
- sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
-
- dim_partition_dict_for_rhs = {-2: [mesh_dim_1]}
- sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
-
- dim_partition_dict_for_output = {-2: [mesh_dim_0]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
-
- # compute the memory cost of this strategy
- total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
- output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
- # in this case, output_data uses all the sharding dims.
- memory_cost = self.output_data.numel() / output_sharding_size
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # TODO: add all-reduce cost if lhs or rhs is type of Parameters.
- # The communication takes place during forward activation computation.
- communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
- dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
- sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
-
- dim_partition_dict_for_rhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
- sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
-
- dim_partition_dict_for_output = {-1: [mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
-
- # compute the memory cost of this strategy
- total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
- output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
- # in this case, output_data uses all the sharding dims.
- memory_cost = self.output_data.numel() / output_sharding_size
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # TODO: add all-reduce cost if lhs or rhs is type of Parameters.
- # The communication takes place during forward and backward activation computation.
- communication_cost_forward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
- communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
- communication_cost = communication_cost_backward_activation + communication_cost_forward_activation
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
- dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
- sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
-
- dim_partition_dict_for_rhs = {-1: [mesh_dim_1]}
- sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
-
- dim_partition_dict_for_output = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
-
- # compute the memory cost of this strategy
- total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
- output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
- # in this case, output_data uses all the sharding dims.
- memory_cost = self.output_data.numel() / output_sharding_size
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # TODO: add all-reduce cost if lhs or rhs is type of Parameters.
- # The communication takes place during backward activation computation.
- communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
-
- self.strategies_vector.append(sharding_strategies)
-
- def _registry_2d_strategies_for_matmul(self):
- self._split_lhs_space_both_contract(0, 1)
- self._split_lhs_space_both_contract(1, 0)
- self._split_rhs_space_both_contract(0, 1)
- self._split_rhs_space_both_contract(1, 0)
- self._split_lhs_space_rhs_space(0, 1)
- self._split_lhs_space_rhs_space(1, 0)
-
- def register_strategy(self) -> StrategiesVector:
- MESH_DIM_LIST = [0, 1]
- if self.node.target != torch.matmul:
- output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
- for output_sharding_spec in output_sharding_specs:
- self._register_strategy(output_sharding_spec)
- else:
- # we only care about the non-computing dimensions,
- # therefore, we omit the last two dimensions.
- dim_size = self.output_data.dim() - 2
-
- # Both device mesh axises are uesd on batch dimensions
- dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size)
- for dim_partition_dict in dim_partition_dicts_2d:
- self._registry_no_split_strategies_for_matmul(dim_partition_dict)
-
- # Only one device mesh axis is uesd on batch dimensions
- for mesh_dim_index in [0, 1]:
- dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size)
- for dim_partition_dict in dim_partition_dicts_1d:
- self._registry_no_split_strategies_for_matmul(dim_partition_dict)
- self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]])
-
- # No device mesh axis is uesd on batch dimensions
- dim_partition_dict_on_batch_dim = {}
- self._registry_no_split_strategies_for_matmul(dim_partition_dict_on_batch_dim)
- self._registry_1d_strategies_for_matmul(dim_partition_dict_on_batch_dim, MESH_DIM_LIST)
- self._registry_2d_strategies_for_matmul()
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py
deleted file mode 100644
index d8952040dffe..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/conv_handler.py
+++ /dev/null
@@ -1,609 +0,0 @@
-import operator
-import warnings
-from functools import reduce
-
-import torch
-
-from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-
-from .operator_handler import OperatorHandler
-
-__all__ = ['ConvHandler']
-
-
-class ConvHandler(OperatorHandler):
- """
- An OperatorHandler which deals with the sharding strategies of Convolution.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.input_data = self.predecessor_node[0]._meta_data
- self.weight = self.module_named_parameters['weight']
- self.output_data = self.node._meta_data
- self._sanity_check()
-
- def _sanity_check(self):
- '''
- In sanity check, we need make sure the input data having correct dimension size.
- For Conv1d, the dim of input data should be 3([N, C, L]).
- For Conv2d, the dim of input data should be 4([N, C, H, W]).
- For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
- '''
- assert self.input_data.dim() in (3, 4,
- 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
-
- def _generate_compute_cost(self, bs, channel_in, channel_out):
- '''
- Compute the computation cost per device with this specific strategy.
-
- Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
-
- Argument:
- bs(int): Batch size of the input data.
- channel_in(int): The channel dimension of input data.
- channel_out(int): The out channel of the conv weight.
-
- Return:
- compute_cost(float): Computation cost per device with this specific strategy
- '''
- # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
- # 1D: (L) * N * Cout * Cin * kernel
- # 2D: (H * W) * N * Cout * Cin * kernel
- # 3D: (H * W * D) * N * Cout * Cin * kernel
- output_size = self.output_data.shape[2:]
- output_size_product = reduce(operator.mul, output_size, 1)
- input_size = self.input_data.shape[2:]
- input_size_product = reduce(operator.mul, input_size, 1)
- kernel_size = self.weight.shape[2:]
- kernel_size_product = reduce(operator.mul, kernel_size, 1)
- forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
- backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
- backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
- compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
- return compute_cost
-
- def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
- '''
- Compute the memory cost per device with this specific strategy.
-
- Argument:
- sharding_size_forward(int): The forward activation will be divided
- into sharding_size_forward number partions.
- sharding_size_backward_activation(int): The backward activation will
- be divided into sharding_size_backward_activation number partions.
- sharding_size_weight(int): The backward weight will be divided
- into sharding_size_weight number partions.
-
- Return:
- memory_cost(Tuple[float]): Memory cost per device with this
- specific strategy, the first element of this tuple is forward
- memory cost, and the second element of this tuple is backward
- memory cost.
- memory_cost_forward(float): Memory cost of forward activation per
- device with this specific strategy.
- memory_cost_backward_activation(float): Memory cost of backward activation
- per device with this specific strategy.
- '''
- # compute the memory cost of this strategy
- dtype = self.input_data.dtype
- numel_output = self.output_data.numel()
- numel_input = self.input_data.numel()
- numel_weight = self.weight.numel()
- size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
-
- # forward memory_cost
- memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
- memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
- memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
-
- # backward memory_cost
- memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
- memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
- memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
-
- # memory_cost pair
- memory_cost = (memory_cost_forward, memory_cost_backward)
-
- return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
-
- @ignore_sharding_exception
- def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {1: [mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
- channel_in = self.input_data.shape[1]
- channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
- compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
- sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
- memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
- sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
-
- # This strategy do not need to do all_reduce operation during forward
- communication_cost_forward = 0
- # compute the backward communication cost to all reduce the input activation grad
- communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation,
- mesh_dim_1)
- # compute the backward communication cost to all reduce the weight due to data parallel
- communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
- # total communication cost
- communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_input_batch(self, mesh_dim_0):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
- channel_in = self.input_data.shape[1]
- channel_out = self.weight.shape[1]
- compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
- sharding_size_weight = 1
- memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
- sharding_size_backward_activation,
- sharding_size_weight)
-
- # This strategy do not need to do all_reduce operation in forward phase.
- communication_cost_forward = 0
- # compute the backward communication cost to all reduce the weight due to data parallel
- communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
- # compute the total cost
- communication_cost = communication_cost_forward + communication_cost_backward_weight
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_0]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
- channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
- channel_out = self.weight.shape[1]
- compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
- memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost(
- sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
-
- # compute the communication cost of this strategy during forward phase
- communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
- # This strategy do not need to do all_reduce operation to compute the input activation grad
- communication_cost_backward_activation = 0
- # compute the backward communication cost to all reduce the weight due to data parallel
- communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
- # compute total cost
- communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
-
- dim_partition_dict_for_input = {1: [mesh_dim_0]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {1: [mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0]
- channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
- channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
- compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
- sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
- sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
-
- # compute the communication cost of this strategy during forward phase
- communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
- # compute the communication cost of this strategy during backward phase
- communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
- communication_cost = communication_cost_forward + communication_cost_backward
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
- name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
-
- dim_partition_dict_for_input = {1: [mesh_dim_0]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_0]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0]
- channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
- channel_out = self.weight.shape[1]
- compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
-
- # compute the memory cost of this strategy
- sharding_size_forward = 1
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
- sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
- memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
- sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
-
- # compute the communication cost of this strategy during forward phase
- communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
- # This strategy do NOT need all_reduce during forward phase
- communication_cost_backward = 0
- communication_cost = communication_cost_forward + communication_cost_backward
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_weight_out_channel(self, mesh_dim_0):
- name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
-
- dim_partition_dict_for_input = {}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {1: [mesh_dim_0]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {1: [mesh_dim_0]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0]
- channel_in = self.input_data.shape[1]
- channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0]
- compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
- sharding_size_backward_activation = 1
- sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
- memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost(
- sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
-
- # This strategy do not need to do all_reduce during forward phase
- communication_cost_forward = 0
- # compute the communication cost of this strategy during backward phase
- communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
- communication_cost = communication_cost_forward + communication_cost_backward
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def non_split(self):
- name = f'RR = RR x RR'
-
- dim_partition_dict_for_input = {}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0]
- channel_in = self.input_data.shape[1]
- channel_out = self.weight.shape[1]
- compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
-
- # compute the memory cost of this strategy
- sharding_size_forward = 1
- sharding_size_backward_activation = 1
- sharding_size_weight = 1
- memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
- sharding_size_weight)
-
- # This strategy do not need to do all_reduce in both forward and backward phase
- communication_cost = 0
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
- channel_in = self.input_data.shape[1]
- channel_out = self.weight.shape[1]
- compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
- sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
- mesh_dim_1]
- sharding_size_weight = 1
- memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
- sharding_size_backward_activation,
- sharding_size_weight)
-
- # This strategy do not need to do all_reduce in forward phase
- communication_cost_forward = 0
- # compute the backward communication cost to all reduce the weight due to data parallel
- communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
- memory_cost_backward_weight, 0)
- # compute the total communication cost
- communication_cost = communication_cost_backward_weight + communication_cost_forward
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
-
- dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- bs = self.input_data.shape[0]
- channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
- self.device_mesh.shape[mesh_dim_1])
- channel_out = self.weight.shape[1]
- compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
-
- # compute the memory cost of this strategy
- sharding_size_forward = 1
- sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
- mesh_dim_1]
- sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
- memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
- sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
-
- # compute communication cost during forward phase
- communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
- memory_cost_forward_activation, 0)
- # This strategy do NOT need do all_reduce during backward phase
- communication_cost_backward = 0
- communication_cost = communication_cost_forward + communication_cost_backward
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- def register_strategy(self) -> StrategiesVector:
- '''
- Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
-
- Example:
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- shape_consistency_manager = ShapeConsistencyManager()
-
- model = ConvModel(16, 32)
- input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
- # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
- # return conv
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- # [x, mul, conv, output]
- nodes = [node for node in gm.graph.nodes]
-
- # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
- strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
- setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
-
- strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
- conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
- device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
- conv_handler.register_strategy_into_strategies_vector()
- for strategy in conv_handler.strategies_vector:
- print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
-
- Output:
- S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
- S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
- S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
- S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
- S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]}
- S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]}
- RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
- RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
- RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
- RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
- RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
- RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
- RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
- S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]}
- RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
- '''
- # SS = SR x RS
- self.split_input_batch_weight_out_channel(0, 1)
- self.split_input_batch_weight_out_channel(1, 0)
-
- # SR = SR x RR
- self.split_input_batch(0)
- self.split_input_batch(1)
-
- # SR = SS x SR
- self.split_input_both_dim_weight_in_channel(0, 1)
- self.split_input_both_dim_weight_in_channel(1, 0)
-
- # RS = RS x SS
- self.split_input_in_channel_weight_both_channel(0, 1)
- self.split_input_in_channel_weight_both_channel(1, 0)
-
- # RR = RS x SR
- self.split_input_in_channel_weight_in_channel(0)
- self.split_input_in_channel_weight_in_channel(1)
-
- # RS = RR x RS
- self.split_weight_out_channel(0)
- self.split_weight_out_channel(1)
-
- # RR= RR x RR
- self.non_split()
-
- # S01R = S01R x RR
- self.split_1d_parallel_on_input_batch(0, 1)
-
- # RR = RS01 x S01R
- self.split_1d_parallel_on_in_channel(0, 1)
-
- return self.strategies_vector
-
-
-CONV_STRATEGIES_LIST = [
- 'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
- 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
- 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
-]
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py
deleted file mode 100644
index 1f2281cc4172..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/dot_handler.py
+++ /dev/null
@@ -1,756 +0,0 @@
-import operator
-from enum import Enum
-from functools import reduce
-from typing import List
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-
-from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
-from .operator_handler import OperatorHandler
-from .strategy_generator import IntermediateStrategy, StrategyGenerator
-
-__all__ = ['DotHandler']
-
-
-class DotProductStrategyGenerator(StrategyGenerator):
- """
- DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
- This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
- do not consider bias here.
- """
-
- def validate(self, input, other):
- assert input.dim() == 1 and other.dim() == 1
-
- def no_split(self):
- name = f'R = R dot R'
- dim_partition_dict = {"input": {}, "other": {}, "output": {}}
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def split_one_dim(self, mesh_dim):
- name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}'
- dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}}
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
-
- def generate(self) -> List[IntermediateStrategy]:
- strategy_list = []
-
- # do not split dimensions for dot product
- # R = R dot R
- strategy_list.append(self.no_split())
-
- # split two tensors in the same dimensions
- # S = S dot S
- strategy_list.append(self.split_one_dim(0))
- strategy_list.append(self.split_one_dim(1))
-
- return strategy_list
-
-
-class MatVecStrategyGenerator(StrategyGenerator):
-
- def validate(self, input, other) -> bool:
- assert input.dim() > 1 and other.dim() == 1
-
- def no_split(self):
- name = "R = R x R"
- dim_partition_dict = {"input": {}, "other": {}, "output": {}}
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def split_input_batch(self, mesh_dim):
- name = f'S{mesh_dim}R = S{mesh_dim}R x R'
- dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def generate(self) -> List[IntermediateStrategy]:
- strategy_list = []
-
- # no split
- strategy_list.append(self.no_split())
-
- # split the batch dim for the first tensor only
- strategy_list.append(self.split_input_batch(0))
- strategy_list.append(self.split_input_batch(1))
-
- return strategy_list
-
-
-class MatMulStrategyGenerator(StrategyGenerator):
- """
- MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
- a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
-
- A matmul can be formulated as [n, p] x [p, q] = [n, q]
-
- Args:
- is_linear (bool): whether this generator is used for nn.Linear and F.linear.
- This will incur extra transformation of the dim partitioning as the weight is transposed.
- """
-
- def __init__(self, is_linear: bool, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.is_linear = is_linear
-
- # as the weight for the linear module is transposed, we can compute
- # the correponding dimension indexfor convenience
- if is_linear:
- self.dim_q = 0
- self.dim_p = 1
- else:
- self.dim_q = 1
- self.dim_p = 0
-
- def validate(self, input, other, bias) -> bool:
- # make sure the second tensor is a 2D tensor
- assert input.dim() > 0 and other.dim() == 2
-
- # make sure bias is of the same dimension
- if self.is_linear:
- assert bias is None or bias.shape[-1] == other.shape[0]
- else:
- assert bias is None or bias.shape[-1] == other.shape[1]
-
- def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
- # handle case SS = SR x RS
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
-
- dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- self.dim_q: [mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
- # handle the case SR = SS x SR
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
- dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- "other": {
- self.dim_p: [mesh_dim_1]
- },
- "bias": {},
- "output": {
- 0: [mesh_dim_0]
- },
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
-
- def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
- dim_partition_dict = {
- "input": {
- -1: [mesh_dim_0]
- },
- "other": {
- self.dim_p: [mesh_dim_0],
- self.dim_q: [mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_1]
- },
- "output": {
- -1: [mesh_dim_1]
- },
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def recompute_split_both_contract(self, mesh_dim):
- name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
- dim_partition_dict = {
- "input": {
- -1: [mesh_dim]
- },
- "other": {
- self.dim_p: [mesh_dim]
- },
- "bias": {},
- "output": {},
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
-
- def split_rhs_space_only(self, mesh_dim):
- name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
- dim_partition_dict = {
- "input": {},
- "other": {
- self.dim_q: [mesh_dim]
- },
- "bias": {
- -1: [mesh_dim]
- },
- "output": {
- -1: [mesh_dim]
- },
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
-
- def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
- dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "other": {},
- "bias": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
- dim_partition_dict = {
- "input": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- self.dim_p: [mesh_dim_0, mesh_dim_1]
- },
- "bias": {},
- "output": {},
- }
- return IntermediateStrategy(name=name,
- dim_partition_dict=dim_partition_dict,
- all_reduce_axis=[mesh_dim_0, mesh_dim_1])
-
- def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
-
- dim_partition_dict = {
- "input": {},
- "other": {
- self.dim_q: [mesh_dim_0, mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- "output": {
- -1: [mesh_dim_0, mesh_dim_1]
- },
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
-
-class BatchedMatMulStrategyGenerator(StrategyGenerator):
- """
- Generate sharding strategies for the batched matrix multiplication.
-
- A batched matrix multiplication can be viewed as
- [b, i, k] x [b, k, j] -> [b, i, j]
- """
-
- def __init__(self, is_torch_bmm: bool, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.is_torch_bmm = is_torch_bmm
-
- def validate(self, input, other, bias) -> bool:
- if self.is_torch_bmm:
- assert input.shape == other.shape
- assert input.dim() > 2
- assert other.shape[-1] == bias.shape[0]
- else:
- # TODO: validate these inputs are broadcastable
- pass
-
- def split_one_batch_dim(self):
- if 1 in self.device_mesh.mesh_shape:
- mesh_dim = self.device_mesh.mesh_shape.index(1)
- name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
- dim_partition_dict = {
- "input": {
- 0: [mesh_dim]
- },
- "other": {
- 0: [mesh_dim]
- },
- "bias": {},
- "output": {
- 0: [mesh_dim]
- }
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
- else:
- return None
-
- def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
- dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0, mesh_dim_1]
- },
- "bias": {},
- "output": {
- 0: [mesh_dim_0, mesh_dim_1]
- }
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def split_one_batch_dim(self, mesh_dim):
- name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
- dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
- dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0],
- -2: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0]
- },
- "bias": {},
- "output": {
- 0: mesh_dim_0,
- -2: [mesh_dim_1]
- }
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
- dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0]
- },
- "other": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- "bias": {
- -1: [mesh_dim_1]
- },
- "output": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- }
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
-
- def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
- name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
- dim_partition_dict = {
- "input": {
- 0: [mesh_dim_0],
- -1: [mesh_dim_1]
- },
- "other": {
- 0: [mesh_dim_0],
- -2: [mesh_dim_1]
- },
- "bias": {},
- "output": {
- 0: [mesh_dim_0],
- -2: [mesh_dim_1]
- }
- }
- return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
-
- def generate(self) -> List[IntermediateStrategy]:
- strategy_list = []
-
- # split only the batch dimension
- # Sb = Sb x Sb
- # can be None as it is only for 1D device mesh
- strategy = self.split_one_batch_dim()
- if strategy:
- strategy_list.append(strategy)
-
- # split batch dim of two inputs and the i dim of the first tensor
- # SbSi = SbSi x Sb
- strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
- strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
-
- # split batch dim of two inputs and the j of the second tensor
- # SbSj = Sb x SbSj
- strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
- strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
-
- # split batch dim of two inputs and the k dim of two inputs
- # Sb = SbSk x SbSk, need to all-reduce by k dim
- strategy_list.append(self.split_batch_dim_both_contract(0, 1))
- strategy_list.append(self.split_batch_dim_both_contract(1, 0))
-
- # split two batch dim
- strategy_list.append(self.split_two_batch_dim(0, 1))
- strategy_list.append(self.split_two_batch_dim(1, 0))
-
- return strategy_list
-
-
-class DotHandler(OperatorHandler):
- """
- A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.input_data = self.predecessor_node[0]._meta_data
- self.weight = self.module_named_parameters['weight']
- self.output_data = self.node._meta_data
-
- def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size):
- # TODO: consider bias addition
- compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
- return compute_cost
-
- @ignore_sharding_exception
- def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
- # handle case SS = SR x RS
- name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- # linear layer weight is transposed during init
- dim_partition_dict_for_weight = {0: [mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute computation cost
- total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
-
- # compute the memory cost of this strategy
- toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
- dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
-
- # compute the communication cost
- communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
- communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
- communication_cost = communication_cost_activation_backward + communication_cost_weight_backward
-
- # create and register strategy
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=toatl_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
- # handle the case SR = SS x SR
- name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- # since weight of the linear layer is transposed
- # the actual dim to be sharded is 1
- dim_partition_dict_for_weight = {1: [mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
-
- # compute the memory cost of this strategy
- toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
- dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
-
- # compute the communication cost of this strategy
- communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
- communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
- communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=toatl_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
-
- dim_partition_dict_for_input = {1: [mesh_dim_0]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {1: [mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
-
- # compute the memory cost of this strategy
- toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
- dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
-
- # compute the communication cost of this strategy
- communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0)
- communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1)
- communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=toatl_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def recompute_split_both_contract(self, mesh_dim):
- name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
-
- dim_partition_dict_for_input = {1: [mesh_dim]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {1: [mesh_dim]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- total_sharding_size = self.device_mesh.shape[mesh_dim]
- compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
-
- # compute the memory cost of this strategy
- toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
- dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
-
- # compute the communication cost of this strategy
- communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=toatl_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_rhs_space_only(self, mesh_dim):
- name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
-
- dim_partition_dict_for_input = {}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {1: [mesh_dim]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- total_sharding_size = self.device_mesh.shape[mesh_dim]
- compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
-
- # compute the memory cost of this strategy
- toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
- dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
-
- # compute the communication cost of this strategy
- communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
- communication_cost = communication_cost_activation_backward
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=toatl_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
-
- # compute the memory cost of this strategy
- toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
- dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
-
- # compute the communication cost of this strategy
- communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
- communication_cost = communication_cost_weight_backward
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=toatl_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
-
- dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
-
- # compute the memory cost of this strategy
- toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
- dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
-
- # compute the communication cost of this strategy
- communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost(
- activation_memory_cost, 0)
- communication_cost = communication_cost_forward_activation
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=toatl_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
- name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
-
- dim_partition_dict_for_input = {}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
-
- # compute the memory cost of this strategy
- toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
- dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
- # compute the communication cost of this strategy
- communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
- input_grad_memory_cost, 0)
- communication_cost = communication_cost_activation_backward
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=toatl_memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- def register_strategy(self) -> StrategiesVector:
- '''
- Generate every possible strategies for a linear node, and record all strategies into the strategies_vector.
-
- Output:
-
- '''
- # SS = SR x RS
- self.split_lhs_space_rhs_space(0, 1)
- self.split_lhs_space_rhs_space(1, 0)
-
- # SR = SS x SR
- self.split_lhs_space_both_contract(0, 1)
- self.split_lhs_space_both_contract(1, 0)
-
- # RS = RS x SS
- self.split_rhs_space_both_contract(0, 1)
- self.split_rhs_space_both_contract(1, 0)
-
- # RR= RS x SR
- self.recompute_split_both_contract(0)
- self.recompute_split_both_contract(1)
-
- # RS = RR x RS
- self.split_rhs_space_only(0)
- self.split_rhs_space_only(1)
-
- # S01R = S01R x RR
- self.split_lhs_1st_dim_1d(0, 1)
-
- # RR = RS01 x S01R
- self.split_lhs_2nd_dim_1d(0, 1)
-
- # RS01 = RR x RS01
- self.split_rhs_2nd_dim_1d(0, 1)
-
- return self.strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py
deleted file mode 100644
index d01a487ad673..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/embedding_handler.py
+++ /dev/null
@@ -1,179 +0,0 @@
-import operator
-import warnings
-from copy import deepcopy
-from functools import reduce
-from typing import Dict, List
-
-import torch
-from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
- ignore_sharding_exception
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-from .operator_handler import OperatorHandler
-
-__all__ = ['EmbeddingHandler']
-
-
-class EmbeddingHandler(OperatorHandler):
- """
- An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding).
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.input_data = self.predecessor_node[0]._meta_data
- self.weight = self.module_named_parameters['weight']
- self.output_data = self.node._meta_data
-
- def _generate_compute_cost(self, total_sharding_size):
- input_shape = self.input_data.shape
- weight_shape = self.weight.shape
- input_shape_product = reduce(operator.mul, input_shape, 1)
- weight_shape_product = reduce(operator.mul, weight_shape, 1)
- compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size
- return compute_cost
-
- def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
- '''
- Compute the memory cost per device with this specific strategy.
-
- Argument:
- sharding_size_forward(int): The forward activation will be divided
- into sharding_size_forward number partions.
- sharding_size_backward_activation(int): The backward activation will
- be divided into sharding_size_backward_activation number partions.
- sharding_size_weight(int): The backward weight will be divided
- into sharding_size_weight number partions.
-
- Return:
- memory_cost(Tuple[float]): Memory cost per device with this
- specific strategy, the first element of this tuple is forward
- memory cost, and the second element of this tuple is backward
- memory cost.
- memory_cost_forward(float): Memory cost of forward activation per
- device with this specific strategy.
- memory_cost_backward_activation(float): Memory cost of backward activation
- per device with this specific strategy.
- '''
- # compute the memory cost of this strategy
- dtype = self.input_data.dtype
- numel_output = self.output_data.numel()
- numel_input = self.input_data.numel()
- numel_weight = self.weight.numel()
- size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
-
- # forward memory_cost
- memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
- memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
- memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
-
- # backward memory_cost
- memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
- memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
- memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
-
- # memory_cost pair
- memory_cost = (memory_cost_forward, memory_cost_backward)
-
- return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
-
- @ignore_sharding_exception
- def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
-
- dim_partition_dict_for_input = {}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {2: [mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
- sharding_size_backward_activation = 1
- sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
- sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
-
- # compute the communication cost of this strategy during forward phase
- communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
- # compute the communication cost of this strategy during backward phase
- communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
- communication_cost = communication_cost_forward + communication_cost_backward
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
- name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'
-
- dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
-
- # compute the computation cost of this strategy
- total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # compute the memory cost of this strategy
- sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
- sharding_size_weight = 1
- memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
- sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
-
- # This strategy do not need to do all_reduce during forward phase
- communication_cost_forward = 0
- # compute the communication cost of this strategy during backward phase
- communication_cost_backward_activation = 0
- communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
- memory_cost_backward_weight, 0)
- communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight
- communication_cost = communication_cost_forward + communication_cost_backward
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
- self.strategies_vector.append(sharding_strategies)
-
- def register_strategy(self) -> StrategiesVector:
- '''
- Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
- '''
- # RRS = RR x SS
- self.split_weight_both_dim(0, 1)
- self.split_weight_both_dim(1, 0)
-
- # SSR = SS x RR
- self.split_input_both_dim(0, 1)
- self.split_input_both_dim(1, 0)
-
- return self.strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py
deleted file mode 100644
index 8062d0f4babf..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/layer_norm_handler.py
+++ /dev/null
@@ -1,241 +0,0 @@
-import operator
-from functools import reduce
-
-import torch
-
-from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
- enumerate_all_possible_1d_sharding,
- enumerate_all_possible_2d_sharding,
- generate_sharding_size,
- ignore_sharding_exception,
-)
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-
-from .operator_handler import OperatorHandler
-
-__all__ = ['LayerNormHandler']
-
-
-class LayerNormHandler(OperatorHandler):
- """
- A OperatorHandler which deals with the sharding strategies of normalization.
-
- Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.input_data = self.predecessor_node[0]._meta_data
- self.weight = self.module_named_parameters['weight']
- self.bias = self.module_named_parameters['bias']
- self.output_data = self.node._meta_data
-
- def _generate_compute_cost(self, total_sharding_size):
- '''
- Compute the computation cost per device with this specific strategy.
-
- Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
-
- Argument:
- bs(int): Batch size of the input data.
- channel_in(int): The channel dimension of input data.
-
- Return:
- compute_cost(float): Computation cost per device with this specific strategy
- '''
- # TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
- # TODO: a constant coefficient need to be added.
-
- norm_kernel_size = self.weight.shape
- # in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
- input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)]
- input_batch_product = reduce(operator.mul, input_batch_shape, 1)
- norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1)
- forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
- backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
- # To compute gradient of on norm kernel element requires input_batch_product times computation, so
- # the total cost is input_batch_product * norm_kernel_product
- backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
- backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
- compute_cost = forward_compute_cost + backward_compute_cost
- return compute_cost
-
- def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
- '''
- Compute the memory cost per device with this specific strategy.
-
- Argument:
- sharding_size_forward(int): The forward activation will be divided
- into sharding_size_forward number partions.
- sharding_size_backward_activation(int): The backward activation will
- be divided into sharding_size_backward_activation number partions.
- sharding_size_weight(int): The backward weight will be divided
- into sharding_size_weight number partions.
-
- Return:
- memory_cost(Tuple[float]): Memory cost per device with this
- specific strategy, the first element of this tuple is forward
- memory cost, and the second element of this tuple is backward
- memory cost.
- memory_cost_forward(float): Memory cost of forward activation per
- device with this specific strategy.
- memory_cost_backward_activation(float): Memory cost of backward activation
- per device with this specific strategy.
- '''
- # compute the memory cost of this strategy
- dtype = self.input_data.dtype
- numel_output = self.output_data.numel()
- # this operation will not change the shape of input
- numel_input = numel_output
- numel_weight = self.weight.numel()
- size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
-
- # forward memory_cost
- memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
- memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
- memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
-
- # backward memory_cost
- memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
- memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
- memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
-
- # memory_cost pair
- memory_cost = (memory_cost_forward, memory_cost_backward)
-
- return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
-
- def _generate_strategy_with_dim_partition(self, dim_partition):
- dim_partition_dict_for_input = dim_partition
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = dim_partition
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}'
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
-
- total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh)
- # compute the computation cost of this strategy
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # compute the memory cost of this strategy
- sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh)
- sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh)
- sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh)
- memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
- sharding_size_backward_activation,
- sharding_size_weight)
-
- total_mesh_dim_list = []
- for mesh_dim_list in dim_partition.values():
- total_mesh_dim_list.extend(mesh_dim_list)
-
- # This strategy do not need to do all_reduce operation for activation
- communication_cost_forward_activation = 0
- communication_cost_backward_activation = 0
- if len(total_mesh_dim_list) == 1:
- communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight,
- total_mesh_dim_list[0])
- else:
- assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.'
- communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
- memory_cost_backward_weight, 0)
- communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- @ignore_sharding_exception
- def split_input_batch_single_mesh_dim(self, mesh_dim_0):
- batch_dimension_length = self.input_data.dim() - self.weight.dim()
- dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
- for dim_partition in dim_partition_list:
- self._generate_strategy_with_dim_partition(dim_partition)
-
- @ignore_sharding_exception
- def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
- batch_dimension_length = self.input_data.dim() - self.weight.dim()
- dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
- for dim_partition in dim_partition_list:
- self._generate_strategy_with_dim_partition(dim_partition)
-
- @ignore_sharding_exception
- def non_split(self):
- name = f'RR = RR x R'
-
- dim_partition_dict_for_input = {}
- sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
-
- dim_partition_dict_for_weight = {}
- sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
-
- dim_partition_dict_for_output = {}
- sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
-
- total_sharding_size = 1
- # compute the computation cost of this strategy
- compute_cost = self._generate_compute_cost(total_sharding_size)
-
- # compute the memory cost of this strategy
- sharding_size_forward = 1
- sharding_size_backward_activation = 1
- sharding_size_weight = 1
- memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
- sharding_size_weight)
-
- # This strategy do not need to do all_reduce operation
- communication_cost = 0
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=sharding_spec_for_output,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
-
- self.strategies_vector.append(sharding_strategies)
-
- def register_strategy(self) -> StrategiesVector:
- '''
- Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
-
- Example:
- norm_handler = BatchNormHandler(node, strategies_vector,
- self.shape_consistency_manager)
- norm_handler.register_strategy()
- for strategy in norm_handler.strategies_vector:
- print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
-
- Output:
- RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
- RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
- RR = RR x R, computation_cost: 262144, memory_cost: 1048576
- RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
- '''
-
- # SR = SR x R with single mesh dim on batch dimensions
- self.split_input_batch_single_mesh_dim(0)
- self.split_input_batch_single_mesh_dim(1)
-
- # SR = SR x R with both mesh dims on batch dimensions
- self.split_input_batch_both_mesh_dim(0, 1)
-
- # RR = RR x R
- self.non_split()
-
- return self.strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py
deleted file mode 100644
index b120cc16b04b..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/operator_handler.py
+++ /dev/null
@@ -1,149 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Dict, List
-from webbrowser import Opera
-
-import torch
-import torch.nn as nn
-from torch.fx.node import Node
-
-from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-from .._utils import generate_resharding_costs, generate_sharding_spec
-from ..sharding_strategy import StrategiesVector
-
-__all__ = ['OperatorHandler']
-
-
-class OperatorHandler(ABC):
- '''
- The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
-
- Args:
- node (Node): the input node in node argument list.
- device_mesh (DeviceMesh): A logical view of a physical mesh.
- strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
- handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
- '''
-
- def __init__(self,
- node: Node,
- device_mesh: DeviceMesh,
- strategies_vector: StrategiesVector,
- handle_backward: bool = True):
- self.node = node
- self.predecessor_node = list(node._input_nodes.keys())
- self.successor_node = list(node.users.keys())
- self.device_mesh = device_mesh
- self.strategies_vector = strategies_vector
- self.handle_backward = handle_backward
-
- # find the module and its parameters associated with this node
- # this can be used to compute the compute/communication/sharding cost
- if self.node.op == 'call_module':
- module = node.graph.owning_module.get_submodule(node.target)
- named_parameters = list(module.named_parameters(recurse=False))
- # convert named parameters from list to dict
- named_parameters = {k: v for k, v in named_parameters}
- elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP:
- module = None
- parameters = list(self.node.args)[1]
- if isinstance(parameters, Node):
- named_parameters = {'weight': parameters._meta_data}
- else:
- named_parameters = {}
- else:
- module = None
- named_parameters = None
- self.module = module
- self.module_named_parameters = named_parameters
-
- @abstractmethod
- def register_strategy(self) -> StrategiesVector:
- """
- Register
- """
- pass
-
- def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight,
- sharding_spec_for_input):
- '''
- Compute the memory cost per device with this specific strategy.
-
- Argument:
- dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded,
- and the value of the key decribe which logical axis will be sharded in that dimension.
- dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded,
- and the value of the key decribe which logical axis will be sharded in that dimension.
- Return:
- total_memory_cost(float): total memory cost per device with this specific strategy
- activation_cost(float): the memory cost of activation per device with this specific strategy
- weight_memory_cost(float): the memory cost of weight per device with this specific strategy
- '''
- # compute the size of one element with specific dtype
- dtype = self.input_data.dtype
- size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
-
- # compute the memory cost of activation
- activation_numel = self.output_data.numel()
- output_mesh_dims = []
- for sharding_dim, mesh_dims in dim_partition_dict_for_output.items():
- output_mesh_dims.extend(mesh_dims)
- activation_sharding_size = 1
- for mesh_dim in output_mesh_dims:
- activation_sharding_size *= self.device_mesh.shape[mesh_dim]
- activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes
-
- # compute the memory cost of weight
- weight_numel = self.weight.numel()
- weight_sharding_size = 1
- weight_mesh_dims = []
- for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items():
- weight_mesh_dims.extend(mesh_dims)
- for mesh_dim in weight_mesh_dims:
- weight_sharding_size *= self.device_mesh.shape[mesh_dim]
- weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
-
- # compute the memory cost of input grad
- input_grad_numel = self.input_data.numel()
- input_grad_sharding_size = 1
- input_grad_mesh_dims = []
- for sharding_dim, mesh_dims in sharding_spec_for_input.items():
- input_grad_mesh_dims.extend(mesh_dims)
- for mesh_dim in input_grad_mesh_dims:
- input_grad_sharding_size *= self.device_mesh.shape[mesh_dim]
- input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes
-
- memory_cost_forward = activation_memory_cost + weight_memory_cost
- memory_cost_backward = input_grad_memory_cost + weight_memory_cost
-
- return (memory_cost_forward,
- memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost
-
- def _generate_resharding_costs(self, sharding_specs):
- # The resharding_cost of weight is counted due to sharing weight cases.
- if hasattr(self.node._meta_data, 'dtype'):
- dtype = self.node._meta_data.dtype
- else:
- assert isinstance(self.node._meta_data,
- tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected'
- dtype = self.node._meta_data[0].dtype
-
- nodes = self.predecessor_node
- return generate_resharding_costs(nodes=nodes,
- sharding_specs=sharding_specs,
- count_backward=self.handle_backward,
- dtype=dtype)
-
- def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
- return generate_sharding_spec(input_=input_,
- device_mesh=self.device_mesh,
- dim_partition_dict=dim_partition_dict)
-
- @abstractmethod
- def _generate_compute_cost(self, *args, **kwargs):
- """
- Compute the flops involved in the node.
- """
- pass
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py
deleted file mode 100644
index d4ccc8a9c323..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/reshape_handler.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import colorsys
-import math
-import warnings
-from copy import deepcopy
-
-import torch
-
-from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-from ..constants import INFINITY_COST
-from .operator_handler import OperatorHandler
-
-
-class ReshapeHandler(OperatorHandler):
- """
- An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.input_data = self.predecessor_node[0]._meta_data
- self.output_data = self.node._meta_data
-
- def _generate_compute_cost(self, *args, **kwargs):
- return super()._generate_compute_cost(*args, **kwargs)
-
- @ignore_sharding_exception
- def register_strategy(self):
- # TODO: add strategies with more output sharding specs other than only fully replicated.
- input_node = self.strategies_vector.predecessor_nodes[0]
- # For reshape function, to keep the computing correctness we keep the sharding
- # spec of input is fully replicated. In addition, we will keep the output in
- # replica status and let the successor node choose the way to resharding the
- # output node. Therefore, the different strategies of input node with same
- # output sharding spec will generate same strategy for reshape function.
- sharding_spec_checklist = []
- for strategy in input_node.strategies_vector:
- # It looks a little bit confusing, the input of the processing node
- # is the output of the input_node.
- input_sharding_spec = strategy.output_sharding_spec
- assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
- if input_sharding_spec in sharding_spec_checklist:
- continue
- sharding_spec_checklist.append(input_sharding_spec)
- dim_partition_dict_for_output = {}
- if isinstance(self.output_data, tuple):
- dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))]
- try:
- if isinstance(self.output_data, tuple):
- output_sharding_spec = []
- for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output):
- output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict))
- else:
- output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
- except AssertionError as e:
- warnings.warn(f'{e}')
- continue
- name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
- # TODO: use meta_info_prop to profile memory cost and compute cost
- compute_cost = 0
- # consider node._meta_data is in type of tuple
- memory_cost = 0
-
- # compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
- dim_partition_dict_for_replicate_input = {}
- replicate_input_sharding_spec = self._generate_sharding_spec(self.input_data,
- dim_partition_dict_for_replicate_input)
- # shape consistency manager is a singleton class
- shape_consistency_manager = ShapeConsistencyManager()
- _, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
- replicate_input_sharding_spec)
- communication_cost = communication_cost["total"]
-
- # generate resharding cost
- resharding_costs = self._generate_resharding_costs([input_sharding_spec])
-
- # to prevent the resharding happening, set their resharding cost to inf.
- resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]]
- sharding_strategy = ShardingStrategy(name,
- output_sharding_spec,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=[input_sharding_spec])
- self.strategies_vector.append(sharding_strategy)
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py
deleted file mode 100644
index 4e39fcd8e82d..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/strategy_generator.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from dataclasses import dataclass
-from abc import ABC, abstractmethod
-from typing import List, Dict
-from colossalai.device.device_mesh import DeviceMesh
-
-__all__ = ['IntermediateStrategy', 'StrategyGenerator']
-
-
-@dataclass
-class IntermediateStrategy:
- """
- IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
- to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
-
- Args:
- name (str): name of the sharding strategy.
- dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
- all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
- """
- name: str
- dim_partition_dict: Dict[str, Dict[int, List[int]]]
- all_reduce_axis: List[int] = None
-
-
-class StrategyGenerator(ABC):
- """
- StrategyGenerator is used to generate the same group of sharding strategies.
- """
-
- def __init__(self, device_mesh: DeviceMesh):
- self.device_mesh = device_mesh
-
- @abstractmethod
- def generate(self) -> List[IntermediateStrategy]:
- """
- """
- pass
-
- @abstractmethod
- def validate(self, *args, **kwargs) -> bool:
- """
- Validate if the operands are of desired shape.
- If True, means this generator can be used for the current operation.
- """
- pass
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py
deleted file mode 100644
index c929d2fade98..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/unary_elementwise_handler.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import math
-import operator
-import warnings
-from copy import deepcopy
-from functools import reduce
-from typing import Dict, List
-
-import torch
-from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
- ignore_sharding_exception
-from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
- INFINITY_COST
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-from .operator_handler import OperatorHandler
-
-__all__ = ['UnaryElementwiseHandler']
-
-
-class UnaryElementwiseHandler(OperatorHandler):
- """
- An OperatorHandler which deals with the sharding strategies of UnaryElementwiseOp.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- if self.node.op == 'call_module':
- target = self.node.target
- submod = self.node.graph.owning_module.get_submodule(target)
- submod_type = type(submod)
- if submod_type == torch.nn.Dropout:
- print(f'predecessor nodes of dropout node are {self.predecessor_node}')
- input_nodes_len = 0
- for check_node in self.predecessor_node:
- if isinstance(check_node._meta_data, torch.Tensor):
- input_nodes_len += 1
- assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op, node name is {self.node}, node args is {self.node.args}.'
- self.input_data = self.predecessor_node[0]._meta_data
- self.input_node = self.predecessor_node[0]
- self.output_data = self.node._meta_data
-
- def _generate_compute_cost(self, *args, **kwargs):
- return super()._generate_compute_cost(*args, **kwargs)
-
- @ignore_sharding_exception
- def register_strategy(self):
- # TODO: integrate element-wise func and module together
- # create sharding strategy for element-wise function
-
- # For element-wise function, we keep the sharding spec of output node same as
- # the input. Therefore, the different strategies of input node with same
- # output sharding spec will generate same strategy for element-wise function.
-
- for index, strategy in enumerate(self.input_node.strategies_vector):
- # It looks a little bit confusing, the input of the processing node
- # is the output of the input_node.
- input_sharding_spec = strategy.output_sharding_spec
- assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
-
- dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
- try:
- output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
- except AssertionError as e:
- warnings.warn(f'{e}')
- continue
- # add index into name to pass the duplicated check
- # we keep same strategies with different name for node merging, and it will not increase the searching space,
- # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
- name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
- # TODO: use meta_info_prop to profile memory cost and compute cost
- compute_cost = self.output_data.numel()
- memory_cost = 0
-
- resharding_costs = self._generate_resharding_costs([input_sharding_spec])
-
- # to prevent the resharding happening, set their resharding cost to inf.
- resharding_costs[self.input_node] = [
- 0 if cost == 0 else INFINITY_COST for cost in resharding_costs[self.input_node]
- ]
- sharding_strategy = ShardingStrategy(name,
- output_sharding_spec,
- compute_cost=compute_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=[input_sharding_spec])
- self.strategies_vector.append(sharding_strategy)
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py b/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py
deleted file mode 100644
index 6991e913d463..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py
+++ /dev/null
@@ -1,186 +0,0 @@
-import operator
-import warnings
-from copy import deepcopy
-from functools import reduce
-from typing import Dict, List
-
-import torch
-
-from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
- enumerate_all_possible_2d_sharding,
- ignore_sharding_exception)
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-from .operator_handler import OperatorHandler
-
-__all__ = ['WhereHandler']
-
-
-class WhereHandler(OperatorHandler):
- """
- An OperatorHandler which deals with the sharding strategies of torch.where.
- """
-
- def __init__(self, *args, **kwargs):
- # TODO: x or y could be scalar
- super().__init__(*args, **kwargs)
- assert len(self.predecessor_node) == 3
- self.condition_data = self.predecessor_node[0]._meta_data
- self.x_data = self.predecessor_node[1]._meta_data
- self.y_data = self.predecessor_node[2]._meta_data
- self.condition = self.predecessor_node[0]
- self.x = self.predecessor_node[1]
- self.y = self.predecessor_node[2]
- self.output_data = self.node._meta_data
-
- def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
- shape = list(input_.shape)
-
- # padding the shape to the same length as output_data
- while len(shape) < self.output_data.dim():
- shape.insert(0, 1)
- shape = torch.Size(shape)
-
- # if the sharding happens on a size one dimension, we should record it as R.
- processed_dim_partition_dict = deepcopy(dim_partition_dict)
- for dim_index, _ in dim_partition_dict.items():
- if shape[dim_index] == 1:
- processed_dim_partition_dict.pop(dim_index)
- for dim_index, sharding_index_list in processed_dim_partition_dict.items():
- sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
- sharding_size = reduce(operator.mul, sharding_list, 1)
- assert shape[
- dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
- sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=shape,
- dim_partition_dict=processed_dim_partition_dict)
-
- return sharding_spec
-
- def _generate_compute_cost(self, total_sharding_size):
- lhs_matrix_shape = self.lhs_data.shape[-2:]
- rhs_matrix_shape = self.rhs_data.shape[-2:]
- batch_dimensions_shape = self.output_data.shape[:-2]
- batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
- compute_cost = reduce(
- operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
- return compute_cost
-
- def _generate_resharding_costs(self, sharding_specs):
- # The resharding_cost of weight is counted due to sharing weight cases.
- dtype = self.node._meta_data.dtype
- nodes = self.predecessor_node
- resharding_costs = {}
- size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
-
- # shape consistency manager is a singleton class
- shape_consistency_manager = ShapeConsistencyManager()
-
- for input_node, input_spec in zip(nodes, sharding_specs):
- resharding_costs[input_node] = []
- for strategy in input_node.strategies_vector:
- input_sharding_spec = strategy.output_sharding_spec
- assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
- # if the input shape is smaller than the target input, we will fill the input to the same length as target.
- # Then, use the padded input sharding spec to compute the resharding cost.
- if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
- new_entire_shape = list(input_sharding_spec.entire_shape)
- while len(new_entire_shape) < len(input_spec.entire_shape):
- new_entire_shape.insert(0, 1)
- new_entire_shape = torch.Size(new_entire_shape)
- new_device_mesh = input_sharding_spec.device_mesh
- new_dim_partition_dict = input_sharding_spec.dim_partition_dict
- input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
- entire_shape=new_entire_shape,
- dim_partition_dict=new_dim_partition_dict)
-
- # compute the resharding cost
- _, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
- input_sharding_spec, input_spec)
- total_resharding_cost = total_resharding_cost['total']
- # we need multiply the size of elem dtype to get correct communication cost
- resharding_cost = total_resharding_cost * size_per_elem_bytes
- resharding_costs[input_node].append(resharding_cost)
-
- return resharding_costs
-
- def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
-
- sharding_spec_list = []
- check_duplicated_list = []
- for output_dim_partition_dict in dim_partition_list:
- try:
- output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
- except AssertionError as e:
- warnings.warn(f'{e}')
- break
- sharding_seq = output_sharding_spec.sharding_sequence
- if sharding_seq not in check_duplicated_list:
- check_duplicated_list.append(sharding_seq)
- sharding_spec_list.append(output_sharding_spec)
-
- return sharding_spec_list
-
- def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
- # use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
-
- output_dim_partition_list = []
- dim_size = self.output_data.dim()
- # enumerate all the 2D sharding cases
- sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
- output_dim_partition_list.extend(sharding_list_2d)
-
- # enumerate all the 1D sharding cases
- sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
- output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
- sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
- output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
-
- # add empty dict for fully replicated case
- output_dim_partition_list.append({})
- output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
-
- return output_sharding_spec_list
-
- @ignore_sharding_exception
- def _register_strategy(self, output_sharding_spec):
- dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
- sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)
- sharding_spec_for_x = self._generate_sharding_spec(self.x_data, dim_partition_dict_for_input)
- sharding_spec_for_y = self._generate_sharding_spec(self.y_data, dim_partition_dict_for_input)
-
- name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_condition.sharding_sequence} x {sharding_spec_for_x.sharding_sequence} x {sharding_spec_for_y.sharding_sequence}'
- dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
-
- # generate resharding cost for this strategy
- resharding_costs = self._generate_resharding_costs(
- [sharding_spec_for_condition, sharding_spec_for_x, sharding_spec_for_y])
-
- # compute the computation cost of this strategy
- sharding_dims = []
- for mesh_dims in dim_partition_dict_for_output.values():
- for mesh_dim in mesh_dims:
- sharding_dims.append(self.device_mesh.shape[mesh_dim])
- sharding_size = reduce(operator.mul, sharding_dims, 1)
- memory_cost = self.output_data.numel() / sharding_size
- compute_cost = memory_cost
- communication_cost = 0
-
- sharding_strategies = ShardingStrategy(name,
- output_sharding_spec=output_sharding_spec,
- compute_cost=compute_cost,
- communication_cost=communication_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=(sharding_spec_for_condition, sharding_spec_for_x,
- sharding_spec_for_y))
-
- self.strategies_vector.append(sharding_strategies)
-
- def register_strategy(self) -> StrategiesVector:
- MESH_DIM_LIST = [0, 1]
- output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
- for output_sharding_spec in output_sharding_specs:
- self._register_strategy(output_sharding_spec)
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/options.py b/colossalai/auto_parallel/tensor_shard/deprecated/options.py
deleted file mode 100644
index 2d34f5c6447e..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/options.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from dataclasses import dataclass
-
-__all__ = ['SolverOptions']
-
-
-@dataclass
-class SolverOptions:
- """
- SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
- """
- fast: bool = False
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py
deleted file mode 100644
index d468c858e9a9..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/sharding_strategy.py
+++ /dev/null
@@ -1,91 +0,0 @@
-from copy import deepcopy
-from dataclasses import dataclass
-from abc import ABC, abstractmethod
-from enum import Enum
-import operator
-import torch
-from functools import reduce
-
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.sharding_spec import ShardingSpec
-from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
-from typing import Dict, List, Union, Tuple, Any
-from torch.fx.node import Node
-from .constants import *
-
-__all__ = ['ShardingStrategy', 'StrategiesVector']
-
-
-@dataclass
-class ShardingStrategy:
- '''
- ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
- and costs information using in solver.
-
- Argument:
- name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
- output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
- compute_cost(float): Computation cost to complete this strategy.(default to 0)
- communication_cost(float): Communication cost to complete this strategy.(default to 0)
- memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
- resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
- with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
- strategy.(default to None)
- input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
- '''
-
- name: str
- # TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
- output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
- compute_cost: float = 0.
- communication_cost: float = 0.
- memory_cost: float = 0.
- resharding_costs: Dict[Node, List[float]] = None
- # sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
- # Therefore, we could process them at the specific op(operator.getitem)
- input_shardings: List[ShardingSpec] = None
-
-
-class StrategiesVector(list):
- '''
- Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
- strategies of the node.
-
- Argument:
- node (Node): node for which the list of sharding strategies are generated.
- '''
-
- def __init__(self, node: Node):
- super().__init__()
- self.node = node
- # fetch its input and output nodes
- # TODO: placeholder input nodes
- self.predecessor_nodes = list(node._input_nodes.keys())
- if self.node.op == 'output':
- self.predecessor_nodes = list(node._input_nodes.keys())[:1]
- self.successor_nodes = list(node.users.keys())
-
- def check_merge(self):
- merge_label = False
- if self.node.op == 'call_module':
- target = self.node.target
- root_module = self.node.graph.owning_module
- submod = root_module.get_submodule(target)
- submod_type = type(submod)
- # merge elementwise module node into source nodes
- # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
- if submod_type in ELEMENTWISE_MODULE_OP:
- merge_label = True
-
- if self.node.op == 'call_function':
- # we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
- if self.node.target in ELEMENTWISE_FUNC_OP:
- merge_label = True
- # we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
- if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
- merge_label = True
- # we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
- if self.node.target in RESHAPE_FUNC_OP:
- merge_label = True
-
- return merge_label
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/solver.py b/colossalai/auto_parallel/tensor_shard/deprecated/solver.py
deleted file mode 100644
index 4c1d2f3bed5a..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/solver.py
+++ /dev/null
@@ -1,469 +0,0 @@
-import multiprocessing
-import time
-import warnings
-from typing import Dict
-
-import numpy as np
-from torch.fx.graph import Graph
-from torch.fx.node import Node
-
-from .constants import INFINITY_COST
-from .cost_graph import CostGraph
-from .graph_analysis import GraphAnalyser
-from .strategies_constructor import StrategiesConstructor
-
-try:
- import pulp
- from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
-except:
- warnings.warn(f'please install the pulp')
-
-__all___ = ['Solver']
-
-
-class Solver:
-
- def __init__(self,
- graph: Graph,
- strategies_constructor: StrategiesConstructor,
- cost_graph: CostGraph,
- graph_analyser: GraphAnalyser,
- memory_budget: float = -1.0,
- solution_numbers: int = 1,
- memory_increasing_coefficient: float = 1.3):
- '''
- Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
-
- Argument:
- graph: The computing graph to be optimized.
- strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
- cost_graph: A graph data structure to simplify the edge cost graph.
- graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
- memory_budget: Memory constraint for the solution.
- solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
- memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
- '''
- self.graph = graph
- self.strategies_constructor = strategies_constructor
- self.cost_graph = cost_graph
- self.graph_analyser = graph_analyser
- self.leaf_strategies = self.strategies_constructor.leaf_strategies
- self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
- self.strategy_map = self.strategies_constructor.strategy_map
- self.memory_budget = memory_budget
- self.solution_numbers = solution_numbers
- if self.solution_numbers > 1:
- self.memory_increasing_coefficient = memory_increasing_coefficient
- else:
- self.memory_increasing_coefficient = 1
- self.liveness_list = self.graph_analyser.liveness_analysis()
- self.node_index_dict = self._generate_node_index_dict()
- # The last solution vector of auto sharding.
- self.last_s_val = None
- # The last objective value of the best ILP solution.
- self.last_objective = None
-
- def _recover_merged_node_strategy(self):
- '''
- During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
- Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
- node.
- '''
- for node_index, node in enumerate(self.nodes):
- if node.strategies_vector.check_merge():
- # the merged node has only one input, and its strategies follow the input sharding strategy
- input_strategies_vector = node.args[0].strategies_vector
- input_best_strategy_index = self.last_s_val[node_index - 1]
- input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
- for strategy_index, strategy in enumerate(node.strategies_vector):
- if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
- self.last_s_val[node_index] = strategy_index
- break
-
- def _generate_node_index_dict(self) -> Dict[Node, int]:
- node_index_dict = {}
- for index, strategies_vector in enumerate(self.leaf_strategies):
- node_index_dict[strategies_vector.node] = index
- return node_index_dict
-
- def _prepare_data_for_solver(self):
- '''
- Extract information from components for solver.
- '''
- node_nums = len(self.leaf_strategies)
- memory_budget = self.memory_budget
-
- # prepare strategies_len
- strategies_len = []
- for node in self.nodes:
- strategies_len.append(self.cost_graph.node_lens[node])
- strategies_len = np.array(strategies_len)
-
- # prepare following_nodes
- following_nodes = self.cost_graph.following_dict
- index_following_nodes = {}
- for src, target in following_nodes.items():
- src_index = self.node_index_dict[src]
- target_index = self.node_index_dict[target]
- index_following_nodes[src_index] = target_index
- following_nodes = index_following_nodes
- for index in range(node_nums):
- if index not in following_nodes:
- following_nodes[index] = -1
-
- # prepare edge_pairs and resharding costs
- edge_pairs = []
- resharding_costs = []
- for pairs, edge_cost in self.cost_graph.edge_costs.items():
- src_node = pairs[0]
- dst_node = pairs[1]
- src_node_index = self.node_index_dict[src_node]
- dst_node_index = self.node_index_dict[dst_node]
- edge_pairs.append(src_node_index)
- edge_pairs.append(dst_node_index)
-
- for i in range(strategies_len[src_node_index]):
- for j in range(strategies_len[dst_node_index]):
- resharding_costs.append(edge_cost[(i, j)])
- edge_pairs = np.array(edge_pairs)
- resharding_costs = np.array(resharding_costs)
-
- # prepare liveness_set
- liveness_set = self.liveness_list
-
- # omit alias_set now
- alias_set = None
- alias_convert_costs = None
-
- # prepare compute_costs, communication_costs and memory_costs
- compute_costs = []
- communication_costs = []
- memory_costs = []
- extra_node_costs = self.cost_graph.extra_node_costs
- for strategies_vector in self.leaf_strategies:
- node = strategies_vector.node
- for index, strategy in enumerate(strategies_vector):
- compute_costs.append(strategy.compute_cost)
- # node in extra_node_costs means it has some extra communication
- # cost from node merging, so we need to add those extra communication
- # cost into
- if node in extra_node_costs:
- origin_communication_cost = strategy.communication_cost
- extra_node_cost = extra_node_costs[node][index]
- communication_cost = origin_communication_cost + extra_node_cost
- communication_costs.append(communication_cost)
- else:
- communication_costs.append(strategy.communication_cost)
- # temporarily we just consider the forward memory cost
- memory_cost = strategy.memory_cost
- if isinstance(memory_cost, tuple):
- memory_costs.append(memory_cost[0])
- else:
- memory_costs.append(memory_cost)
- compute_costs = np.array(compute_costs)
- communication_costs = np.array(communication_costs)
- memory_costs = np.array(memory_costs)
-
- # omit initial value for nodes
- s_init_np = None
-
- return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
-
- def _call_solver_serialized_args(self,
- node_nums,
- memory_budget,
- strategies_len,
- following_nodes,
- edge_pairs,
- alias_set,
- liveness_set,
- compute_costs,
- communication_costs,
- memory_costs,
- resharding_costs,
- alias_convert_costs,
- s_init_np=None):
- """
- Call the solver with serialized arguments.
- """
-
- tic = time.time()
-
- for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
- assert isinstance(x, np.ndarray)
- assert len(strategies_len) == node_nums, "strategies_len"
-
- def get_non_zero_index(binary_vector):
- """
- Get the index of non-zero item in a vector.
- """
- ct = 0
- ret = None
- for i, elem in enumerate(binary_vector):
- if pulp.value(elem):
- ret = i
- ct += 1
-
- assert ct == 1
- return ret
-
- # 0. Unpack flatten numpy arrays
- s_follow = following_nodes
-
- E = edge_pairs.reshape((-1, 2)) # noqa
- r = []
- pt = 0
- edge_set = set()
- for (i, j) in E:
- prod_length = strategies_len[i] * strategies_len[j]
-
- if (i, j) in edge_set:
- raise ValueError(f"Duplicated edges: {(i, j)}")
-
- edge_set.add((i, j))
- r.append(resharding_costs[pt:pt + prod_length])
- pt += prod_length
- assert pt == len(resharding_costs)
-
- ######################
- # omit alias set now #
- ######################
-
- # A = alias_set.reshape((-1, 2)) # noqa
- # for (i, j) in A:
- # prod_length = strategies_len[i] * strategies_len[j]
- # v.append(alias_convert_costs[pt:pt + prod_length])
- # pt += prod_length
- # assert pt == len(alias_convert_costs)
-
- # L = [] # noqa
- # pt = node_nums
- # for i in range(node_nums):
- # length = liveness_set[i]
- # L.append(liveness_set[pt:pt + length])
- # pt += length
- # assert pt == len(liveness_set)
- v = []
- pt = 0
-
- c = []
- d = []
- m = []
- pt = 0
- for i in range(node_nums):
- length = strategies_len[i]
- c.append(compute_costs[pt:pt + length])
- d.append(communication_costs[pt:pt + length])
- m.append(memory_costs[pt:pt + length])
- pt += length
- assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
- assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
- assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
-
- # 1. Create variables
-
- #############################
- # create variables for node #
- #############################
- s = []
- num_nodes = 0
- reverse_follow_backpatch = []
- for i in range(node_nums):
- if s_follow[i] < 0:
- if strategies_len[i] == 1:
- s.append([1])
- else:
- num_nodes += 1
- s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
- else:
- if s_follow[i] < len(s):
- s.append(s[s_follow[i]])
- else:
- s.append(None)
- reverse_follow_backpatch.append(i)
-
- for i in reverse_follow_backpatch:
- s[i] = s[s_follow[i]]
-
- #############################
- # create variables for edge #
- #############################
- e = []
- num_edges = 0
- for (idx, (i, j)) in enumerate(E):
- if len(s[i]) == 1:
- e.append(s[j])
- elif len(s[j]) == 1:
- e.append(s[i])
- else:
- num_edges += 1
- e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
- assert len(e[idx]) == len(r[idx])
- for element in s:
- assert len(element) > 0
- # 2. Set initial value
- ######################################
- # set a initial value for warm start #
- ######################################
- if s_init_np is not None:
- s_init = s_init_np.reshape((-1, 3))
- for (idx, value, fix) in s_init:
- for i in range(len(s[idx])):
- s[idx][i].setInitialValue(i == value)
- if fix:
- s[idx][i].fixValue()
-
- # 3. Objective
- prob = LpProblem("myProblem", LpMinimize)
- ###################################################################
- # computing the node cost(computing cost and communication cost) #
- ###################################################################
- obj = 0
- for i in range(node_nums):
- assert len(s[i]) == len(c[i])
- assert len(s[i]) == len(d[i])
-
- obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
-
- #############################################
- # computing the edge cost(resharding cost) #
- #############################################
- for i in range(len(E)):
- assert len(e[i]) == len(r[i])
- obj += lpDot(e[i], r[i])
-
- prob += obj
-
- # 4. Constraints
- # (a). specified by `cat="Binary"`
-
- # (b)
- #################################################
- # make sure each node only choose one strategy #
- #################################################
- for i in range(node_nums):
- if s_follow[i] < 0:
- prob += lpSum(s[i]) == 1
-
- # (c)
- #################################################
- # compute memory consumption with liveness set #
- #################################################
- if memory_budget > 0:
- for liveness_stage in liveness_set:
- mem = 0
- for live_variable in liveness_stage.unique_live_vars:
- node_index = self.node_index_dict[live_variable.node]
- mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
- prob += mem <= memory_budget
-
- # (d). specified by `cat="Binary"`
-
- for (idx, (i, j)) in enumerate(E):
- if strategies_len[i] == 1 or strategies_len[j] == 1:
- continue
-
- # (e)
- prob += lpSum(e[idx]) == 1
-
- # (f)
- for row in range(len(s[i])):
- C = len(s[j]) # noqa
- prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
-
- # (g)
- for col in range(len(s[j])):
- R = len(s[i]) # noqa
- C = len(s[j]) # noqa
- prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
-
- # (h)
- ######################
- # omit alias set now #
- ######################
-
- # alias_set = set()
- # for (idx, (i, j)) in enumerate(A):
- # R = len(s[i]) # noqa
- # C = len(s[j]) # noqa
- # if (i, j) in alias_set:
- # raise ValueError(f"Duplicated edges: {(i, j)}")
-
- # alias_set.add((i, j))
- # alias_set.add((j, i))
-
- # for row in range(len(s[i])):
- # for col in range(len(s[j])):
- # if v[idx][row * C + col] > 0.5:
- # prob += s[i][row] + s[j][col] <= 1
-
- verbose = True
-
- msg = verbose
- time_limit = 600
- assert "COIN_CMD" in pulp.listSolvers(
- onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
-
- solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
- # solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
- prob.solve(solver)
-
- status = prob.status
- objective = pulp.value(prob.objective)
- objective = float(objective) if objective is not None else -1.0
- if verbose:
- print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
- f"Time: {time.time() - tic}")
- print(f"#nodes: {num_nodes}, #edges: {num_edges}")
-
- if prob.status in [pulp.LpStatusInfeasible]:
- raise RuntimeError("Cannot run the function under the given memory budget. "
- "Please increase the memory budget.")
-
- # Get and check results
- s_val = np.full((node_nums,), -1, dtype=np.int32)
- for i in range(node_nums):
- s_val[i] = get_non_zero_index(s[i])
-
- e_val = np.full((len(E),), -1, dtype=np.int32)
- for (idx, (i, j)) in enumerate(E):
- e_val[idx] = get_non_zero_index(e[idx])
- i_spec_index = e_val[idx] // len(s[j])
- j_spec_index = e_val[idx] % len(s[j])
- assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
- assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
- if verbose and r[idx][e_val[idx]] > 0:
- print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
-
- self.last_s_val = list(s_val)
- self._recover_merged_node_strategy()
- self.last_objective = objective
-
- if objective > INFINITY_COST:
- warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
-
- return self.last_s_val, e_val, self.last_objective, status
-
- def call_solver_serialized_args(self):
- """
- Call the solver with serialized arguments and handle python errors. Additionally,
- we could give a serious of solutions with different memory budget.
- """
- if self.solution_numbers == 1:
- args = self._prepare_data_for_solver()
- ret = self._call_solver_serialized_args(*args)
-
- return ret
-
- origin_memory_budget = self.memory_budget
- memory_budget_list = [
- origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
- ]
- ret_list = []
- for memory_budget in memory_budget_list:
- self.memory_budget = memory_budget
- args = self._prepare_data_for_solver()
- ret = self._call_solver_serialized_args(*args)
- ret_list.append(ret)
-
- return ret_list
diff --git a/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py
deleted file mode 100644
index 7bebde9d65a0..000000000000
--- a/colossalai/auto_parallel/tensor_shard/deprecated/strategies_constructor.py
+++ /dev/null
@@ -1,426 +0,0 @@
-import builtins
-import math
-import operator
-from copy import deepcopy
-from typing import Dict, List
-
-import torch
-from torch.fx import Graph, Node
-
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-from ._utils import generate_resharding_costs, generate_sharding_spec
-from .constants import *
-from .op_handler import *
-from .options import SolverOptions
-from .sharding_strategy import ShardingStrategy, StrategiesVector
-
-__all__ = ['StrategiesConstructor']
-
-
-class StrategiesConstructor:
- """
- StrategiesConstructor is used to construct the parallelization plan for the model execution.
-
- Args:
- graph (Graph): a Graph object used for analysis and strategy generation.
- device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
- solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
- """
-
- def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
- self.graph = graph
- assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
- self.root_module = self.graph.owning_module
- self.nodes = list(graph.nodes)
- self.device_mesh = device_mesh
- self.leaf_strategies = []
- self.strategy_map = {}
- self.solver_options = solver_options
-
- def remove_duplicated_strategy(self, strategies_vector):
- '''
- In build_strategies_and_cost method, we may produce some duplicated strategies.
- In this method, we will remove the duplicated strategies depending on the strategies name.
- '''
- name_checklist = []
- remove_list = []
- for strategy in strategies_vector:
- if strategy.name not in name_checklist:
- name_checklist.append(strategy.name)
- else:
- remove_list.append(strategy)
-
- for strategy in remove_list:
- strategies_vector.remove(strategy)
-
- def _is_bcast_matmul(self, node):
- is_bcast_matmul = False
- if node.target is torch.matmul and len(node.args) == 2:
- lhs_data = node.args[0]._meta_data
- rhs_data = node.args[1]._meta_data
- if lhs_data.dim() >= 3 and rhs_data.dim() >= 3:
- is_bcast_matmul = True
- return is_bcast_matmul
-
- def build_strategies_and_cost(self):
- for node in self.nodes:
- strategies_vector = StrategiesVector(node)
- input_nodes_len = 0
- for check_node in strategies_vector.predecessor_nodes:
- if isinstance(check_node._meta_data, torch.Tensor):
- input_nodes_len += 1
- # input_nodes_len = len(strategies_vector.predecessor_nodes)
- # placeholder node
- if node.op == 'placeholder':
- # For placeholder nodes, if solver_options.fast is True, we just let them in
- # fully replicate status, then strategies of following node will be treated equally due
- # to replicate status has no resharding cost to other status. At the same time, the searching
- # space is smaller than enumerating all the possible sharding spec for the placeholder node.
- # Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
-
- if self.solver_options.fast:
- # create sharding strategy for placeholder
- name = 'Replica Placeholder'
- dim_partition_dict = {}
- output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
- # TODO: use meta_info_prop to profile memory cost
- memory_cost = 0
- sharding_strategy_placeholder = ShardingStrategy(name,
- output_sharding_spec,
- memory_cost=memory_cost)
- strategies_vector.append(sharding_strategy_placeholder)
-
- # get_attr node
- if node.op == 'get_attr':
- # Same as placeholder nodes, if solver_options.fast is True, we just let them in
- # fully replicate status, then strategies of following node will be treated equally due
- # to replicate status has no resharding cost to other status. At the same time, the searching
- # space is smaller than enumerating all the possible sharding spec for the get_attr node.
- # Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
- if self.solver_options.fast:
- # create sharding strategy for get_attr
- name = 'Replica Attribute'
- dim_partition_dict = {}
- output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
- # TODO: use meta_info_prop to profile memory cost
- memory_cost = 0
- sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
- strategies_vector.append(sharding_strategy_attribute)
-
- # call_module node
- if node.op == 'call_module':
-
- target = node.target
- submod = self.root_module.get_submodule(target)
- submod_type = type(submod)
-
- # conv module
- if submod_type in CONV_MODULE_OP:
- # use ConvHandler to create sharding strategies for conv module node
- conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
- conv_handler.register_strategy()
-
- # linear module
- elif submod_type in LINEAR_MODULE_OP:
- # use DotHandler to create sharding strategies for linear module node
- dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
- dot_handler.register_strategy()
-
- # element-wise module
- elif submod_type in ELEMENTWISE_MODULE_OP:
- unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
- unary_elementwise_handler.register_strategy()
-
- # BatchNormNd module
- elif submod_type in BATCHNORM_MODULE_OP:
- # create sharding strategy for element-wise module
- norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector)
- norm_handler.register_strategy()
- # for strategy in norm_handler.strategies_vector:
- # print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
- # assert False
-
- # MaxPool module
- elif submod_type in POOL_MODULE_OP:
- # TODO: add sharding constraints on image dimension
- # e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
-
- # create sharding strategy for element-wise module
- assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.'
- input_node = strategies_vector.predecessor_nodes[0]
- # For element-wise module, we keep the sharding spec of output node same as
- # the input. Therefore, the different strategies of input node with same
- # output sharding spec will generate same strategy for element-wise module.
- sharding_spec_checklist = []
- for strategy in input_node.strategies_vector:
- # It looks a little bit confusing, the input of the processing node
- # is the output of the input_node.
- input_sharding_spec = strategy.output_sharding_spec
- assert isinstance(input_sharding_spec,
- ShardingSpec), f'The input node should NOT be a tuple of tensor.'
- if input_sharding_spec in sharding_spec_checklist:
- continue
-
- sharding_spec_checklist.append(input_sharding_spec)
- dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
- output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
-
- name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
-
- # TODO: use meta_info_prop to profile memory cost and compute cost
- compute_cost = node._meta_data.numel()
- memory_cost = 0
- resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
- [input_sharding_spec])
-
- sharding_strategy = ShardingStrategy(name,
- output_sharding_spec,
- compute_cost=compute_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=[input_sharding_spec])
- strategies_vector.append(sharding_strategy)
-
- # embedding module
- elif submod_type in EMBEDDING_MODULE_OP:
- embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector)
- embedding_handler.register_strategy()
-
- # layernorm module
- elif submod_type in LAYERNORM_MODULE_OP:
- layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector)
- layernorm_handler.register_strategy()
- # other module
- else:
- raise RuntimeError(f'{submod_type} module is NOT supported now.')
-
- # call_function node
- if node.op == 'call_function':
- target = node.target
- # conv function
- if target in CONV_FUNC_OP:
- # use ConvHandler to create sharding strategies for conv node
- # TODO: the operator_handler does NOT support function node processing now.
- conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
- conv_handler.register_strategy()
-
- # linear function
- elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node):
- # use DotHandler to create sharding strategies for linear node
- # TODO: the operator_handler does NOT support function node processing now.
- linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
- linear_handler.register_strategy()
-
- # where function
- elif target == torch.where:
- if input_nodes_len == 1:
- # both of x and y are scalar
- pass
-
- elif input_nodes_len == 2:
- # one of x or y is type of scalar
- pass
-
- else:
- # general case
- where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
- where_handler.register_strategy()
-
- # reshape function
- elif target in RESHAPE_FUNC_OP:
- # use ReshapeHandler to create sharding strategies for rehsape node
- reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
- reshape_handler.register_strategy()
-
- # element-wise function
- elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
- unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
- unary_elementwise_handler.register_strategy()
-
- # bcast op
- elif target in BCAST_FUNC_OP:
- if isinstance(node._meta_data, torch.Tensor):
- bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
- bcast_op_handler.register_strategy()
-
- # torch.var_mean
- elif target == torch.var_mean:
- dim = node.kwargs['dim']
- input_tensor_node = strategies_vector.predecessor_nodes[0]
- for strategy in input_tensor_node.strategies_vector:
- input_sharding_spec = strategy.output_sharding_spec
- assert isinstance(input_sharding_spec,
- ShardingSpec), f'The input node should NOT be a tuple of tensor.'
- entire_shape_input = input_sharding_spec.entire_shape
- dim_partition_dict_input = input_sharding_spec.dim_partition_dict
- name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
- if dim in dim_partition_dict_input:
- # We need to make the action dimension in replicate status
- dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
- dim_partition_dict_for_input.pop(dim)
- new_input_sharding_spec = ShardingSpec(self.device_mesh,
- entire_shape_input,
- dim_partition_dict=dim_partition_dict_for_input)
- entire_shape_output = deepcopy(entire_shape_input)
- entire_shape_output.pop(dim)
- dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
- output_sharding_spec = ShardingSpec(self.device_mesh,
- entire_shape_output,
- dim_partition_dict=dim_partition_dict_for_input)
- # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
- compute_cost = 0
- memory_cost = 0
- resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
- [new_input_sharding_spec])
- sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
- compute_cost=compute_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=[new_input_sharding_spec])
-
- else:
- entire_shape_output = deepcopy(entire_shape_input)
- entire_shape_output.pop(dim)
- dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
- output_sharding_spec = ShardingSpec(self.device_mesh,
- entire_shape_output,
- dim_partion_dict=dim_partition_dict_input)
- # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
- compute_cost = 0
- memory_cost = 0
- resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
- [input_sharding_spec])
- sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
- compute_cost=compute_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=[input_sharding_spec])
-
- strategies_vector.append(sharding_strategy)
-
- # operator.getitem
- elif target == operator.getitem:
- index = node.args[1]
- input_tensor_node = strategies_vector.predecessor_nodes[0]
- for strategy in input_tensor_node.strategies_vector:
- if isinstance(strategy.output_sharding_spec, ShardingSpec):
- input_sharding_spec = strategy.output_sharding_spec
- else:
- input_sharding_spec = strategy.output_sharding_spec[index]
- assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
- dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
- entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
- output_sharding_spec = ShardingSpec(self.device_mesh,
- entire_shape_output,
- dim_partition_dict=dim_partition_dict_for_output)
- # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
- compute_cost = 0
- memory_cost = 0
- resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
- [input_sharding_spec],
- index=index)
- # to prevent the resharding happening, set their resharding cost to inf.
- resharding_costs[input_tensor_node] = [
- cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
- ]
- sharding_strategy = ShardingStrategy(name,
- output_sharding_spec,
- compute_cost=compute_cost,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=[strategy.output_sharding_spec])
- strategies_vector.append(sharding_strategy)
-
- # torch.arange function
- elif target == torch.arange:
- name = f'FULLY REPLICATED ARANGE'
- entire_shape_output = node._meta_data.shape
- dim_partition_dict_for_output = {}
- output_sharding_spec = ShardingSpec(self.device_mesh,
- entire_shape_output,
- dim_partition_dict=dim_partition_dict_for_output)
- memory_cost = node._meta_data.numel()
- sharding_strategy = ShardingStrategy(name,
- output_sharding_spec,
- compute_cost=0,
- memory_cost=memory_cost)
- strategies_vector.append(sharding_strategy)
-
- # op list to be processed to support gpt2
- elif target in (builtins.getattr, operator.le, torch.addmm):
- pass
- # other function
- else:
- raise RuntimeError(f'{target} function is NOT supported now.')
-
- # call_method node
- if node.op == 'call_method':
- method = getattr(node.args[0]._meta_data.__class__, node.target)
- if method in (torch.Tensor.size,):
- pass
- elif method in ELEMENTWISE_METHOD_OP:
- unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
- unary_elementwise_handler.register_strategy()
-
- elif method in RESHAPE_METHOD_OP:
- reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
- reshape_handler.register_strategy()
- # print(strategies_vector)
- # if len(strategies_vector) == 0:
- # print(node)
- # assert False
- else:
- raise RuntimeError(f'{method} function is NOT supported now.')
-
- # output node
- if node.op == 'output':
- if self.solver_options.fast:
- # create sharding strategy for output
- name = 'Replica Output'
- input_nodes = strategies_vector.predecessor_nodes
- input_sharding_specs = []
- for input_node in input_nodes:
- dim_partition_dict_for_input = {}
- entire_shape = input_node._meta_data.shape
- sharding_spec = ShardingSpec(self.device_mesh,
- entire_shape,
- dim_partition_dict=dim_partition_dict_for_input)
- input_sharding_specs.append(sharding_spec)
-
- dim_partition_dict = {}
- output_sharding_spec = input_sharding_specs
- # TODO: use meta_info_prop to profile memory cost
- memory_cost = 0
- resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
- input_sharding_specs)
-
- # clear the resharding cost for the output node
- # TODO: we may remove this in final version
- for prev_node, resharding_cost_list in resharding_costs.items():
- resharding_costs[prev_node] = [0] * len(resharding_cost_list)
-
- sharding_strategy_attribute = ShardingStrategy(name,
- output_sharding_spec,
- memory_cost=memory_cost,
- resharding_costs=resharding_costs,
- input_shardings=tuple(input_sharding_specs))
- strategies_vector.append(sharding_strategy_attribute)
-
- self.remove_duplicated_strategy(strategies_vector)
- setattr(node, 'strategies_vector', strategies_vector)
- self.leaf_strategies.append(strategies_vector)
- self.strategy_map[node] = strategies_vector
-
- # remove no strategy nodes
- remove_list = []
- for strategies_vector in self.leaf_strategies:
- if len(strategies_vector) == 0:
- remove_list.append(strategies_vector.node)
- for node in remove_list:
- if node.strategies_vector in self.leaf_strategies:
- self.leaf_strategies.remove(node.strategies_vector)
- if node in self.strategy_map:
- self.strategy_map.pop(node)
diff --git a/colossalai/auto_parallel/tensor_shard/initialize.py b/colossalai/auto_parallel/tensor_shard/initialize.py
index 0dce2564c519..60472eee52ca 100644
--- a/colossalai/auto_parallel/tensor_shard/initialize.py
+++ b/colossalai/auto_parallel/tensor_shard/initialize.py
@@ -8,16 +8,12 @@
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
+from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
-from colossalai.auto_parallel.tensor_shard.solver import (
- CostGraph,
- GraphAnalyser,
- Solver,
- SolverOptions,
- StrategiesConstructor,
-)
+from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec
@@ -28,7 +24,7 @@ class ModuleWrapper(nn.Module):
into the forward function.
'''
- def __init__(self, module: GraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
+ def __init__(self, module: ColoGraphModule, sharding_spec_dict: Dict[int, List[ShardingSpec]],
origin_spec_dict: Dict[int, ShardingSpec], comm_actions_dict: Dict[int, Dict[str, CommAction]]):
'''
Args:
@@ -59,18 +55,6 @@ def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader,
pass
-def search_best_logical_mesh_shape(world_size: int, alpha_beta_dict: Dict[Tuple[int], Tuple[float]]):
- '''
- This method is used to search the best logical mesh shape for the given world size
- based on the alpha_beta_dict.
-
- For example:
- if the world_size is 8, and the possible logical shape will be (1, 8), (2, 4), (4, 2), (8, 1).
- '''
- # TODO: implement this function
- return (world_size, 1)
-
-
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
'''
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
@@ -80,45 +64,83 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f
pass
-def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh):
+def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
+ shard_option: str):
'''
This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler.
'''
- solver_options = SolverOptions()
+ if solver_preference == 'standard':
+ solver_preference = SolverPerference.STANDARD
+ elif solver_preference == 'tp':
+ solver_preference = SolverPerference.TP
+ elif solver_preference == 'dp':
+ solver_preference = SolverPerference.DP
+ else:
+ raise ValueError(f'Invalid solver_preference: {solver_preference}')
+
+ if dataloader_option == 'replicated':
+ dataloader_option = DataloaderOption.REPLICATED
+ elif dataloader_option == 'distributed':
+ dataloader_option = DataloaderOption.DISTRIBUTED
+ else:
+ raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
+
+ if shard_option == 'standard':
+ shard_option = ShardOption.STANDARD
+ elif shard_option == 'shard':
+ shard_option = ShardOption.SHARD
+ elif shard_option == 'shard_last_axis':
+ shard_option = ShardOption.SHARD_LAST_AXIS
+ elif shard_option == 'full_shard':
+ shard_option = ShardOption.FULL_SHARD
+ else:
+ raise ValueError(f'Invalid shard_option: {shard_option}')
+
+ solver_options = SolverOptions(solver_perference=solver_preference,
+ dataloader_option=dataloader_option,
+ shard_option=shard_option)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
return strategies_constructor
-def solve_solution(gm: GraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
+def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
'''
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
- graph_analyser = GraphAnalyser(gm)
- liveness_list = graph_analyser.liveness_analysis()
+ # temporarily we use all nodes as liveness list, we count the backward memory cost together with
+ # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
+ # graph_analyser = GraphAnalyser(gm)
+ # liveness_list = graph_analyser.liveness_analysis()
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
cost_graph.simplify_graph()
- solver = Solver(gm.graph, strategy_constructor, cost_graph, graph_analyser, memory_budget=memory_budget)
+ solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
return solution
-def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh: DeviceMesh,
- strategies_constructor: StrategiesConstructor):
+def transform_to_sharded_model(gm: ColoGraphModule,
+ solution: List[int],
+ device_mesh: DeviceMesh,
+ strategies_constructor: StrategiesConstructor,
+ overlap: bool = False):
'''
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
'''
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
- gm, solution, device_mesh, strategies_constructor)
+ gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
+ solution,
+ device_mesh,
+ strategies_constructor,
+ overlap=overlap)
gm = runtime_apply_pass(gm)
gm.recompile()
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
@@ -127,39 +149,56 @@ def transform_to_sharded_model(gm: GraphModule, solution: List[int], device_mesh
def initialize_device_mesh(world_size: int = -1,
+ physical_devices: List[int] = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
- logical_mesh_shape: Tuple[int] = None):
+ logical_mesh_shape: Tuple[int] = None,
+ logical_mesh_id: torch.Tensor = None):
'''
This method is used to initialize the device mesh.
Args:
- world_size(optional): the size of device mesh. If the world_size is -1,
+ world_size: the size of device mesh. If the world_size is -1,
the world size will be set to the number of GPUs in the current machine.
+ physical_devices: the physical devices used to initialize the device mesh.
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
generated by profile_alpha_beta function.
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
- mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
- generated by search_best_logical_mesh_shape function.
+ mesh shape.
+ logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
'''
# if world_size is not set, use the world size from torch.distributed
if world_size == -1:
world_size = dist.get_world_size()
- device1d = [i for i in range(world_size)]
+
+ if physical_devices is None:
+ physical_devices = [i for i in range(world_size)]
+ physical_mesh = torch.tensor(physical_devices)
if alpha_beta_dict is None:
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
- alpha_beta_dict = profile_alpha_beta(device1d)
+ ab_profiler = AlphaBetaProfiler(physical_devices)
+ alpha_beta_dict = ab_profiler.alpha_beta_dict
+ else:
+ ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)
- if logical_mesh_shape is None:
+ if logical_mesh_shape is None and logical_mesh_id is None:
# search for the best logical mesh shape
- logical_mesh_shape = search_best_logical_mesh_shape(world_size, alpha_beta_dict)
+ logical_mesh_id = ab_profiler.search_best_logical_mesh()
+ logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
+ logical_mesh_shape = logical_mesh_id.shape
+
+ # extract alpha and beta values for the chosen logical mesh shape
+ mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()
+
+ elif logical_mesh_shape is not None and logical_mesh_id is None:
+ logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
+
+ # extract alpha and beta values for the chosen logical mesh shape
+ mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
- # extract alpha and beta values for the chosen logical mesh shape
- mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_shape)
- physical_mesh = torch.tensor(device1d)
device_mesh = DeviceMesh(physical_mesh_id=physical_mesh,
- mesh_shape=logical_mesh_shape,
+ logical_mesh_id=logical_mesh_id,
mesh_alpha=mesh_alpha,
mesh_beta=mesh_beta,
init_process_group=True)
@@ -170,6 +209,10 @@ def initialize_model(model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
+ overlap: bool = False,
+ solver_preference: str = 'standard',
+ dataloader_option: str = 'replicated',
+ shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
@@ -183,6 +226,14 @@ def initialize_model(model: nn.Module,
device_mesh: the device mesh to execute the model.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
+ overlap(optional): the overlap is used to specify whether to overlap gradient communication and
+ backward computing.
+ solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
+ has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
+ dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
+ be used. The valid dataloader_option could be 'replicated' or 'distributed'.
+ shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
+ model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@@ -192,12 +243,17 @@ def initialize_model(model: nn.Module,
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
return a series of integers, but return the best strategies.
'''
- tracer = ColoTracer()
+ tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(root=model, meta_args=meta_args)
- gm = GraphModule(model, graph, model.__class__.__name__)
+ gm = ColoGraphModule(model, graph, model.__class__.__name__)
gm.recompile()
- strategies_constructor = build_strategy_constructor(graph, device_mesh)
+
+ strategies_constructor = build_strategy_constructor(graph,
+ device_mesh,
+ solver_preference=solver_preference,
+ dataloader_option=dataloader_option,
+ shard_option=shard_option)
if load_solver_solution:
solution = torch.load(solution_path)
else:
@@ -205,7 +261,7 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)
- gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
+ gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
if return_solution:
@@ -224,6 +280,10 @@ def autoparallelize(model: nn.Module,
data_process_func: callable = None,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None,
+ logical_mesh_id: torch.Tensor = None,
+ solver_preference: str = 'standard',
+ dataloader_option: str = 'replicated',
+ shard_option: str = 'standard',
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solver_solution_path: str = None,
@@ -245,6 +305,13 @@ def autoparallelize(model: nn.Module,
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function.
+ logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
+ solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
+ has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
+ dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
+ be used. The valid dataloader_option could be 'replicated' or 'distributed'.
+ shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
+ model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
@@ -254,16 +321,21 @@ def autoparallelize(model: nn.Module,
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
'''
- device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape)
+ device_mesh = initialize_device_mesh(alpha_beta_dict=alpha_beta_dict,
+ logical_mesh_shape=logical_mesh_shape,
+ logical_mesh_id=logical_mesh_id)
if meta_args is None:
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
rst_to_unpack = initialize_model(model,
meta_args,
device_mesh,
+ solver_preference=solver_preference,
+ dataloader_option=dataloader_option,
+ shard_option=shard_option,
save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution,
- solver_solution_path=solver_solution_path,
+ solution_path=solver_solution_path,
return_solution=return_solution,
memory_budget=memory_budget)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
index a5e3f649a345..9903ca54e52c 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
@@ -3,8 +3,8 @@
from .binary_elementwise_handler import BinaryElementwiseHandler
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
from .conv_handler import ConvFunctionHandler, ConvModuleHandler
+from .default_reshape_handler import DefaultReshapeHandler
from .embedding_handler import EmbeddingFunctionHandler, EmbeddingModuleHandler
-from .experimental import PermuteHandler, ViewHandler
from .getattr_handler import GetattrHandler
from .getitem_handler import GetItemHandler
from .layer_norm_handler import LayerNormModuleHandler
@@ -12,20 +12,24 @@
from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler
from .output_handler import OutputHandler
+from .permute_handler import PermuteHandler
from .placeholder_handler import PlaceholderHandler
from .registry import operator_registry
-from .reshape_handler import ReshapeHandler
from .softmax_handler import SoftmaxHandler
+from .split_handler import SplitHandler
from .sum_handler import SumHandler
from .tensor_constructor_handler import TensorConstructorHandler
+from .transpose_handler import TransposeHandler
from .unary_elementwise_handler import UnaryElementwiseHandler
+from .view_handler import ViewHandler
from .where_handler import WhereHandler
__all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
- 'UnaryElementwiseHandler', 'ReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
+ 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
- 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler'
+ 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
+ 'SplitHandler'
]
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
index f510f74776b6..db8f0b54ddee 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
@@ -32,20 +32,32 @@ def _get_op_data_type(tensor):
return OperationDataType.ARG
def _get_arg_value(idx):
+ non_tensor = False
if isinstance(self.node.args[idx], Node):
meta_data = self.node.args[idx]._meta_data
+ # The meta_data of node type argument could also possibly be a non-tensor object.
+ if not isinstance(meta_data, torch.Tensor):
+ assert isinstance(meta_data, (int, float))
+ meta_data = torch.Tensor([meta_data]).to('meta')
+ non_tensor = True
+
else:
# this is in fact a real data like int 1
# but we can deem it as meta data
# as it won't affect the strategy generation
assert isinstance(self.node.args[idx], (int, float))
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
- return meta_data
+ non_tensor = True
- input_meta_data = _get_arg_value(0)
- other_meta_data = _get_arg_value(1)
- output_meta_data = self.node._meta_data
+ return meta_data, non_tensor
+ input_meta_data, non_tensor_input = _get_arg_value(0)
+ other_meta_data, non_tensor_other = _get_arg_value(1)
+ output_meta_data = self.node._meta_data
+ # we need record op_data with non-tensor data in this list,
+ # and filter the non-tensor op_data in post_process.
+ self.non_tensor_list = []
+ # assert False
input_op_data = OperationData(name=str(self.node.args[0]),
type=_get_op_data_type(input_meta_data),
data=input_meta_data,
@@ -58,6 +70,10 @@ def _get_arg_value(idx):
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=bcast_shape)
+ if non_tensor_input:
+ self.non_tensor_list.append(input_op_data)
+ if non_tensor_other:
+ self.non_tensor_list.append(other_op_data)
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
return mapping
@@ -73,9 +89,10 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
op_data_mapping = self.get_operation_data_mapping()
for op_name, op_data in op_data_mapping.items():
- if not isinstance(op_data.data, torch.Tensor):
+ if op_data in self.non_tensor_list:
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
strategy.sharding_specs.pop(op_data)
+
else:
# convert the logical sharding spec to physical sharding spec if broadcast
# e.g. torch.rand(4, 4) + torch.rand(4)
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
similarity index 87%
rename from colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
rename to colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
index 7763b1884025..0c5b9f39e1fb 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/default_reshape_handler.py
@@ -5,23 +5,23 @@
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
-from .strategy import ReshapeGenerator, StrategyGenerator
+from .strategy import DefaultReshapeGenerator, StrategyGenerator
-__all__ = ['ReshapeHandler']
+__all__ = ['DefaultReshapeHandler']
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.unsqueeze)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
-class ReshapeHandler(MetaInfoNodeHandler):
+class DefaultReshapeHandler(MetaInfoNodeHandler):
"""
- A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
+ A DefaultReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
- generators.append(ReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
+ generators.append(DefaultReshapeGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def infer_logical_shape(self, data):
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
deleted file mode 100644
index 15f66104b156..000000000000
--- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
+++ /dev/null
@@ -1,10 +0,0 @@
-from .permute_handler import PermuteHandler
-from .reshape_generator import PermuteGenerator, SplitGenerator, TransposeGenerator, ViewGenerator
-from .split_handler import SplitHandler
-from .transpose_handler import TransposeHandler
-from .view_handler import ViewHandler
-
-__all__ = [
- 'ViewGenerator', 'ViewHandler', 'PermuteGenerator', 'PermuteHandler', 'TransposeGenerator', 'TransposeGenerator',
- 'SplitHandler', 'SplitGenerator'
-]
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py
deleted file mode 100644
index b7248d011950..000000000000
--- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/reshape_generator.py
+++ /dev/null
@@ -1,299 +0,0 @@
-import copy
-from typing import List
-
-from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
- CommAction,
- CommType,
- MemoryCost,
- ShardingStrategy,
- TrainCycleItem,
-)
-from colossalai.auto_parallel.tensor_shard.utils import (
- check_keep_sharding_status,
- detect_reshape_mapping,
- infer_output_dim_partition_dict,
-)
-from colossalai.tensor.shape_consistency import CollectiveCommPattern
-from colossalai.tensor.sharding_spec import ShardingSpec
-
-__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
-
-
-class ReshapeGenerator(FollowingStrategyGenerator):
- """
- ReshapeGenerator is the base class for all the reshape operation.
- """
-
- def validate(self) -> bool:
- return super().validate()
-
- def update_compute_cost(self, strategy: ShardingStrategy):
- compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
- strategy.compute_cost = compute_cost
-
- def update_memory_cost(self, strategy: ShardingStrategy):
- '''
- Compute the memory cost per device with this specific strategy.
- '''
- forward_size_mapping = {
- 'input': self._compute_size_in_bytes(strategy, "input"),
- 'output': self._compute_size_in_bytes(strategy, "output")
- }
-
- backward_size_mapping = copy.deepcopy(forward_size_mapping)
- backward_size_mapping.pop("output")
- # compute fwd cost incurred
- # fwd_cost = input + output
- fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
- fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
- fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
-
- # compute bwd cost incurred
- # bwd_cost = input_grad
- bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
- bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
- bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
-
- # compute total cost
- total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
- parameter=fwd_parameter_cost + bwd_parameter_cost)
- memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
- strategy.memory_cost = memory_cost
-
- def collate_strategies(self) -> List[ShardingStrategy]:
- return super().collate_strategies()
-
-
-class ViewGenerator(ReshapeGenerator):
- """
- ViewGenerator deals with the sharding strategies of view op.
- """
-
- def collate_strategies(self) -> List[ShardingStrategy]:
- strategy_list = []
- for index, strategy in enumerate(self.predecessor_node.strategies_vector):
- dim_partition_dict_mapping = {}
- communication_action_mapping = {}
- input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
-
- origin_shape = self.op_data['input'].data.shape
- tgt_shape = self.op_data['tgt_shape'].data
-
- reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
-
- dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
- keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
-
- if keep_sharding_status:
- dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
- reshape_mapping_dict)
- else:
- dim_partition_dict_for_output = {}
-
- dim_partition_dict_mapping = {
- "input": dim_partition_dict_for_input,
- "output": dim_partition_dict_for_output,
- }
- sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
-
- # add index into name to pass the duplicated check
- # we keep same strategies with different name for node merging, and it will not increase the searching space,
- # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
- if keep_sharding_status:
- name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
- else:
- name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
-
- # add comm action for converting input to fully replicated
- total_mesh_dim_list = []
- for mesh_dim_list in dim_partition_dict_for_input.values():
- total_mesh_dim_list.extend(mesh_dim_list)
- # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
- if len(total_mesh_dim_list) == 1:
- total_mesh_dim_list = total_mesh_dim_list[0]
- # the total mesh dim list only has one element, so the shard dim has only one element as well.
- shard_dim = list(dim_partition_dict_for_input.keys())[0]
- input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping["input"],
- communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
- logical_process_axis=total_mesh_dim_list,
- comm_type=CommType.BEFORE,
- arg_index=0)
- # it will gather the input through gather_dim during forward phase.
- input_comm_action.comm_spec.gather_dim = shard_dim
- # it will split the input activation grad through shard_dim during backward phase.
- input_comm_action.comm_spec.shard_dim = shard_dim
-
- elif len(total_mesh_dim_list) >= 2:
- source_spec = sharding_spec_mapping["input"]
- target_spec = ShardingSpec(device_mesh=self.device_mesh,
- entire_shape=source_spec.entire_shape,
- dim_partition_dict={})
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
- input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
-
- else:
- input_comm_action = None
-
- if input_comm_action is not None:
- communication_action_mapping["input"] = input_comm_action
-
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
- strategy_list.append(strategy)
-
- return strategy_list
-
-
-class PermuteGenerator(ReshapeGenerator):
- """
- PermuteGenerator deals with the sharding strategies of permute op.
- """
-
- def collate_strategies(self) -> List[ShardingStrategy]:
- strategy_list = []
- for index, strategy in enumerate(self.predecessor_node.strategies_vector):
- dim_partition_dict_mapping = {}
- communication_action_mapping = {}
- input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
-
- permute_dims = self.op_data['permute_dims'].data
- dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
- dim_partition_dict_for_output = {}
- for dim_index, permute_dim in enumerate(permute_dims):
- if permute_dim in dim_partition_dict_for_input:
- dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
-
- dim_partition_dict_mapping = {
- "input": dim_partition_dict_for_input,
- "output": dim_partition_dict_for_output,
- }
- sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
-
- # add index into name to pass the duplicated check
- # we keep same strategies with different name for node merging, and it will not increase the searching space,
- # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
- name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
-
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
- strategy_list.append(strategy)
-
- return strategy_list
-
-
-class TransposeGenerator(ReshapeGenerator):
- """
- TransposeGenerator deals with the sharding strategies of permute op.
- """
-
- def collate_strategies(self) -> List[ShardingStrategy]:
- strategy_list = []
- for index, strategy in enumerate(self.predecessor_node.strategies_vector):
- dim_partition_dict_mapping = {}
- communication_action_mapping = {}
- input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
- dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
- dim_partition_dict_for_output = {}
-
- transpose_dims = self.op_data['transpose_dims'].data
- dim_0 = transpose_dims[0]
- dim_1 = transpose_dims[1]
- for dim, sharded_dims in dim_partition_dict_for_input.items():
- if dim == dim_0:
- dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
- elif dim == dim_1:
- dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
- else:
- dim_partition_dict_for_output[dim] = sharded_dims
-
- dim_partition_dict_mapping = {
- "input": dim_partition_dict_for_input,
- "output": dim_partition_dict_for_output,
- }
- sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
-
- # add index into name to pass the duplicated check
- # we keep same strategies with different name for node merging, and it will not increase the searching space,
- # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
- name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
-
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
- strategy_list.append(strategy)
-
- return strategy_list
-
-
-class SplitGenerator(ReshapeGenerator):
- """
- SplitGenerator deals with the sharding strategies of split op.
- """
-
- def collate_strategies(self) -> List[ShardingStrategy]:
- strategy_list = []
- for index, strategy in enumerate(self.predecessor_node.strategies_vector):
- recover_dims = None
- dim_partition_dict_mapping = {}
- communication_action_mapping = {}
- input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
- dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
- split_size, split_dim = self.op_data['split_info'].data
-
- if split_dim in dim_partition_dict_for_input:
- recover_dims = dim_partition_dict_for_input.pop(split_dim)
-
- dim_partition_dict_for_output = [
- copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
- ]
- assert len(dim_partition_dict_for_output) >= 2
- dim_partition_dict_mapping = {
- "input": dim_partition_dict_for_input,
- "output": dim_partition_dict_for_output,
- }
- sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
- # add index into name to pass the duplicated check
- # we keep same strategies with different name for node merging, and it will not increase the searching space,
- # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
- name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
-
- # add comm action if the input need to be recovered to replica in the split dimension.
- if recover_dims:
- # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
- if len(recover_dims) == 1:
- recover_dims = recover_dims[0]
- input_comm_action = self.get_communication_action(
- sharding_spec=sharding_spec_mapping["input"],
- communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
- logical_process_axis=recover_dims,
- comm_type=CommType.BEFORE,
- arg_index=0)
- # it will gather the input through gather_dim during forward phase.
- input_comm_action.comm_spec.gather_dim = split_dim
- # it will split the input activation grad through split_dim during backward phase.
- input_comm_action.comm_spec.shard_dim = split_dim
-
- elif len(recover_dims) >= 2:
- # original sharding spec
- source_spec = input_sharding_spec
- # target sharding spec
- target_spec = sharding_spec_mapping["input"]
- comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
- input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
-
- else:
- input_comm_action = None
-
- if input_comm_action is not None:
- communication_action_mapping["input"] = input_comm_action
-
- strategy = self.get_sharding_strategy(name=name,
- sharding_spec_mapping=sharding_spec_mapping,
- communication_action_mapping=communication_action_mapping)
- strategy_list.append(strategy)
-
- return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
index 132ac30daed8..452381169b74 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/layer_norm_handler.py
@@ -3,7 +3,7 @@
import torch
from ..sharding_strategy import OperationData, OperationDataType
-from .node_handler import ModuleHandler
+from .node_handler import MetaInfoModuleHandler, ModuleHandler
from .registry import operator_registry
from .strategy import LayerNormGenerator, StrategyGenerator
@@ -11,7 +11,7 @@
@operator_registry.register(torch.nn.LayerNorm)
-class LayerNormModuleHandler(ModuleHandler):
+class LayerNormModuleHandler(MetaInfoModuleHandler):
"""
A LayerNormModuleHandler which deals with the sharding strategies for nn.LayerNorm module.
"""
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
index 37ff3c3ab572..59091dab519f 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
@@ -152,7 +152,10 @@ def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
- LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
+ LinearProjectionStrategyGenerator(op_data_mapping,
+ self.device_mesh,
+ linear_projection_type='linear',
+ solver_perference=self.solver_perference))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
index d3f9fd01d891..f3c9d0cbf826 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/matmul_handler.py
@@ -16,7 +16,7 @@
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from ..utils import recover_sharding_spec_for_broadcast_shape
-from .node_handler import NodeHandler
+from .node_handler import MetaInfoNodeHandler, NodeHandler
from .registry import operator_registry
from .strategy import (
BatchedMatMulStrategyGenerator,
@@ -326,7 +326,7 @@ def _get_bmm_logical_shape(input_shape, other_shape, transforms):
@operator_registry.register(torch.matmul)
@operator_registry.register(torch.Tensor.matmul)
-class MatMulHandler(NodeHandler):
+class MatMulHandler(MetaInfoNodeHandler):
"""
The MatMulHandler is a node handler which handles the sharding strategy generation for the matmul operation.
According to https://pytorch.org/docs/stable/generated/torch.matmul.html, the operations will vary depending on
@@ -483,4 +483,6 @@ def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, Li
raise TypeError(
f"Found unexpected output type {type(output)} from the recover method of BmmTransform")
strategies = recovered_stragies
+ for index, strategies in enumerate(strategies):
+ strategies.name = f"{strategies.name}_{index}"
return strategies
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
index 78dc58c905ec..136e57c5e0f5 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
@@ -5,6 +5,7 @@
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
+from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
@@ -15,6 +16,7 @@
)
from colossalai.auto_parallel.tensor_shard.utils import check_sharding_spec_validity
from colossalai.device.device_mesh import DeviceMesh
+from colossalai.logging import get_dist_logger
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from .strategy import StrategyGenerator
@@ -30,17 +32,19 @@ class NodeHandler(ABC):
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
'''
- def __init__(
- self,
- node: Node,
- device_mesh: DeviceMesh,
- strategies_vector: StrategiesVector,
- ) -> None:
+ def __init__(self,
+ node: Node,
+ device_mesh: DeviceMesh,
+ strategies_vector: StrategiesVector,
+ shard_option: ShardOption = ShardOption.STANDARD,
+ solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
+ self.shard_option = shard_option
+ self.solver_perference = solver_perference
def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
"""
@@ -181,6 +185,30 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
check_sharding_spec_validity(sharding_spec, op_data.data)
+ remove_strategy_list = []
+ for strategy in self.strategies_vector:
+ shard_axis_list = []
+ last_axis = len(self.device_mesh.mesh_shape) - 1
+ for op_data, sharding_spec in strategy.sharding_specs.items():
+ if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
+ for dim, shard_axes in sharding_spec.dim_partition_dict.items():
+ for shard_axis in shard_axes:
+ if shard_axis not in shard_axis_list:
+ shard_axis_list.append(shard_axis)
+
+ shard_level = len(shard_axis_list)
+ using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list
+ if self.shard_option == ShardOption.SHARD and shard_level == 0:
+ remove_strategy_list.append(strategy)
+ if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
+ remove_strategy_list.append(strategy)
+ if self.shard_option == ShardOption.SHARD_LAST_AXIS:
+ if shard_level != 1 or using_last_axis == False:
+ remove_strategy_list.append(strategy)
+
+ for strategy in remove_strategy_list:
+ self.strategies_vector.remove(strategy)
+
return self.strategies_vector
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
@@ -248,6 +276,10 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
+ else:
+ logger = get_dist_logger()
+ logger.warning(f'The target function {target} is not patched yet, ')
+
return self.strategies_vector
@@ -299,4 +331,8 @@ def register_strategy(self, compute_resharding_cost: bool = True) -> StrategiesV
# attach metainfos to the handler
setattr(self, "metainfo_vector", metainfo_vector)
+ else:
+ logger = get_dist_logger()
+ logger.warning(f'The target function {target} is not patched yet')
+
return self.strategies_vector
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
similarity index 92%
rename from colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py
rename to colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
index 6d625e153f61..91e4a5105a08 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/permute_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/permute_handler.py
@@ -2,11 +2,10 @@
import torch
-from ...sharding_strategy import OperationData, OperationDataType
-from ..node_handler import NodeHandler
-from ..registry import operator_registry
-from ..strategy import StrategyGenerator
-from .reshape_generator import PermuteGenerator
+from ..sharding_strategy import OperationData, OperationDataType
+from .node_handler import NodeHandler
+from .registry import operator_registry
+from .strategy import PermuteGenerator, StrategyGenerator
__all__ = ['PermuteHandler']
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
similarity index 89%
rename from colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py
rename to colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
index 38c5eed7d00e..653d158b7c36 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/split_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
@@ -2,11 +2,10 @@
import torch
-from ...sharding_strategy import OperationData, OperationDataType
-from ..node_handler import NodeHandler
-from ..registry import operator_registry
-from ..strategy import StrategyGenerator
-from .reshape_generator import SplitGenerator
+from ..sharding_strategy import OperationData, OperationDataType
+from .node_handler import NodeHandler
+from .registry import operator_registry
+from .strategy import SplitGenerator, StrategyGenerator
__all__ = ['SplitHandler']
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
index 8d25475f9c57..db1f31521c86 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
@@ -14,7 +14,13 @@
from .normal_pooling_generator import NormalPoolStrategyGenerator
from .output_generator import OutputGenerator
from .placeholder_generator import PlaceholderGenerator
-from .reshape_generator import ReshapeGenerator
+from .reshape_generator import (
+ DefaultReshapeGenerator,
+ PermuteGenerator,
+ SplitGenerator,
+ TransposeGenerator,
+ ViewGenerator,
+)
from .softmax_generator import SoftmaxGenerator
from .strategy_generator import StrategyGenerator
from .sum_generator import SumGenerator
@@ -26,7 +32,8 @@
'StrategyGenerator', 'DotProductStrategyGenerator', 'MatVecStrategyGenerator', 'LinearProjectionStrategyGenerator',
'BatchedMatMulStrategyGenerator', 'ConvStrategyGenerator', 'UnaryElementwiseGenerator',
'BatchNormStrategyGenerator', 'GetItemStrategyGenerator', 'TensorStrategyGenerator', 'TensorTupleStrategyGenerator',
- 'LayerNormGenerator', 'ReshapeGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator',
- 'ReshapeGenerator', 'NormalPoolStrategyGenerator', 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator',
- 'TensorConstructorGenerator', 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator'
+ 'LayerNormGenerator', 'PlaceholderGenerator', 'OutputGenerator', 'WhereGenerator', 'NormalPoolStrategyGenerator',
+ 'BinaryElementwiseStrategyGenerator', 'GetattrGenerator', 'TensorConstructorGenerator',
+ 'EmbeddingStrategyGenerator', 'SumGenerator', 'SoftmaxGenerator', 'ViewGenerator', 'PermuteGenerator',
+ 'TransposeGenerator', 'SplitGenerator', 'DefaultReshapeGenerator'
]
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
index fa2246f952a9..5d70e131d1e9 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
@@ -3,6 +3,7 @@
from functools import reduce
from typing import List
+from colossalai.auto_parallel.tensor_shard.options import SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
@@ -209,9 +210,14 @@ def collate_strategies(self) -> List[ShardingStrategy]:
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
- def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'):
+ def __init__(self,
+ operation_data_mapping,
+ device_mesh,
+ linear_projection_type='linear',
+ solver_perference=SolverPerference.STANDARD):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
+ self.solver_perference = solver_perference
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C = AB
@@ -231,16 +237,22 @@ def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
- def collate_strategies(self) -> List[ShardingStrategy]:
+ def dp_strategies(self) -> List[ShardingStrategy]:
strategies = []
- # SS = SR x RS
- strategies.append(self.split_lhs_space_rhs_space(0, 1))
- strategies.append(self.split_lhs_space_rhs_space(1, 0))
+ # S01R = S01R x RR
+ strategies.append(self.split_lhs_1st_dim_1d(0, 1))
- # SR = SS x SR
- strategies.append(self.split_lhs_space_both_contract(0, 1))
- strategies.append(self.split_lhs_space_both_contract(1, 0))
+ return strategies
+
+ def tp_strategies(self) -> List[ShardingStrategy]:
+ strategies = []
+
+ # RR = RS01 x S01R
+ strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
+
+ # RS01 = RR x RS01
+ strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# RS = RS x SS
strategies.append(self.split_rhs_space_both_contract(0, 1))
@@ -254,20 +266,38 @@ def collate_strategies(self) -> List[ShardingStrategy]:
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))
- # S01R = S01R x RR
- strategies.append(self.split_lhs_1st_dim_1d(0, 1))
+ return strategies
- # RR = RS01 x S01R
- strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
+ def mix_strategies(self) -> List[ShardingStrategy]:
+ strategies = []
- # RS01 = RR x RS01
- strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
+ # SS = SR x RS
+ strategies.append(self.split_lhs_space_rhs_space(0, 1))
+ strategies.append(self.split_lhs_space_rhs_space(1, 0))
+
+ # SR = SS x SR
+ strategies.append(self.split_lhs_space_both_contract(0, 1))
+ strategies.append(self.split_lhs_space_both_contract(1, 0))
# RR = RR x RR
strategies.append(self.non_split())
return strategies
+ def collate_strategies(self) -> List[ShardingStrategy]:
+ strategies = []
+
+ if self.solver_perference == SolverPerference.STANDARD:
+ strategies.extend(self.dp_strategies())
+ strategies.extend(self.tp_strategies())
+ strategies.extend(self.mix_strategies())
+ elif self.solver_perference == SolverPerference.DP:
+ strategies.extend(self.dp_strategies())
+ elif self.solver_perference == SolverPerference.TP:
+ strategies.extend(self.tp_strategies())
+
+ return strategies
+
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
index 0b3506c27e4c..24f75e352935 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
@@ -1,6 +1,7 @@
import copy
from typing import List
+from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
@@ -8,17 +9,20 @@
ShardingStrategy,
TrainCycleItem,
)
+from colossalai.auto_parallel.tensor_shard.utils import (
+ check_keep_sharding_status,
+ detect_reshape_mapping,
+ infer_output_dim_partition_dict,
+)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
-from .strategy_generator import FollowingStrategyGenerator
-
-__all__ = ['ReshapeGenerator']
+__all__ = ['ReshapeGenerator', 'ViewGenerator', 'PermuteGenerator', 'TransposeGenerator', 'SplitGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
"""
- ReshapeGenerator which deals with the sharding strategies of Reshape Op, such as torch.Tensor.permute.
+ ReshapeGenerator is the base class for all the reshape operation.
"""
def validate(self) -> bool:
@@ -57,11 +61,255 @@ def update_memory_cost(self, strategy: ShardingStrategy):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
+ def collate_strategies(self) -> List[ShardingStrategy]:
+ return super().collate_strategies()
+
+
+class ViewGenerator(ReshapeGenerator):
+ """
+ ViewGenerator deals with the sharding strategies of view op.
+ """
+
+ def collate_strategies(self) -> List[ShardingStrategy]:
+ strategy_list = []
+ for index, strategy in enumerate(self.predecessor_node.strategies_vector):
+ dim_partition_dict_mapping = {}
+ communication_action_mapping = {}
+ input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
+
+ origin_shape = self.op_data['input'].data.shape
+ tgt_shape = self.op_data['tgt_shape'].data
+
+ reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
+
+ dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
+ keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
+
+ if keep_sharding_status:
+ dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
+ reshape_mapping_dict)
+ else:
+ dim_partition_dict_for_output = {}
+
+ dim_partition_dict_mapping = {
+ "input": dim_partition_dict_for_input,
+ "output": dim_partition_dict_for_output,
+ }
+ sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
+
+ # add index into name to pass the duplicated check
+ # we keep same strategies with different name for node merging, and it will not increase the searching space,
+ # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
+ if keep_sharding_status:
+ name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
+ else:
+ name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
+
+ # add comm action for converting input to fully replicated
+ total_mesh_dim_list = []
+ for mesh_dim_list in dim_partition_dict_for_input.values():
+ total_mesh_dim_list.extend(mesh_dim_list)
+ # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
+ if len(total_mesh_dim_list) == 1:
+ total_mesh_dim_list = total_mesh_dim_list[0]
+ # the total mesh dim list only has one element, so the shard dim has only one element as well.
+ shard_dim = list(dim_partition_dict_for_input.keys())[0]
+ input_comm_action = self.get_communication_action(
+ sharding_spec=sharding_spec_mapping["input"],
+ communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
+ logical_process_axis=total_mesh_dim_list,
+ comm_type=CommType.BEFORE,
+ arg_index=0)
+ # it will gather the input through gather_dim during forward phase.
+ input_comm_action.comm_spec.gather_dim = shard_dim
+ # it will split the input activation grad through shard_dim during backward phase.
+ input_comm_action.comm_spec.shard_dim = shard_dim
+
+ elif len(total_mesh_dim_list) >= 2:
+ source_spec = sharding_spec_mapping["input"]
+ target_spec = ShardingSpec(device_mesh=self.device_mesh,
+ entire_shape=source_spec.entire_shape,
+ dim_partition_dict={})
+ comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
+
+ else:
+ input_comm_action = None
+
+ if input_comm_action is not None:
+ communication_action_mapping["input"] = input_comm_action
+
+ strategy = self.get_sharding_strategy(name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping)
+ strategy_list.append(strategy)
+
+ return strategy_list
+
+
+class PermuteGenerator(ReshapeGenerator):
+ """
+ PermuteGenerator deals with the sharding strategies of permute op.
+ """
+
+ def collate_strategies(self) -> List[ShardingStrategy]:
+ strategy_list = []
+ for index, strategy in enumerate(self.predecessor_node.strategies_vector):
+ dim_partition_dict_mapping = {}
+ communication_action_mapping = {}
+ input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
+
+ permute_dims = self.op_data['permute_dims'].data
+ dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
+ dim_partition_dict_for_output = {}
+ for dim_index, permute_dim in enumerate(permute_dims):
+ if permute_dim in dim_partition_dict_for_input:
+ dim_partition_dict_for_output[dim_index] = dim_partition_dict_for_input[permute_dim]
+
+ dim_partition_dict_mapping = {
+ "input": dim_partition_dict_for_input,
+ "output": dim_partition_dict_for_output,
+ }
+ sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
+
+ # add index into name to pass the duplicated check
+ # we keep same strategies with different name for node merging, and it will not increase the searching space,
+ # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
+ name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
+
+ strategy = self.get_sharding_strategy(name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping)
+ strategy_list.append(strategy)
+
+ return strategy_list
+
+
+class TransposeGenerator(ReshapeGenerator):
+ """
+ TransposeGenerator deals with the sharding strategies of permute op.
+ """
+
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
- # For reshape function, to keep the computing correctness we keep the sharding
- # spec of input is fully replicated. In addition, we will keep the output in
- # replica status and let the successor node choose the way to resharding the
+ for index, strategy in enumerate(self.predecessor_node.strategies_vector):
+ dim_partition_dict_mapping = {}
+ communication_action_mapping = {}
+ input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
+ dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
+ dim_partition_dict_for_output = {}
+
+ transpose_dims = self.op_data['transpose_dims'].data
+ dim_0 = transpose_dims[0]
+ dim_1 = transpose_dims[1]
+ for dim, sharded_dims in dim_partition_dict_for_input.items():
+ if dim == dim_0:
+ dim_partition_dict_for_output[dim_1] = dim_partition_dict_for_input[dim_0]
+ elif dim == dim_1:
+ dim_partition_dict_for_output[dim_0] = dim_partition_dict_for_input[dim_1]
+ else:
+ dim_partition_dict_for_output[dim] = sharded_dims
+
+ dim_partition_dict_mapping = {
+ "input": dim_partition_dict_for_input,
+ "output": dim_partition_dict_for_output,
+ }
+ sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
+
+ # add index into name to pass the duplicated check
+ # we keep same strategies with different name for node merging, and it will not increase the searching space,
+ # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
+ name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
+
+ strategy = self.get_sharding_strategy(name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping)
+ strategy_list.append(strategy)
+
+ return strategy_list
+
+
+class SplitGenerator(ReshapeGenerator):
+ """
+ SplitGenerator deals with the sharding strategies of split op.
+ """
+
+ def collate_strategies(self) -> List[ShardingStrategy]:
+ strategy_list = []
+ for index, strategy in enumerate(self.predecessor_node.strategies_vector):
+ recover_dims = None
+ dim_partition_dict_mapping = {}
+ communication_action_mapping = {}
+ input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
+ dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
+ split_size, split_dim = self.op_data['split_info'].data
+
+ if split_dim in dim_partition_dict_for_input:
+ recover_dims = dim_partition_dict_for_input.pop(split_dim)
+
+ dim_partition_dict_for_output = [
+ copy.deepcopy(dim_partition_dict_for_input) for _ in range(len(self.op_data["output"].data))
+ ]
+ assert len(dim_partition_dict_for_output) >= 2
+ dim_partition_dict_mapping = {
+ "input": dim_partition_dict_for_input,
+ "output": dim_partition_dict_for_output,
+ }
+ sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
+ # add index into name to pass the duplicated check
+ # we keep same strategies with different name for node merging, and it will not increase the searching space,
+ # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
+ name = f'{sharding_spec_mapping["input"].sharding_sequence}_{index}'
+
+ # add comm action if the input need to be recovered to replica in the split dimension.
+ if recover_dims:
+ # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
+ if len(recover_dims) == 1:
+ recover_dims = recover_dims[0]
+ input_comm_action = self.get_communication_action(
+ sharding_spec=sharding_spec_mapping["input"],
+ communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
+ logical_process_axis=recover_dims,
+ comm_type=CommType.BEFORE,
+ arg_index=0)
+ # it will gather the input through gather_dim during forward phase.
+ input_comm_action.comm_spec.gather_dim = split_dim
+ # it will split the input activation grad through split_dim during backward phase.
+ input_comm_action.comm_spec.shard_dim = split_dim
+
+ elif len(recover_dims) >= 2:
+ # original sharding spec
+ source_spec = input_sharding_spec
+ # target sharding spec
+ target_spec = sharding_spec_mapping["input"]
+ comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+ input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
+
+ else:
+ input_comm_action = None
+
+ if input_comm_action is not None:
+ communication_action_mapping["input"] = input_comm_action
+
+ strategy = self.get_sharding_strategy(name=name,
+ sharding_spec_mapping=sharding_spec_mapping,
+ communication_action_mapping=communication_action_mapping)
+ strategy_list.append(strategy)
+
+ return strategy_list
+
+
+class DefaultReshapeGenerator(ReshapeGenerator):
+ """
+ DefaultReshapeGenerator which deals with the sharding strategies of Reshape Op which have to recover the tensor
+ to Replica status.
+ """
+
+ def collate_strategies(self) -> List[ShardingStrategy]:
+ strategy_list = []
+ # For default reshape strategy, to keep the computing correctness we keep the
+ # sharding spec of input is fully replicated. In addition, we will keep the output
+ # in replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
@@ -95,6 +343,7 @@ def collate_strategies(self) -> List[ShardingStrategy]:
comm_type=CommType.BEFORE,
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
+ input_comm_action.comm_spec.shard_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
@@ -114,9 +363,4 @@ def collate_strategies(self) -> List[ShardingStrategy]:
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
- for strategy in strategy_list:
- self.update_communication_cost(strategy)
- self.update_compute_cost(strategy)
- self.update_memory_cost(strategy)
-
return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
similarity index 90%
rename from colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py
rename to colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
index 3c7336a93167..7a9d37726490 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/transpose_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/transpose_handler.py
@@ -2,11 +2,10 @@
import torch
-from ...sharding_strategy import OperationData, OperationDataType
-from ..node_handler import NodeHandler
-from ..registry import operator_registry
-from ..strategy import StrategyGenerator
-from .reshape_generator import TransposeGenerator
+from ..sharding_strategy import OperationData, OperationDataType
+from .node_handler import NodeHandler
+from .registry import operator_registry
+from .strategy import StrategyGenerator, TransposeGenerator
__all__ = ['TransposeHandler']
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
similarity index 88%
rename from colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
rename to colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
index 6be634593510..7dff89d1d7a3 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/view_handler.py
@@ -2,11 +2,10 @@
import torch
-from ...sharding_strategy import OperationData, OperationDataType
-from ..node_handler import NodeHandler
-from ..registry import operator_registry
-from ..strategy import StrategyGenerator
-from .reshape_generator import ViewGenerator
+from ..sharding_strategy import OperationData, OperationDataType
+from .node_handler import NodeHandler
+from .registry import operator_registry
+from .strategy import StrategyGenerator, ViewGenerator
__all__ = ['ViewHandler']
diff --git a/colossalai/auto_parallel/tensor_shard/options.py b/colossalai/auto_parallel/tensor_shard/options.py
new file mode 100644
index 000000000000..f0ea502a6f0e
--- /dev/null
+++ b/colossalai/auto_parallel/tensor_shard/options.py
@@ -0,0 +1,49 @@
+from dataclasses import dataclass
+from enum import Enum
+
+__all__ = ['SolverOptions', 'SolverPerference', 'DataloaderOption', 'ShardOption']
+
+
+class SolverPerference(Enum):
+ """
+ This enum class is to define the solver preference.
+ """
+ STANDARD = 0
+ DP = 1
+ TP = 2
+
+
+class ShardOption(Enum):
+ """
+ This enum class is to define the shard level required in node strategies.
+
+ Notes:
+ STANDARD: We do not add any extra shard requirements.
+ SHARD: We require the node to be shard using at least one device mesh axis.
+ SHARD_ONE_AXIS: We require the node to be shard using the last device mesh axis.
+ FULL_SHARD: We require the node to be shard using all device mesh axes.
+ TP_SHARD: We require the node to be shard using tensor parallel strategies on last device mesh axis.
+ TP_FULL_SHARD: We require the node to be shard using tensor parallel strategies on all device mesh axes.
+ """
+ STANDARD = 0
+ SHARD = 1
+ SHARD_LAST_AXIS = 2
+ FULL_SHARD = 3
+
+
+class DataloaderOption(Enum):
+ """
+ This enum class is to define the dataloader option.
+ """
+ REPLICATED = 0
+ DISTRIBUTED = 1
+
+
+@dataclass
+class SolverOptions:
+ """
+ SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
+ """
+ solver_perference: SolverPerference = SolverPerference.STANDARD
+ dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
+ shard_option: ShardOption = ShardOption.STANDARD
diff --git a/colossalai/auto_parallel/tensor_shard/solver/__init__.py b/colossalai/auto_parallel/tensor_shard/solver/__init__.py
index e9f9ba8814a7..f9e6bd923921 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/__init__.py
@@ -1,7 +1,6 @@
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
-from .options import SolverOptions
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
-__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions']
+__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
diff --git a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
index 038e56547b96..74290453ca0c 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
@@ -62,9 +62,6 @@ def _build_cost_graph(self):
else:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
- # add parents and children attribute to node
- # parent_nodes = [node for node in strategies_vector.predecessor_nodes]
- # children_nodes = [node for node in strategies_vector.successor_nodes]
parent_nodes = []
children_nodes = []
diff --git a/colossalai/auto_parallel/tensor_shard/solver/options.py b/colossalai/auto_parallel/tensor_shard/solver/options.py
deleted file mode 100644
index b52e55708dfd..000000000000
--- a/colossalai/auto_parallel/tensor_shard/solver/options.py
+++ /dev/null
@@ -1,30 +0,0 @@
-from dataclasses import dataclass
-from enum import Enum
-
-__all__ = ['SolverOptions']
-
-
-class SolverPerference(Enum):
- """
- This enum class is to define the solver preference.
- """
- STANDARD = 0
- DP = 1
- TP = 2
-
-
-class DataloaderOption(Enum):
- """
- This enum class is to define the dataloader option.
- """
- REPLICATED = 0
- DISTRIBUTED = 1
-
-
-@dataclass
-class SolverOptions:
- """
- SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
- """
- solver_perference: SolverPerference = SolverPerference.STANDARD
- dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
diff --git a/colossalai/auto_parallel/tensor_shard/solver/solver.py b/colossalai/auto_parallel/tensor_shard/solver/solver.py
index 89d0da2235a2..f5c6663dce80 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/solver.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/solver.py
@@ -1,3 +1,7 @@
+"""This code is adapted from Alpa
+ https://github.com/alpa-projects/alpa/
+ with some changes. """
+
import multiprocessing
import time
import warnings
@@ -28,12 +32,12 @@ def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
- graph_analyser: GraphAnalyser,
+ graph_analyser: GraphAnalyser = None,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3,
- verbose=True):
+ verbose=False):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
@@ -59,7 +63,10 @@ def __init__(self,
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
- self.liveness_list = self.graph_analyser.liveness_analysis()
+ # temporarily we use all nodes as liveness list, we count the backward memory cost together with
+ # forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
+ # self.liveness_list = self.graph_analyser.liveness_analysis()
+ self.liveness_list = self.nodes
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
@@ -136,7 +143,7 @@ def _prepare_data_for_solver(self):
liveness_set = self.liveness_list
# omit alias_set now
- alias_set = None
+ alias_set = self.strategies_constructor.alias_set
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
@@ -226,6 +233,7 @@ def get_non_zero_index(binary_vector):
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
+ s_alias = alias_set
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
@@ -290,8 +298,11 @@ def get_non_zero_index(binary_vector):
if strategies_len[i] == 1:
s.append([1])
else:
- num_nodes += 1
- s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
+ if i not in s_alias:
+ num_nodes += 1
+ s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
+ else:
+ s.append(s[s_alias[i]])
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
@@ -307,15 +318,20 @@ def get_non_zero_index(binary_vector):
#############################
e = []
num_edges = 0
+ map_edge_to_idx = {}
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
- num_edges += 1
- e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
+ if i in s_alias and j in s_alias and (s_alias[i], s_alias[j]) in map_edge_to_idx:
+ e.append(e[map_edge_to_idx[(s_alias[i], s_alias[j])]])
+ else:
+ num_edges += 1
+ e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
+ map_edge_to_idx[(i, j)] = idx
for element in s:
assert len(element) > 0
# 2. Set initial value
@@ -367,13 +383,12 @@ def get_non_zero_index(binary_vector):
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
- for liveness_stage in liveness_set:
- mem = 0
- for live_variable in liveness_stage.unique_live_vars:
- if live_variable.node not in self.node_index_dict:
- continue
- node_index = self.node_index_dict[live_variable.node]
- mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
+ mem = 0
+ for node in liveness_set:
+ if node not in self.node_index_dict:
+ continue
+ node_index = self.node_index_dict[node]
+ mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
index 042b9bb4b0d1..59ead1ca8fac 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
@@ -15,9 +15,10 @@
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
+from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
from colossalai.device.device_mesh import DeviceMesh
-from .options import DataloaderOption, SolverOptions
+from ..options import DataloaderOption, SolverOptions
__all__ = ['StrategiesConstructor']
@@ -42,6 +43,7 @@ def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: Solver
self.strategy_map = {}
self.solver_options = solver_options
self.no_strategy_nodes = []
+ self.alias_set = None
def remove_duplicated_strategy(self, strategies_vector):
'''
@@ -59,6 +61,22 @@ def remove_duplicated_strategy(self, strategies_vector):
for strategy in remove_list:
strategies_vector.remove(strategy)
+ def generate_alias_set(self):
+
+ node_list = [strategy_vector.node for strategy_vector in self.leaf_strategies]
+ common_blocks = find_repeat_blocks(node_list, self.root_module, common_length_threshold=10)
+
+ repeat_block_nums = len(common_blocks)
+ alias_set = {}
+
+ if repeat_block_nums == 0:
+ return alias_set
+
+ for index, common_node in enumerate(common_blocks[0]):
+ for i in range(1, repeat_block_nums):
+ alias_set[node_list.index(common_blocks[i][index])] = node_list.index(common_node)
+ return alias_set
+
def build_strategies_and_cost(self):
"""
This method is to build the strategy vector for each node in the computation graph.
@@ -101,7 +119,11 @@ def _check_no_strategy_for_data(data):
# get_attr node
elif node.op == 'get_attr':
- getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
+ getattr_handler = GetattrHandler(node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference)
getattr_handler.register_strategy()
# call_module node
@@ -109,7 +131,11 @@ def _check_no_strategy_for_data(data):
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
- handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
+ handler = operator_registry.get(submod_type)(node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
@@ -118,7 +144,11 @@ def _check_no_strategy_for_data(data):
# call_function node
elif node.op == 'call_function':
target = node.target
- handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
+ handler = operator_registry.get(target)(node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
@@ -127,7 +157,11 @@ def _check_no_strategy_for_data(data):
# call_method node
elif node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
- handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
+ handler = operator_registry.get(method)(node,
+ self.device_mesh,
+ strategies_vector,
+ shard_option=self.solver_options.shard_option,
+ solver_perference=self.solver_options.solver_perference)
handler.register_strategy()
# attach metainfo_vector to node
if hasattr(handler, 'metainfo_vector'):
@@ -159,3 +193,6 @@ def _check_no_strategy_for_data(data):
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)
+
+ alias_set = self.generate_alias_set()
+ self.alias_set = alias_set
diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py
index fd3ba3d41c30..05331e560001 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/factory.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py
@@ -1,13 +1,16 @@
+import copy
import operator
import warnings
from functools import reduce
from typing import Dict, List, Optional, Union
import torch
+from torch.fx.node import Node
+from torch.utils._pytree import tree_map
+
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
-from torch.fx.node import Node
from ..constants import INFINITY_COST
@@ -18,7 +21,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
-
+
Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
@@ -59,7 +62,7 @@ def generate_resharding_costs(nodes: List[Node],
nodes (List[Node]): a list of nodes
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
- dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
+ dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
@@ -88,3 +91,116 @@ def generate_resharding_costs(nodes: List[Node],
resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
+
+
+def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20):
+ '''
+ Find the largest repeat blocks in the graph, whose length is larger than the threshold.
+
+ Args:
+ gm (GraphModule): the graph module to be analyzed.
+ common_length_threshold (int): the threshold of the repeat block length.
+ '''
+
+ # graph = gm.graph
+
+ def _process_args(args):
+ new_args = []
+ for arg in args:
+ if hasattr(arg, '_meta_data'):
+ meta_data = arg._meta_data
+ else:
+ meta_data = arg
+
+ def _process_arg(data):
+ if isinstance(data, torch.Tensor):
+ data = data.size()
+ elif isinstance(data, slice):
+ data = (data.start, data.step, data.stop)
+ return data
+
+ new_meta_data = tree_map(_process_arg, meta_data)
+ new_args.append(new_meta_data)
+
+ return new_args
+
+ def _all_equal(check_list, check_fn):
+ base_value = check_list[-1]
+ for e in check_list:
+ if not check_fn(e, base_value):
+ return False
+ return True
+
+ def _check_node_list_equal(l1, l2):
+ if len(l1) != len(l2):
+ return False
+ for node1, node2 in zip(l1, l2):
+ if hash(node1.hash_key) != hash(node2.hash_key):
+ return False
+ return True
+
+ def _check_node_equal(node1, node2):
+ if hash(node1.hash_key) == hash(node2.hash_key):
+ return True
+ return False
+
+ for index, node in enumerate(node_list):
+ if node.op == 'call_module':
+ target = node.target
+ submod = root_module.get_submodule(target)
+ submod_type = type(submod)
+ target = submod_type
+ else:
+ target = node.target
+
+ new_args = _process_args(node.args)
+
+ if node.op != 'get_attr':
+ hash_key = (node.op, target, *new_args)
+ else:
+ hash_key = (node.op,)
+
+ setattr(node, 'hash_key', hash_key)
+
+ hash_value_to_node_dict = {}
+
+ for index, node in enumerate(node_list):
+ hash_value = hash(node.hash_key)
+ if hash_value not in hash_value_to_node_dict:
+ hash_value_to_node_dict[hash_value] = []
+ hash_value_to_node_dict[hash_value].append(index)
+
+ # node_list = list(graph.nodes)
+
+ node_list_start = 0
+ max_common_length = common_length_threshold
+ common_blocks_index = []
+ for index, node in enumerate(node_list):
+ # the comparison will be triggered if a common node appears
+ if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2:
+ start_index_list = hash_value_to_node_dict[hash(node.hash_key)]
+ check_block_list = [node_list[start:start + max_common_length] for start in start_index_list]
+
+ common_label = True
+ if not _all_equal(check_block_list, _check_node_list_equal):
+ common_label = False
+
+ if common_label:
+ common_blocks_index = copy.deepcopy(start_index_list)
+ max_step = len(node_list) - common_blocks_index[-1] - max_common_length - 1
+
+ for i in range(max_step):
+ # add assertion to avoid out of index
+ next_node_list = [node_list[index + max_common_length + i] for index in start_index_list]
+ if not _all_equal(next_node_list, _check_node_equal):
+ max_step = i
+ break
+ max_common_length += max_step
+ node_list_start += max_common_length
+
+ # recover common subgraph from the index
+ common_blocks = []
+ for start in common_blocks_index:
+ common_blocks.append(node_list[start:start + max_common_length])
+
+ return common_blocks
diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py
new file mode 100644
index 000000000000..2cbc6c9221aa
--- /dev/null
+++ b/colossalai/autochunk/autochunk_codegen.py
@@ -0,0 +1,561 @@
+from typing import Any, Callable, Dict, Iterable, List, Tuple
+
+import torch
+
+import colossalai
+from colossalai.fx._compatibility import is_compatible_with_meta
+from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
+
+AUTOCHUNK_AVAILABLE = CODEGEN_AVAILABLE and is_compatible_with_meta()
+
+if AUTOCHUNK_AVAILABLE:
+ from torch.fx.graph import CodeGen, PythonCode, _custom_builtins, _CustomBuiltin, _format_target, _is_from_torch, _Namespace, _origin_type_map, inplace_methods, magic_methods
+
+from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
+
+from .search_chunk import SearchChunk
+from .utils import delete_free_var_from_last_use, get_logger, get_node_name, get_node_shape
+
+
+def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
+ """
+ Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :]
+
+ Args:
+ chunk_dim (int)
+ chunk_indice_name (str): chunk indice name
+ shape (List): node shape
+
+ Returns:
+ new_shape (str): return slice
+ """
+ new_shape = "["
+ for idx, _ in enumerate(shape):
+ if idx == chunk_dim:
+ new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name)
+ else:
+ new_shape += ":"
+ new_shape += ", "
+ new_shape = new_shape[:-2] + "]"
+ return new_shape
+
+
+def _gen_loop_start(chunk_input: List[Node], chunk_output: List[Node], chunk_ouput_dim: int, chunk_size=2) -> str:
+ """
+ Generate chunk loop start
+
+ eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device)
+ chunk_size = 32
+ for chunk_idx in range(0, 100, 32):
+ ......
+
+ Args:
+ chunk_input (List[Node]): chunk input node
+ chunk_output (Node): chunk output node
+ chunk_ouput_dim (int): chunk output node chunk dim
+ chunk_size (int): chunk size. Defaults to 2.
+
+ Returns:
+ context (str): generated str
+ """
+ input_node = chunk_input[0]
+
+ context = ""
+ for i in range(len(chunk_output)):
+ shape_str = str(list(get_node_shape(chunk_output[i])))
+ if get_node_name(chunk_output[i]) in ["split", "unbind"]:
+ tensor_str = "torch.empty(%s, dtype=%s.dtype, device=%s.device), " % (shape_str, input_node.name,
+ input_node.name)
+ tensor_str = tensor_str * len(chunk_output[i].meta['tensor_meta'])
+ tensor_str = "[" + tensor_str[:-2] + "]"
+ context += "%s = %s; " % (chunk_output[i].name, tensor_str)
+ else:
+ context += "%s = torch.empty(%s, dtype=%s.dtype, device=%s.device); " % (chunk_output[i].name, shape_str,
+ input_node.name, input_node.name)
+
+ out_shape = get_node_shape(chunk_output[0])
+ chunk_shape = out_shape[chunk_ouput_dim[0]]
+ context += "chunk_size = %d\nfor chunk_idx in range(0, %d, chunk_size):\n" % (chunk_size, chunk_shape)
+ return context
+
+
+def _gen_loop_end(chunk_inputs: List[Node], chunk_non_compute_inputs: List[Node], node_list: List[Node],
+ chunk_outputs_idx: int, chunk_outputs_non_tensor: List[Node], search_chunk: SearchChunk) -> str:
+ """
+ Generate chunk loop end
+
+ eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node
+ output_node = chunk_result; xx = None; xx = None
+
+ Args:
+ chunk_inputs (List[Node]): chunk input node
+ chunk_non_compute_inputs (List[Node]): input node without chunk
+ chunk_outputs (Node): chunk output node
+ chunk_outputs_dim (int): chunk output node chunk dim
+ node_list (List)
+
+ Returns:
+ context (str): generated str
+ """
+ context = "chunk_size = None"
+ # determine if its the last use for chunk input
+ for chunk_input in chunk_inputs + chunk_non_compute_inputs:
+ if all([search_chunk.node_mgr.find_node_idx(user) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
+ context += "; %s = None" % chunk_input.name
+ for chunk_output_non_tensor, chunk_output_non_tensor_val in chunk_outputs_non_tensor.items():
+ context += "; %s = %s" % (chunk_output_non_tensor.name, chunk_output_non_tensor_val)
+ context += "\n"
+ return context
+
+
+def _replace_name(context: str, name_from: str, name_to: str) -> str:
+ """
+ replace node name
+ """
+ patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")]
+ for p in patterns:
+ source = p[0] + name_from + p[1]
+ target = p[0] + name_to + p[1]
+ if source in context:
+ context = context.replace(source, target)
+ break
+ return context
+
+
+def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str:
+ """
+ replace reshape size, some may have changed due to chunk
+ """
+ if node_name not in reshape_size_dict:
+ return context
+ context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1])
+ return context
+
+
+def _replace_new_tensor_like_shape(
+ search_chunk: SearchChunk,
+ chunk_infos: List[Dict],
+ region_idx: int,
+ node_idx: int,
+ node: Node,
+ body: List[str],
+) -> List[str]:
+ """
+ add chunk slice for new tensor op such as ones like
+ """
+ if get_node_name(node) in ["ones_like", "zeros_like", "empty_like"]:
+ meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
+ chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
+ if get_node_shape(meta_node)[chunk_dim] != 1:
+ source_node = meta_node.args[0].args[0]
+ if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
+ or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
+ chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
+ body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
+ return body
+
+
+def _replace_new_tensor_shape(
+ search_chunk: SearchChunk,
+ chunk_infos: List[Dict],
+ region_idx: int,
+ node_idx: int,
+ node: Node,
+ body: List[str],
+) -> List[str]:
+ """
+ add chunk slice for new tensor op such as ones
+ """
+ if get_node_name(node) in ["ones", "zeros", "empty"]:
+ meta_node = search_chunk.node_mgr.get_node_by_idx(node_idx)
+ chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
+ if chunk_dim is None:
+ return
+ if get_node_shape(meta_node)[chunk_dim] == 1:
+ return
+ origin_shape = str(node.args)
+ new_shape = list(node.args)
+ new_shape[chunk_dim] = "min(chunk_size, %d - chunk_idx)" % get_node_shape(meta_node)[chunk_dim]
+ new_shape = str(new_shape)
+ new_shape = new_shape.replace("'", "")
+ body[-1] = _replace_name(body[-1], origin_shape[1:-1], new_shape[1:-1])
+ return body
+
+
+def _add_node_slice(
+ chunk_nodes: List[Node],
+ region_idx: int,
+ chunk_nodes_dim: Dict,
+ node_idx: int,
+ body: List[str],
+ node: Node,
+) -> List[str]:
+ """
+ add chunk slice for input nodes
+ """
+ for chunk_node_idx, chunk_node in enumerate(chunk_nodes[region_idx]):
+ # inputs node
+ if isinstance(chunk_nodes_dim[region_idx][chunk_node_idx], dict):
+ for idx, dim in chunk_nodes_dim[region_idx][chunk_node_idx].items():
+ if idx == node_idx:
+ chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(chunk_node))
+ body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
+ # outputs node
+ else:
+ if chunk_node.name == node.name or (chunk_node.name in [i.name for i in node.all_input_nodes]):
+ chunk_slice = _gen_chunk_slice_dim(chunk_nodes_dim[region_idx][chunk_node_idx], "chunk_idx",
+ get_node_shape(chunk_node))
+ if get_node_name(chunk_node) in ["split", "unbind"]:
+ split_chunk_slice = ""
+ for i in range(len(chunk_node.meta['tensor_meta'])):
+ split_chunk_slice += "%s[%d]%s, " % (chunk_node.name, i, chunk_slice)
+ split_chunk_slice = split_chunk_slice[:-2]
+ body[-1] = _replace_name(body[-1], chunk_node.name, split_chunk_slice)
+ else:
+ body[-1] = _replace_name(body[-1], chunk_node.name, chunk_node.name + chunk_slice)
+ return body
+
+
+def emit_code_with_chunk(body: List[str],
+ nodes: Iterable[Node],
+ emit_node_func: Callable,
+ delete_unused_value_func: Callable,
+ search_chunk: SearchChunk,
+ chunk_infos: List,
+ eval_mem: bool = False):
+ """
+ Emit code with chunk according to chunk_infos.
+
+ It will generate a for loop in chunk regions, and
+ replace inputs and outputs of regions with chunked variables.
+
+ Args:
+ body: forward code
+ nodes: graph.nodes
+ emit_node_func: function to emit node
+ delete_unused_value_func: function to remove the unused value
+ search_chunk: the class to search all chunks
+ chunk_infos: store all information about all chunks.
+ """
+ node_list = list(nodes)
+
+ # chunk region
+ chunk_starts = [i["region"][0] for i in chunk_infos]
+ chunk_ends = [i["region"][1] for i in chunk_infos]
+
+ # chunk inputs
+ chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
+ chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
+ chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
+ chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
+
+ # chunk outputs
+ chunk_outputs = [i["outputs"] for i in chunk_infos]
+ chunk_outputs_non_tensor = [i["outputs_non_tensor"] for i in chunk_infos]
+ chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
+
+ node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
+ node_idx = 0
+ region_idx = 0
+ within_chunk_region = False
+
+ if eval_mem:
+ body.append("init_memory = torch.cuda.memory_allocated() / 1024**2\n")
+
+ while node_idx < len(node_list):
+ node = node_list[node_idx]
+
+ # if is chunk start, generate for loop start
+ if node_idx in chunk_starts:
+ within_chunk_region = True
+ region_idx = chunk_starts.index(node_idx)
+ body.append(
+ _gen_loop_start(
+ chunk_inputs[region_idx],
+ chunk_outputs[region_idx],
+ chunk_outputs_dim[region_idx],
+ chunk_infos[region_idx]["chunk_size"],
+ ))
+
+ if within_chunk_region:
+ emit_node_func(node, body)
+ # replace input var with chunk var
+ body = _add_node_slice(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body, node)
+ # replace output var with chunk var
+ body = _add_node_slice(chunk_outputs, region_idx, chunk_outputs_dim, node_idx, body, node)
+ # new tensor like
+ body = _replace_new_tensor_like_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
+ # new tensor
+ body = _replace_new_tensor_shape(search_chunk, chunk_infos, region_idx, node_idx, node, body)
+ # reassgin reshape size
+ body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
+ body[-1] = " " + body[-1]
+ delete_unused_value_func(node, body, chunk_inputs_names)
+ if eval_mem:
+ body.append(
+ " if chunk_idx == 0:\n print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
+ % (node.name))
+ else:
+ emit_node_func(node, body)
+ if node_idx not in chunk_inputs:
+ delete_unused_value_func(node, body, chunk_inputs_names)
+ if eval_mem:
+ body.append(
+ "print('%s', torch.cuda.max_memory_allocated() / 1024**2 - init_memory); torch.cuda.reset_peak_memory_stats()\n"
+ % (node.name))
+
+ # generate chunk region end
+ if node_idx in chunk_ends:
+ body.append(
+ _gen_loop_end(chunk_inputs[region_idx], chunk_inputs_non_chunk[region_idx], node_list,
+ chunk_ends[region_idx], chunk_outputs_non_tensor[region_idx], search_chunk))
+ within_chunk_region = False
+
+ node_idx += 1
+
+
+if AUTOCHUNK_AVAILABLE:
+
+ class AutoChunkCodeGen(CodeGen):
+
+ def __init__(self,
+ meta_graph,
+ max_memory: int = None,
+ print_mem: bool = False,
+ print_progress: bool = False,
+ eval_mem: bool = False) -> None:
+ super().__init__()
+ self.eval_mem = eval_mem
+ # find the chunk regions
+ self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem, print_progress)
+ self.chunk_infos = self.search_chunk.search_region()
+ if print_progress:
+ get_logger().info("AutoChunk start codegen")
+
+ def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
+ free_vars: List[str] = []
+ body: List[str] = []
+ globals_: Dict[str, Any] = {}
+ wrapped_fns: Dict[str, None] = {}
+
+ # Wrap string in list to pass by reference
+ maybe_return_annotation: List[str] = [""]
+
+ def add_global(name_hint: str, obj: Any):
+ """Add an obj to be tracked as a global.
+
+ We call this for names that reference objects external to the
+ Graph, like functions or types.
+
+ Returns: the global name that should be used to reference 'obj' in generated source.
+ """
+ if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
+ # HACK: workaround for how torch custom ops are registered. We
+ # can't import them like normal modules so they must retain their
+ # fully qualified name.
+ return _get_qualified_name(obj)
+
+ # normalize the name hint to get a proper identifier
+ global_name = namespace.create_name(name_hint, obj)
+
+ if global_name in globals_:
+ assert globals_[global_name] is obj
+ return global_name
+ globals_[global_name] = obj
+ return global_name
+
+ # set _custom_builtins here so that we needn't import colossalai in forward
+ _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
+
+ # Pre-fill the globals table with registered builtins.
+ for name, (_, obj) in _custom_builtins.items():
+ add_global(name, obj)
+
+ def type_repr(o: Any):
+ if o == ():
+ # Empty tuple is used for empty tuple type annotation Tuple[()]
+ return "()"
+
+ typename = _type_repr(o)
+
+ if hasattr(o, "__origin__"):
+ # This is a generic type, e.g. typing.List[torch.Tensor]
+ origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
+ origin_typename = add_global(_type_repr(origin_type), origin_type)
+
+ if hasattr(o, "__args__"):
+ # Assign global names for each of the inner type variables.
+ args = [type_repr(arg) for arg in o.__args__]
+
+ if len(args) == 0:
+ # Bare type, such as `typing.Tuple` with no subscript
+ # This code-path used in Python < 3.9
+ return origin_typename
+
+ return f'{origin_typename}[{",".join(args)}]'
+ else:
+ # Bare type, such as `typing.Tuple` with no subscript
+ # This code-path used in Python 3.9+
+ return origin_typename
+
+ # Common case: this is a regular module name like 'foo.bar.baz'
+ return add_global(typename, o)
+
+ def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
+
+ def _get_repr(arg):
+ # Handle NamedTuples (if it has `_fields`) via add_global.
+ if isinstance(arg, tuple) and hasattr(arg, "_fields"):
+ qualified_name = _get_qualified_name(type(arg))
+ global_name = add_global(qualified_name, type(arg))
+ return f"{global_name}{repr(tuple(arg))}"
+ return repr(arg)
+
+ args_s = ", ".join(_get_repr(a) for a in args)
+ kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
+ if args_s and kwargs_s:
+ return f"{args_s}, {kwargs_s}"
+ return args_s or kwargs_s
+
+ # Run through reverse nodes and record the first instance of a use
+ # of a given node. This represents the *last* use of the node in the
+ # execution order of the program, which we will use to free unused
+ # values
+ node_to_last_use: Dict[Node, Node] = {}
+ user_to_last_uses: Dict[Node, List[Node]] = {}
+
+ def register_last_uses(n: Node, user: Node):
+ if n not in node_to_last_use:
+ node_to_last_use[n] = user
+ user_to_last_uses.setdefault(user, []).append(n)
+
+ for node in reversed(nodes):
+ map_arg(node.args, lambda n: register_last_uses(n, node))
+ map_arg(node.kwargs, lambda n: register_last_uses(n, node))
+
+ delete_free_var_from_last_use(user_to_last_uses)
+
+ # NOTE: we add a variable to distinguish body and ckpt_func
+ def delete_unused_values(user: Node, body, to_keep=[]):
+ """
+ Delete values after their last use. This ensures that values that are
+ not used in the remainder of the code are freed and the memory usage
+ of the code is optimal.
+ """
+ if user.op == "placeholder":
+ return
+ if user.op == "output":
+ body.append("\n")
+ return
+ nodes_to_delete = user_to_last_uses.get(user, [])
+ nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
+ if len(nodes_to_delete):
+ to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
+ body.append(f"; {to_delete_str}\n")
+ else:
+ body.append("\n")
+
+ # NOTE: we add a variable to distinguish body and ckpt_func
+ def emit_node(node: Node, body):
+ maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
+ if node.op == "placeholder":
+ assert isinstance(node.target, str)
+ maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
+ free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
+ raw_name = node.target.replace("*", "")
+ if raw_name != repr(node):
+ body.append(f"{repr(node)} = {raw_name}\n")
+ return
+ elif node.op == "call_method":
+ assert isinstance(node.target, str)
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
+ f"({_format_args(node.args[1:], node.kwargs)})")
+ return
+ elif node.op == "call_function":
+ assert callable(node.target)
+ # pretty print operators
+ if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
+ assert isinstance(node.args, tuple)
+ body.append(f"{repr(node)}{maybe_type_annotation} = "
+ f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
+ return
+
+ # pretty print inplace operators; required for jit.script to work properly
+ # not currently supported in normal FX graphs, but generated by torchdynamo
+ if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
+ body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
+ f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
+ return
+
+ qualified_name = _get_qualified_name(node.target)
+ global_name = add_global(qualified_name, node.target)
+ # special case for getattr: node.args could be 2-argument or 3-argument
+ # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
+ if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
+ and node.args[1].isidentifier() and len(node.args) == 2):
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
+ return
+ body.append(
+ f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
+ if node.meta.get("is_wrapped", False):
+ wrapped_fns.setdefault(global_name)
+ return
+ elif node.op == "call_module":
+ assert isinstance(node.target, str)
+ body.append(f"{repr(node)}{maybe_type_annotation} = "
+ f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
+ return
+ elif node.op == "get_attr":
+ assert isinstance(node.target, str)
+ body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
+ return
+ elif node.op == "output":
+ if node.type is not None:
+ maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
+ body.append(self.generate_output(node.args[0]))
+ return
+ raise NotImplementedError(f"node: {node.op} {node.target}")
+
+ # Modified for activation checkpointing
+ ckpt_func = []
+
+ # if any node has a list of labels for activation_checkpoint, we
+ # will use nested type of activation checkpoint codegen
+ emit_code_with_chunk(body, nodes, emit_node, delete_unused_values, self.search_chunk, self.chunk_infos,
+ self.eval_mem)
+
+ if len(body) == 0:
+ # If the Graph has no non-placeholder nodes, no lines for the body
+ # have been emitted. To continue to have valid Python code, emit a
+ # single pass statement
+ body.append("pass\n")
+
+ if len(wrapped_fns) > 0:
+ wrap_name = add_global("wrap", torch.fx.wrap)
+ wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
+ else:
+ wrap_stmts = ""
+
+ if self._body_transformer:
+ body = self._body_transformer(body)
+
+ for name, value in self.additional_globals():
+ add_global(name, value)
+
+ # as we need colossalai.utils.checkpoint, we need to import colossalai
+ # in forward function
+ prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
+ prologue = "".join(ckpt_func) + prologue
+ prologue = prologue
+
+ code = "".join(body)
+ code = "\n".join(" " + line for line in code.split("\n"))
+ fn_code = f"""
+{wrap_stmts}
+
+{prologue}
+{code}"""
+ # print(fn_code)
+ return PythonCode(fn_code, globals_)
diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py
new file mode 100644
index 000000000000..08a55f9aa04a
--- /dev/null
+++ b/colossalai/autochunk/estimate_memory.py
@@ -0,0 +1,240 @@
+import copy
+from typing import Any, Callable, Dict, Iterable, List, Tuple
+
+import torch
+from torch.fx.node import Node
+
+from colossalai.fx.profiler import activation_size, parameter_size
+
+from .utils import NodeMgr, get_node_shape, is_non_memory_node
+
+
+class EstimateMemory(object):
+ """
+ Estimate memory with chunk
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def _get_node_size(self, x: Node) -> float:
+ """
+ return node size in MB
+ """
+ x = x.meta["tensor_meta"]
+ if not hasattr(x, "numel"):
+ out = sum([i.numel * torch.tensor([], dtype=i.dtype).element_size() for i in x])
+ else:
+ out = x.numel * torch.tensor([], dtype=x.dtype).element_size()
+ out = float(out) / 1024**2
+ return out
+
+ def _add_active_node(self, n: Node, active_nodes: Dict, chunk_ratio: float) -> None:
+ """
+ add an active node and its shape to active node dict
+ """
+ if get_node_shape(n) is None:
+ return
+ if n.op == "placeholder":
+ return
+ if n not in active_nodes:
+ node_size = self._get_node_size(n) * chunk_ratio
+ active_nodes[n] = node_size
+
+ def _build_delete_node_dict(self, node_mgr: NodeMgr) -> Dict:
+ """
+ build delete node dict, means node should be deleted at what time
+ """
+ delete_node_dict = {}
+ for idx, node in enumerate(node_mgr.get_node_list()):
+ # skip non shape node
+ if get_node_shape(node) is None:
+ continue
+ # dont remove free nodes
+ elif node.op == "placeholder":
+ delete_node_dict[node] = len(node_mgr.get_node_list())
+ # node no user
+ elif len(node.users) == 0:
+ delete_node_dict[node] = idx
+ # log max use
+ else:
+ node_user_idx = [node_mgr.find_node_idx(i) for i in node.users.keys()]
+ delete_node_dict[node] = max(node_user_idx)
+ return delete_node_dict
+
+ def _remove_deactive_node(self,
+ user_idx: int,
+ user: Node,
+ active_nodes: List,
+ delete_node_dict: List,
+ kept_nodes: List = None) -> None:
+ """
+ remove deactivate nodes from active nodes
+ """
+ if kept_nodes is None:
+ kept_nodes = []
+ if user.op in ("output",):
+ return
+
+ for node in list(active_nodes.keys()):
+ # dont delete kept nodes
+ if node in kept_nodes:
+ continue
+ # should be deleted
+ if delete_node_dict[node] <= user_idx:
+ active_nodes.pop(node)
+
+ def _get_tmp_memory(self, node, not_contiguous_list, delete=False):
+ mem = 0
+ not_contiguous_ops = ["permute"]
+
+ if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
+ for n in node.args:
+ if n in not_contiguous_list:
+ # matmul won't change origin tensor, but create a tmp copy
+ mem += self._get_node_size(n)
+ elif node.op == "call_module":
+ for n in node.args:
+ if n in not_contiguous_list:
+ # module will just make origin tensor to contiguous
+ if delete:
+ not_contiguous_list.remove(n)
+ elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops):
+ if node not in not_contiguous_list:
+ not_contiguous_list.append(node)
+ return mem
+
+ def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size):
+ if node not in chunk_node_dim:
+ return 1.0
+ node_shape = get_node_shape(node)
+ chunk_dim = chunk_node_dim[node]["chunk_dim"]
+ if chunk_dim is None:
+ return 1.0
+ else:
+ return chunk_size / float(node_shape[chunk_dim])
+
+ def _print_compute_op_mem_log(self, log, nodes, title=None):
+ if title:
+ print(title)
+ for idx, (l, n) in enumerate(zip(log, nodes)):
+ if n.op in ["placeholder", "get_attr", "output"]:
+ continue
+ if any(i in n.name for i in ["getitem", "getattr"]):
+ continue
+ print("%s:%.2f \t" % (n.name, l), end="")
+ if (idx + 1) % 3 == 0:
+ print("")
+ print("\n")
+
+ def _add_active_nodes_from_list(self, active_nodes: List, nodes: List) -> List:
+ """
+ add active nodes from nodes
+ """
+ for n in nodes:
+ self._add_active_node(n, active_nodes, 1)
+
+ def _get_memory_from_active_nodes(self, active_nodes: Dict) -> float:
+ """
+ sum all memory of active nodes
+ """
+ out = [i for i in active_nodes.values()]
+ out = sum(out)
+ return out
+
+ def estimate_chunk_inference_mem(self, node_list: List, chunk_infos: Dict = None, print_mem: bool = False):
+ """
+ Estimate inference memory with chunk
+
+ Args:
+ node_list (List): _description_
+ chunk_infos (Dict): Chunk information. Defaults to None.
+ print_mem (bool): Wether to print peak memory of every node. Defaults to False.
+
+ Returns:
+ act_memory_peak_log (List): peak memory of every node
+ act_memory_after_node_log (List): memory after excuting every node
+ active_node_list_log (List): active nodes of every node. active nodes refer to
+ nodes generated but not deleted.
+ """
+ act_memory = 0.0
+ act_memory_peak_log = []
+ act_memory_after_node_log = []
+ active_nodes = {}
+ active_nodes_log = []
+ not_contiguous_list = []
+ node_mgr = NodeMgr(node_list)
+ delete_node_dict = self._build_delete_node_dict(node_mgr)
+
+ use_chunk = True if chunk_infos is not None else False
+ chunk_within = False
+ chunk_region_idx = None
+ chunk_ratio = 1 # use it to estimate chunk mem
+ chunk_inputs_all = []
+
+ if use_chunk:
+ chunk_regions = [i["region"] for i in chunk_infos]
+ chunk_starts = [i[0] for i in chunk_regions]
+ chunk_ends = [i[1] for i in chunk_regions]
+ chunk_inputs = [i["inputs"] for i in chunk_infos]
+ chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
+ chunk_inputs_all = [j for i in chunk_inputs for j in i] + [j for i in chunk_inputs_non_chunk for j in i]
+ chunk_outputs = [i["outputs"] for i in chunk_infos]
+ chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
+ chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
+
+ for idx, node in enumerate(node_mgr.get_node_list()):
+
+ # if node in chunk start nodes, change chunk ratio and add chunk_tensor
+ if use_chunk and idx in chunk_starts:
+ chunk_within = True
+ chunk_region_idx = chunk_starts.index(idx)
+ self._add_active_nodes_from_list(active_nodes, chunk_outputs[chunk_region_idx])
+
+ # determine chunk ratio for current node
+ if chunk_within:
+ chunk_ratio = self._get_chunk_ratio(node, chunk_node_dim[chunk_region_idx],
+ chunk_sizes[chunk_region_idx])
+
+ # add current node as active node
+ self._add_active_node(node, active_nodes, chunk_ratio)
+ act_memory = self._get_memory_from_active_nodes(active_nodes)
+
+ # if node is placeholder, just add the size of the node
+ if node.op == "placeholder":
+ act_memory_peak_log.append(act_memory)
+ # skip output
+ elif node.op == "output":
+ continue
+ # no change for non compute node
+ elif is_non_memory_node(node):
+ act_memory_peak_log.append(act_memory)
+ # node is a compute op, calculate tmp
+ else:
+ # forward memory
+ # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
+ tmp_memory = self._get_tmp_memory(node, not_contiguous_list, delete=True) * chunk_ratio
+ # record max act memory
+ act_memory_peak_log.append(act_memory + tmp_memory)
+
+ # remove_deactive_node
+ self._remove_deactive_node(idx, node, active_nodes, delete_node_dict, kept_nodes=chunk_inputs_all)
+
+ # if node in chunk end nodes, restore chunk settings
+ if use_chunk and idx in chunk_ends:
+ self._remove_deactive_node(idx, node, active_nodes, delete_node_dict) # dont provide kept nodes now
+ chunk_within = False
+ chunk_ratio = 1
+ chunk_region_idx = None
+
+ act_memory = self._get_memory_from_active_nodes(active_nodes)
+ act_memory_after_node_log.append(act_memory)
+ active_nodes_log.append(active_nodes.copy())
+
+ if print_mem:
+ print("with chunk" if use_chunk else "without chunk")
+ self._print_compute_op_mem_log(act_memory_peak_log, node_mgr.get_node_list(), "peak")
+
+ # param_memory = parameter_size(gm)
+ # all_memory = act_memory + param_memory
+ return act_memory_peak_log, act_memory_after_node_log, active_nodes_log
diff --git a/colossalai/autochunk/reorder_graph.py b/colossalai/autochunk/reorder_graph.py
new file mode 100644
index 000000000000..3b00d47fb955
--- /dev/null
+++ b/colossalai/autochunk/reorder_graph.py
@@ -0,0 +1,111 @@
+from .trace_indice import TraceIndice
+from .utils import NodeMgr
+
+
+class ReorderGraph(object):
+ """
+ Reorder node list and indice trace list
+ """
+
+ def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
+ self.trace_indice = trace_indice
+ self.node_mgr = node_mgr
+ self.all_reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
+
+ def _get_reorder_map(self, chunk_info):
+ reorder_map = {i: i for i in range(len(self.node_mgr.get_node_list()))}
+
+ chunk_region_start = chunk_info["region"][0]
+ chunk_region_end = chunk_info["region"][1]
+ chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
+ chunk_prepose_nodes_idx = [self.node_mgr.find_node_idx(i) for i in chunk_prepose_nodes]
+ # put prepose nodes ahead
+ for idx, n in enumerate(chunk_prepose_nodes):
+ n_idx = chunk_prepose_nodes_idx[idx]
+ reorder_map[n_idx] = chunk_region_start + idx
+ # put other nodes after prepose nodes
+ for n in self.node_mgr.get_node_slice_by_idx(chunk_region_start, chunk_region_end + 1):
+ if n in chunk_prepose_nodes:
+ continue
+ n_idx = self.node_mgr.find_node_idx(n)
+ pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
+ reorder_map[n_idx] = n_idx + pos
+
+ return reorder_map
+
+ def _reorder_chunk_info(self, chunk_info, reorder_map):
+ # update chunk info
+ chunk_info["region"] = (
+ chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
+ chunk_info["region"][1],
+ )
+ new_inputs_dim = []
+ for _, input_dim in enumerate(chunk_info["inputs_dim"]):
+ new_input_dim = {}
+ for k, v in input_dim.items():
+ new_input_dim[reorder_map[k]] = v
+ new_inputs_dim.append(new_input_dim)
+ chunk_info["inputs_dim"] = new_inputs_dim
+ return chunk_info
+
+ def _update_all_reorder_map(self, reorder_map):
+ for origin_idx, map_idx in self.all_reorder_map.items():
+ self.all_reorder_map[origin_idx] = reorder_map[map_idx]
+
+ def _reorder_self_node_list(self, reorder_map):
+ new_node_list = [None for _ in range(len(self.node_mgr.get_node_list()))]
+ for old_idx, new_idx in reorder_map.items():
+ new_node_list[new_idx] = self.node_mgr.get_node_by_idx(old_idx)
+ self.node_mgr.update_node_list(new_node_list)
+
+ def _reorder_idx_trace(self, reorder_map):
+ # reorder list
+ new_idx_trace_list = [None for _ in range(len(self.trace_indice.indice_trace_list))]
+ for old_idx, new_idx in reorder_map.items():
+ new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
+ self.trace_indice.indice_trace_list = new_idx_trace_list
+ # update compute
+ for idx_trace in self.trace_indice.indice_trace_list:
+ compute = idx_trace["compute"]
+ for dim_compute in compute:
+ for idx, i in enumerate(dim_compute):
+ dim_compute[idx] = reorder_map[i]
+ # update source
+ for idx_trace in self.trace_indice.indice_trace_list:
+ source = idx_trace["source"]
+ for dim_idx, dim_source in enumerate(source):
+ new_dim_source = {}
+ for k, v in dim_source.items():
+ new_dim_source[reorder_map[k]] = v
+ source[dim_idx] = new_dim_source
+
+ def reorder_all(self, chunk_info):
+ if chunk_info is None:
+ return chunk_info
+ if len(chunk_info["args"]["prepose_nodes"]) == 0:
+ return chunk_info
+ reorder_map = self._get_reorder_map(chunk_info)
+ self._update_all_reorder_map(reorder_map)
+ self._reorder_idx_trace(reorder_map)
+ self._reorder_self_node_list(reorder_map)
+ chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
+ return chunk_info
+
+ def reorder_node_list(self, node_list):
+ new_node_list = [None for _ in range(len(node_list))]
+ for old_idx, new_idx in self.all_reorder_map.items():
+ new_node_list[new_idx] = node_list[old_idx]
+ return new_node_list
+
+ def tmp_reorder(self, node_list, chunk_info):
+ if len(chunk_info["args"]["prepose_nodes"]) == 0:
+ return node_list, chunk_info
+ reorder_map = self._get_reorder_map(chunk_info)
+
+ # new tmp node list
+ new_node_list = [None for _ in range(len(node_list))]
+ for old_idx, new_idx in reorder_map.items():
+ new_node_list[new_idx] = node_list[old_idx]
+
+ chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
+ return new_node_list, chunk_info
diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py
new file mode 100644
index 000000000000..326445ee9f12
--- /dev/null
+++ b/colossalai/autochunk/search_chunk.py
@@ -0,0 +1,293 @@
+import copy
+from typing import Dict, List, Tuple
+
+from torch.fx.node import Node
+
+from .estimate_memory import EstimateMemory
+from .reorder_graph import ReorderGraph
+from .select_chunk import SelectChunk
+from .trace_flow import TraceFlow
+from .trace_indice import TraceIndice
+from .utils import NodeMgr, get_logger, get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
+
+
+class SearchChunk(object):
+ """
+ This is the core class for AutoChunk.
+
+ It defines the framework of the strategy of AutoChunk.
+ Chunks will be selected one by one utill search stops.
+
+ The chunk search is as follows:
+ 1. find the peak memory node
+ 2. find the max chunk region according to the peak memory node
+ 3. find all possible chunk regions in the max chunk region
+ 4. find the best chunk region for current status
+ 5. goto 1
+
+ Attributes:
+ gm: graph model
+ print_mem (bool): print estimated memory
+ trace_index: trace the flow of every dim of every node to find all free dims
+ trace_flow: determine the region chunk strategy
+ reorder_graph: reorder nodes to improve chunk efficiency
+ estimate_memory: estimate memory with chunk
+ select_chunk: select the best chunk region
+
+ Args:
+ gm: graph model
+ max_memory (int): max memory in MB
+ print_mem (bool): print estimated memory
+ """
+
+ def __init__(self, gm, max_memory=None, print_mem=False, print_progress=False) -> None:
+ self.print_mem = print_mem
+ self.max_memory = max_memory
+ self.print_progress = print_progress
+ self.node_mgr = NodeMgr(list(gm.graph.nodes))
+ self.trace_indice = TraceIndice(self.node_mgr)
+ self.estimate_memory = EstimateMemory()
+ self._init_trace()
+ self.trace_flow = TraceFlow(self.trace_indice, self.node_mgr)
+ self.reorder_graph = ReorderGraph(self.trace_indice, self.node_mgr)
+ self.select_chunk = SelectChunk(
+ self.trace_indice,
+ self.estimate_memory,
+ self.reorder_graph,
+ self.node_mgr,
+ max_memory=max_memory,
+ )
+
+ def _init_trace(self) -> None:
+ """
+ find the max trace range for every node
+ reduce the computation complexity of trace_indice
+ """
+ # find all max ranges
+ active_nodes = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())[2]
+ # set trace range and do the trace
+ if self.print_progress:
+ get_logger().info("AutoChunk start tracing indice")
+ self.trace_indice.set_active_nodes(active_nodes)
+ self.trace_indice.trace_indice()
+
+ def _find_peak_region(self, mem_peak: List) -> int:
+ """
+ find peak node, along with its neighbour nodes exceeds max mem
+ """
+ max_value = max(mem_peak)
+ max_idx = mem_peak.index(max_value)
+ peak_region = [max_idx, max_idx]
+ if self.max_memory is None:
+ return peak_region
+
+ # to left
+ count = 0
+ for i in range(max_idx - 1, -1, -1):
+ if mem_peak[i] > self.max_memory:
+ peak_region[0] = i
+ else:
+ count += 1
+ if count >= 3:
+ break
+ # to right
+ count = 0
+ for i in range(max_idx + 1, len(mem_peak) - 1):
+ if mem_peak[i] > self.max_memory:
+ peak_region[1] = i
+ count = 0
+ else:
+ count += 1
+ if count >= 3:
+ break
+
+ return peak_region
+
+ def _search_max_chunk_region(self, active_node: List, peak_region: int, chunk_regions: List = None) -> Tuple:
+ """
+ Search max chunk region according to peak memory node
+
+ Chunk region starts extending from the peak node, stops where free var num is min
+
+ Args:
+ active_node (List): active node status for every node
+ peak_node_idx (int): peak memory node idx
+ chunk_regions (List): chunk region infos
+
+ Returns:
+ chunk_region_start (int)
+ chunk_region_end (int)
+ """
+ # check if peak node already in chunkinfo
+ if chunk_regions is not None:
+ for i in chunk_regions:
+ if i["region"][0] < peak_region[0] <= i["region"][1] or \
+ i["region"][0] < peak_region[1] <= i["region"][1]:
+ return None
+
+ active_node_num = [len(i) for i in active_node]
+ window_size = 100
+ # search min for start
+ min_num = 1e4
+ for i in range(peak_region[0], max(peak_region[0] - window_size, -1), -1):
+ if active_node_num[i] < min_num:
+ min_num = active_node_num[i]
+ chunk_region_start = i
+ # search min for end
+ min_num = 1e4
+ for i in range(peak_region[1], min(peak_region[1] + window_size, len(active_node_num))):
+ if active_node_num[i] < min_num:
+ min_num = active_node_num[i]
+ chunk_region_end = i
+
+ # avoid chunk regions overlap
+ if chunk_regions is not None:
+ for i in chunk_regions:
+ region = i["region"]
+ if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
+ return None
+ elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
+ chunk_region_start = region[1] + 1
+ elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
+ chunk_region_end = region[0] - 1
+ return chunk_region_start, chunk_region_end
+
+ def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
+ """
+ Find chunk info for a region.
+
+ We are given the region start and region end, and need to find out all chunk info for it.
+ We first loop every dim of start node and end node, to see if we can find dim pair,
+ which is linked in a flow and not computed.
+ If found, we then search flow in the whole region to find out all chunk infos.
+
+ Args:
+ input_trace (List): node's input trace in region
+ output_trace (List): node's output trace in region
+ start_idx (int): region start node index
+ end_idx (int): region end node index
+
+ Returns:
+ chunk_infos: possible regions found
+ """
+ start_traces = input_trace[start_idx]
+ if len(start_traces) > 1: # TODO need to be removed
+ return []
+ end_trace = output_trace[end_idx]
+ end_node = self.node_mgr.get_node_by_idx(end_idx)
+
+ chunk_infos = []
+ for end_dim, _ in enumerate(end_trace["indice"]):
+ for start_node, start_trace in start_traces.items():
+ for start_dim, _ in enumerate(start_trace["indice"]):
+ if not self.trace_flow.check_region_start_end(start_node, start_dim, start_idx, end_node, end_dim,
+ end_idx):
+ continue
+ # flow search
+ chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
+ if chunk_info is None:
+ continue
+ chunk_infos.append(chunk_info)
+ return chunk_infos
+
+ def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_region: Node) -> List:
+ """
+ Search every possible region within the max chunk region.
+
+ Args:
+ max_chunk_region (Tuple)
+ peak_node (Node): peak memory node
+
+ Returns:
+ possible_chunk_region (List)
+ """
+ possible_chunk_region = []
+ output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
+ input_trace = [] # trace of a node's input nodes
+ for _, n in enumerate(self.node_mgr.get_node_list()):
+ cur_trace = {}
+ for arg in n.args:
+ if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
+ cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
+ input_trace.append(cur_trace)
+
+ for start_idx in range(max_chunk_region[0], peak_region[0] + 1):
+ for end_idx in range(peak_region[1], max_chunk_region[1] + 1):
+ # skip non compute nodes
+ if is_non_compute_node(self.node_mgr.get_node_by_idx(start_idx)) or is_non_compute_node(
+ self.node_mgr.get_node_by_idx(end_idx)):
+ continue
+ # select free dim
+ chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
+ if len(chunk_info) > 0:
+ possible_chunk_region.extend(chunk_info)
+ return possible_chunk_region
+
+ def _step_search(
+ self,
+ mem_peak: List[float],
+ active_node: List[List[Node]],
+ chunk_infos: List[Dict],
+ ) -> Dict:
+ """
+ Find one chunk region
+
+ The chunk search is as follows:
+ 1. find the peak memory node
+ 2. find the max chunk region according to the peak memory node
+ 3. find all possible chunk regions in the max chunk region
+ 4. find the best chunk region for current status
+
+ Args:
+ mem_peak (List): peak memory for every node
+ active_node (List[List[Node]]): active node for every node
+ chunk_infos (List[Dict]): all chunk info
+
+ Returns:
+ best_chunk_region (Dict)
+ """
+ peak_region = self._find_peak_region(mem_peak)
+ max_chunk_region = self._search_max_chunk_region(active_node, peak_region, chunk_infos)
+ if max_chunk_region == None:
+ return None
+ possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_region)
+ best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
+ best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
+ return best_chunk_region
+
+ def search_region(self) -> Dict:
+ """
+ Search all chunk regions:
+ 1. Estimate current memory
+ 2. Find best chunk for current memory
+ 3. goto 1
+
+ Returns:
+ chunk_infos (Dict)
+ """
+ if self.print_progress:
+ get_logger().info("AutoChunk start searching chunk regions")
+
+ chunk_infos = []
+ init_mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list())
+ mem_peak = init_mem_peak
+
+ while True:
+ chunk_info = self._step_search(mem_peak, active_node, chunk_infos)
+ if chunk_info is None:
+ break
+ chunk_infos.append(chunk_info)
+
+ mem_peak, _, active_node = self.estimate_memory.estimate_chunk_inference_mem(
+ self.node_mgr.get_node_list(), chunk_infos)
+
+ if self.print_progress:
+ get_logger().info("AutoChunk find chunk region %d = (%d, %d)" %
+ (len(chunk_infos), chunk_info["region"][0], chunk_info["region"][1]))
+
+ if self.print_mem:
+ self.print_mem = False
+ self.estimate_memory.estimate_chunk_inference_mem(self.node_mgr.get_node_list(),
+ chunk_infos,
+ print_mem=True)
+ return chunk_infos
diff --git a/colossalai/autochunk/select_chunk.py b/colossalai/autochunk/select_chunk.py
new file mode 100644
index 000000000000..94a29bfd5691
--- /dev/null
+++ b/colossalai/autochunk/select_chunk.py
@@ -0,0 +1,181 @@
+from .estimate_memory import EstimateMemory
+from .reorder_graph import ReorderGraph
+from .trace_indice import TraceIndice
+from .utils import NodeMgr, is_non_compute_node
+
+
+class SelectChunk(object):
+
+ def __init__(
+ self,
+ trace_indice: TraceIndice,
+ estimate_memory: EstimateMemory,
+ reorder_graph: ReorderGraph,
+ node_mgr: NodeMgr,
+ max_memory=None,
+ ):
+ self.trace_indice = trace_indice
+ self.estimate_memory = estimate_memory
+ self.reorder_graph = reorder_graph
+ self.node_mgr = node_mgr
+ if max_memory is not None:
+ self.stratge = "fit_memory"
+ self.max_memory = max_memory # MB
+ else:
+ self.stratge = "min_memory"
+
+ def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
+ if self.stratge == "min_memory":
+ best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
+ elif self.stratge == "fit_memory":
+ best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak)
+ else:
+ raise RuntimeError()
+ return best_region
+
+ def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak):
+ # stop chunk if max memory satisfy memory limit
+ if max(mem_peak) < self.max_memory:
+ return None
+
+ # remove illegal regions
+ illegal_regions = []
+ for i in possible_chunk_regions:
+ if not self._is_legal_region(i, chunk_infos):
+ illegal_regions.append(i)
+ for i in illegal_regions:
+ if i in possible_chunk_regions:
+ possible_chunk_regions.remove(i)
+
+ if len(possible_chunk_regions) == 0:
+ return None
+
+ # get mem for chunk region
+ regions_dict = []
+ for region in possible_chunk_regions:
+ cur_region = region.copy()
+ cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
+ cur_chunk_infos = chunk_infos + [cur_region]
+ cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
+ cur_chunk_region_peak = cur_mem[cur_region["region"][0]:cur_region["region"][1] + 1]
+ cur_chunk_region_max_peak = max(cur_chunk_region_peak)
+ if cur_chunk_region_max_peak < self.max_memory:
+ regions_dict.append({
+ "chunk_info": region,
+ "chunk_max_mem": cur_chunk_region_max_peak,
+ "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
+ "reorder_chunk_info": cur_region,
+ "reorder_node_list": cur_node_list,
+ })
+ # no region found
+ if len(regions_dict) == 0:
+ raise RuntimeError("Search failed. Try a larger memory threshold.")
+
+ # select the min chunk len
+ chunk_len = [i["chunk_len"] for i in regions_dict]
+ best_region_idx = chunk_len.index(min(chunk_len))
+ best_region = regions_dict[best_region_idx]
+
+ # get max chunk size
+ best_region = self._get_fit_chunk_size(best_region, chunk_infos)
+ return best_region
+
+ def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
+ chunk_size = 1
+ reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
+ reorder_chunk_info["chunk_size"] = chunk_size
+ cur_chunk_max_mem = 0
+ # search a region
+ while cur_chunk_max_mem < self.max_memory:
+ chunk_size *= 2
+ reorder_chunk_info["chunk_size"] = chunk_size
+ cur_chunk_infos = chunk_infos + [reorder_chunk_info]
+ cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
+ cur_chunk_infos)[0]
+ cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0]:reorder_chunk_info["region"][1] + 1])
+ # search exact size
+ chunk_info = chunk_region_dict["chunk_info"]
+ chunk_info["chunk_size"] = self._chunk_size_binary_search(chunk_size // 2, chunk_size, chunk_region_dict,
+ chunk_infos)
+ return chunk_info
+
+ def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
+ if left >= 16:
+ gap = 4
+ else:
+ gap = 1
+ chunk_info = chunk_region_dict["reorder_chunk_info"]
+ while right >= left + gap:
+ mid = int((left + right) / 2 + 0.5)
+ chunk_info["chunk_size"] = mid
+ cur_chunk_infos = chunk_infos + [chunk_info]
+ cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(chunk_region_dict["reorder_node_list"],
+ cur_chunk_infos)[0]
+ cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0]:chunk_info["region"][1] + 1])
+ if cur_chunk_max_mem >= self.max_memory:
+ right = mid - gap
+ else:
+ left = mid + gap
+ return left
+
+ def _get_compute_node_num(self, start, end):
+ count = 0
+ for i in self.node_mgr.get_node_slice_by_idx(start, end + 1):
+ if not is_non_compute_node(i):
+ count += 1
+ return count
+
+ def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
+ # remove illegal regions
+ illegal_regions = []
+ for i in possible_chunk_regions:
+ if not self._is_legal_region(i, chunk_infos):
+ illegal_regions.append(i)
+ for i in illegal_regions:
+ if i in possible_chunk_regions:
+ possible_chunk_regions.remove(i)
+
+ if len(possible_chunk_regions) == 0:
+ return None
+
+ # get max possible chunk region
+ max_possible_chunk_region = (min([i["region"][0] for i in possible_chunk_regions]),
+ max([i["region"][1] for i in possible_chunk_regions]))
+
+ # get mem for chunk region
+ regions_dict_list = []
+ for region in possible_chunk_regions:
+ cur_region = region.copy()
+ cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region)
+ cur_chunk_infos = chunk_infos + [cur_region]
+ cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0]
+ cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0]:max_possible_chunk_region[1] + 1]
+ cur_chunk_region_max_peak = max(cur_chunk_region_peak)
+ regions_dict_list.append({
+ "chunk_info": region,
+ "chunk_max_mem": cur_chunk_region_max_peak,
+ "chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]),
+ "reorder_chunk_info": cur_region,
+ "reorder_node_list": cur_node_list,
+ })
+
+ # select the min mem
+ chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list]
+ best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
+ best_region = regions_dict_list[best_region_idx]["chunk_info"]
+ if best_region is not None:
+ best_region["chunk_size"] = 1
+ return best_region
+
+ def _is_legal_region(self, cur_chunk_info, chunk_infos):
+ (chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
+ if cur_chunk_info in chunk_infos:
+ return False
+ if chunk_region_end < chunk_region_start:
+ return False
+ for i in chunk_infos:
+ region = i["region"]
+ if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) or
+ (chunk_region_start < region[0] and chunk_region_end < region[0])):
+ return False
+ return True
diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py
new file mode 100644
index 000000000000..16815215f52b
--- /dev/null
+++ b/colossalai/autochunk/trace_flow.py
@@ -0,0 +1,485 @@
+from typing import Dict, List, Tuple
+
+from torch.fx.node import Node
+
+from .trace_indice import TraceIndice
+from .utils import (
+ NodeMgr,
+ find_chunk_all_input_nodes,
+ find_chunk_compute_input_and_output_nodes,
+ find_tensor_shape_node,
+ flat_list,
+ get_node_name,
+ get_node_shape,
+ is_non_compute_node,
+)
+
+
+class TraceFlow(object):
+
+ def __init__(self, trace_indice: TraceIndice, node_mgr: NodeMgr) -> None:
+ self.trace_indice = trace_indice
+ self.node_mgr = node_mgr
+
+ def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
+ """
+ Check 2 given index: one index should be source of the other
+ Args:
+ start_idx(int): start node chunk dim
+ start_node(node): start node
+ end_idx(int): end node chunk dim
+ end_node(node): end node
+
+ Returns:
+ bool: True if check pass
+ """
+ # we use start_node_idx instead of real chunk index
+ start_node_idx = self.node_mgr.find_node_idx(start_node)
+ end_node_trace = self.trace_indice._find_trace_from_node(end_node)
+ end_node_trace_source = end_node_trace["source"][end_dim]
+ sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
+ for node_idx, node_dim in sorted_source:
+ if node_idx == start_node_idx and start_dim in node_dim:
+ return True
+ # it means we meet a node outside the loop, and the node is not input node
+ if node_idx < start_node_idx:
+ return False
+ return False
+
+ def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
+ """
+ Check 2 given index: check they haven't been computed in the source trace.
+ Args:
+ start_idx(int): start node chunk dim
+ start_node(node): start node
+ end_idx(int): end node chunk dim
+ end_node(node): end node
+
+ Returns:
+ bool: True if check pass
+ """
+ end_node_trace = self.trace_indice._find_trace_from_node(end_node)
+ end_node_compute = end_node_trace["compute"][end_dim]
+ if any(start_idx <= i <= end_idx for i in end_node_compute):
+ return False
+ return True
+
+ def _assgin_single_node_flow(
+ self,
+ arg_node: Node,
+ start_idx: int,
+ end_idx: int,
+ cur_node: Node,
+ cur_node_dim: int,
+ cur_node_compute: Dict,
+ cur_node_source: Dict,
+ cur_node_fix_dim: List,
+ all_node_info: Dict,
+ next_node_list: List,
+ ) -> bool:
+ """
+ Given the current node and one of its arg node,
+ this function finds out arg node's chunk dim and fix dim
+
+ Args:
+ arg_node (Node): input node
+ start_idx (int): chunk region start
+ end_idx (int): chunk region end
+ cur_node_dim (int): current node chunk dim
+ cur_node_compute (Dict): current node compute dict
+ cur_node_source (Dict): current node source dict
+ cur_node_fix_dim (List): current node fix dim
+ all_node_info (Dict): all node chunk info in the chunk region
+ next_node_list (List)
+
+ Returns:
+ bool: True if this node can be added to the flow, vice versa.
+ """
+ arg_idx = self.node_mgr.find_node_idx(arg_node)
+ # arg in chunk range or be inputs
+ if not (start_idx <= arg_idx < end_idx):
+ return True
+
+ # get fix dim
+ arg_fix_dim = []
+ if cur_node_dim is not None:
+ for i in cur_node_fix_dim:
+ fix_dim_source = cur_node_source[i]
+ if arg_idx in fix_dim_source:
+ arg_fix_dim.append(fix_dim_source[arg_idx][0])
+ if arg_node in all_node_info:
+ arg_fix_dim = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
+
+ # find arg dim
+ if cur_node_dim is not None:
+ # dim is computed
+ if arg_idx in cur_node_compute[cur_node_dim]:
+ return False
+ if arg_idx not in cur_node_source[cur_node_dim]:
+ arg_dim = None
+ else:
+ arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
+ # chunk dim cannot be in fix dims
+ if arg_dim in arg_fix_dim:
+ return False
+ # chunk dim should be None if shape size is 1
+ if get_node_shape(arg_node)[arg_dim] == 1:
+ arg_dim = None
+ # chunk shape should equal cur node
+ elif get_node_shape(arg_node)[arg_dim] != 1:
+ if cur_node_dim is not None and get_node_shape(cur_node)[cur_node_dim] != 1:
+ if get_node_shape(arg_node)[arg_dim] != get_node_shape(cur_node)[cur_node_dim]:
+ return False
+ else:
+ arg_dim = None
+
+ # add arg rest dim as fix dim
+ arg_fix_dim = list(range(len(get_node_shape(arg_node))))
+ if arg_dim is not None:
+ arg_fix_dim.remove(arg_dim)
+
+ # if already in node_info, arg dim must be same
+ if arg_node in all_node_info:
+ if all_node_info[arg_node]["chunk_dim"] != arg_dim:
+ return False
+ all_node_info[arg_node]["fix_dim"] = arg_fix_dim
+ # else add it to list
+ else:
+ all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
+
+ next_node_list.append(arg_node)
+ return True
+
+ def _get_all_node_info(self, end_dim, start_idx, end_idx):
+ cur_node_list = [self.node_mgr.get_node_by_idx(end_idx)] # start from the last node
+ all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
+
+ while len(cur_node_list) > 0:
+ next_node_list = []
+
+ for cur_node in cur_node_list:
+ # get cur node info
+ cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
+ cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
+ if cur_node_chunk_dim is not None:
+ cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
+ cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
+ else:
+ cur_node_compute = cur_node_source = None
+
+ # get all valid args
+ arg_list = []
+ for arg in cur_node.all_input_nodes:
+ if type(arg) != type(cur_node):
+ continue
+ if is_non_compute_node(arg):
+ continue
+ if get_node_shape(arg) is None:
+ continue
+ arg_list.append(arg)
+ flow_flag = self._assgin_single_node_flow(
+ arg,
+ start_idx,
+ end_idx,
+ cur_node,
+ cur_node_chunk_dim,
+ cur_node_compute,
+ cur_node_source,
+ cur_node_fix_dim,
+ all_node_info,
+ next_node_list,
+ )
+ if flow_flag == False:
+ return None
+
+ cur_node_list = next_node_list
+ return all_node_info
+
+ def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
+ """
+ Get chunk dim for every input node for their every entry, remove unchunked nodes
+
+ Args:
+ inputs (List[Node]): input nodes
+ all_node_info (Dict): describe all node's chunk dim and fix dim
+ start_idx (int): chunk start idx
+ end_idx (int): chunk end idx
+
+ Returns:
+ inputs (List(Node)): new inputs
+ inputs_dim (List): chunk dim for inputs
+ """
+ inputs_dim = []
+ remove_inputs = []
+ for input_node in inputs:
+ input_dict = {}
+ input_node_idx = self.node_mgr.find_node_idx(input_node)
+ for user in input_node.users.keys():
+ # skip non compute
+ if is_non_compute_node(user):
+ continue
+ # untraced node, mostly non compute
+ if user not in all_node_info:
+ continue
+ user_idx = self.node_mgr.find_node_idx(user)
+ if start_idx <= user_idx <= end_idx:
+ chunk_dim = all_node_info[user]["chunk_dim"]
+ if chunk_dim is not None:
+ user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
+ if input_node_idx in user_source:
+ if get_node_shape(input_node)[user_source[input_node_idx][0]] == 1:
+ input_dict[user_idx] = [None]
+ else:
+ input_dict[user_idx] = user_source[input_node_idx]
+ else:
+ return None, None
+ if len(input_dict) == 0:
+ remove_inputs.append(input_node)
+ else:
+ inputs_dim.append(input_dict)
+ # remove unchunked inputs
+ for i in remove_inputs:
+ if i in inputs:
+ inputs.remove(i)
+ return inputs, inputs_dim
+
+ def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int, chunk_info) -> List[Node]:
+ """
+ get all useless nodes in chunk region and prepose them
+
+ Args:
+ all_node_info (Dict): describe all node's chunk dim and fix dim
+ start_idx (int): chunk start idx
+ end_idx (int): chunk end idx
+
+ Returns:
+ List[Node]: all nodes to be preposed
+ """
+ # get all possible prepose nodes
+ maybe_prepose_nodes = []
+ for node, node_info in all_node_info.items():
+ if node_info["chunk_dim"] is None:
+ maybe_prepose_nodes.append(node)
+ for node in self.node_mgr.get_node_slice_by_idx(start_idx, end_idx):
+ if node not in all_node_info and node not in chunk_info["outputs"]:
+ maybe_prepose_nodes.append(node)
+ maybe_prepose_nodes.sort(
+ key=lambda x: self.node_mgr.find_node_idx(x),
+ reverse=True,
+ ) # from last node to first node
+ prepose_nodes = []
+ # set every node as root, search its args, if all legal, turn root and args as prepose nodes
+ while len(maybe_prepose_nodes) > 0:
+ tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
+ tmp_cur_related_prepose_nodes = []
+ prepose_flag = True
+
+ # loop cur node's all arg until out of chunk
+ while len(tmp_cur_prepose_nodes) > 0:
+ if prepose_flag == False:
+ break
+ tmp_next_prepose_nodes = []
+ tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes)
+ for cur_prepose_node in tmp_cur_prepose_nodes:
+ if prepose_flag == False:
+ break
+ for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
+ if type(cur_prepose_node_arg) != type(cur_prepose_node):
+ continue
+ # out of loop
+ if not (start_idx <= self.node_mgr.find_node_idx(cur_prepose_node_arg) < end_idx):
+ continue
+ # compute op in loop
+ elif cur_prepose_node_arg in all_node_info:
+ if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
+ tmp_next_prepose_nodes.append(cur_prepose_node_arg)
+ else:
+ prepose_flag = False
+ break
+ # non compute op
+ else:
+ tmp_next_prepose_nodes.append(cur_prepose_node_arg)
+ tmp_cur_prepose_nodes = tmp_next_prepose_nodes
+
+ if prepose_flag == False:
+ maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
+ continue
+ else:
+ for n in tmp_cur_related_prepose_nodes:
+ if n not in prepose_nodes:
+ prepose_nodes.append(n)
+ if n in maybe_prepose_nodes:
+ maybe_prepose_nodes.remove(n)
+ # sort by index
+ prepose_nodes.sort(key=lambda x: self.node_mgr.find_node_idx(x))
+ chunk_info["args"]["prepose_nodes"] = prepose_nodes
+
+ def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
+ # we need to log input nodes to avoid deleteing them in the loop
+ chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
+ # also need to get some prepose node's arg out of non_chunk_inputs
+ for n in chunk_info["args"]["prepose_nodes"]:
+ chunk_node_list.remove(n)
+ non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
+ for i in non_chunk_inputs:
+ if i not in chunk_info["inputs"]:
+ chunk_info["inputs_non_chunk"].append(i)
+ return chunk_info
+
+ def flow_search(self, start_idx, start_dim, end_idx, end_dim):
+ inputs, outputs = find_chunk_compute_input_and_output_nodes(
+ self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1))
+
+ # get every node's chunk dim and fix dim
+ all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
+ if all_node_info is None:
+ return None
+
+ chunk_info = {
+ "region": (start_idx, end_idx),
+ "inputs": [],
+ "inputs_non_chunk": [],
+ "inputs_dim": [],
+ "outputs": [self.node_mgr.get_node_by_idx(end_idx)],
+ "outputs_non_tensor": {},
+ "outputs_dim": [end_dim],
+ "node_chunk_dim": all_node_info,
+ "args": {},
+ }
+
+ # find chunk info for other outputs
+ if len(find_tensor_shape_node(outputs)) > 1:
+ chunk_info = self._get_other_output_info(outputs, start_idx, start_dim, end_idx, end_dim, chunk_info)
+ if chunk_info is None:
+ return None
+
+ # get input nodes' chunk dim
+ inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
+ if inputs is None:
+ return None
+ chunk_info["inputs"] = inputs
+ chunk_info["inputs_dim"] = inputs_dim
+
+ # move useless nodes ahead of loop
+ self._get_prepose_nodes(all_node_info, start_idx, end_idx, chunk_info)
+
+ # find non chunk inputs
+ chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
+
+ # reassgin reshape size, some size may have changed due to chunk
+ chunk_info = self._reassgin_reshape_size(chunk_info)
+
+ return chunk_info
+
+ def _get_other_output_info(self, outputs: List[Node], start_idx: int, start_dim: int, end_idx: int, end_dim: int,
+ chunk_info: Dict):
+ start_node = self.node_mgr.get_node_by_idx(start_idx)
+ # loop all outputs
+ for output in outputs:
+ output_legal = False
+ output_idx = self.node_mgr.find_node_idx(output)
+ # skip the origin output
+ if output_idx == end_idx:
+ continue
+ # skip non tensor
+ if get_node_shape(output) is None:
+ # log shape tensor
+ if len(output.meta['fwd_out']) > 0 and isinstance(output.meta['fwd_out'][0], int):
+ chunk_info["outputs_non_tensor"][output] = str(output.meta['fwd_out'])
+ continue
+ # loop every dim of outputs, try to find a legal one
+ for output_dim in range(len(get_node_shape(output))):
+ if not self.check_region_start_end(start_node, start_dim, start_idx, output, output_dim, output_idx):
+ continue
+ new_all_node_info = self._get_all_node_info(output_dim, start_idx, output_idx)
+ if new_all_node_info is None:
+ continue
+ # check node info legal
+ if self._update_chunk_info(chunk_info, new_all_node_info, output, output_dim) == True:
+ output_legal = True
+ break
+ # not legal
+ if output_legal == False:
+ return None
+ return chunk_info
+
+ def _update_chunk_info(self, chunk_info: Dict, new_all_node_info: Dict, output: Node, output_dim: int) -> bool:
+ """
+ check if there is conflict between new node info and old chunk info. If not, update old chunk info
+ """
+ # check if conflict
+ overlap_flag = False
+ for k, v in new_all_node_info.items():
+ if k in chunk_info["node_chunk_dim"]:
+ overlap_flag = True
+ if chunk_info["node_chunk_dim"][k]["chunk_dim"] != v["chunk_dim"]:
+ return False
+ # if no overlap, we just consider them as prepose nodes, instead of new output
+ if overlap_flag == False:
+ return True
+ # update chunk info
+ for k, v in new_all_node_info.items():
+ if k in chunk_info["node_chunk_dim"]:
+ chunk_info["node_chunk_dim"][k]["fix_dim"] = list(
+ set(chunk_info["node_chunk_dim"][k]["fix_dim"] + v["fix_dim"]))
+ else:
+ chunk_info["node_chunk_dim"][k] = v
+ chunk_info["outputs"].append(output)
+ chunk_info["outputs_dim"].append(output_dim)
+ return True
+
+ def _reassgin_reshape_size(self, chunk_info):
+ """
+ Some shape args in reshape may have changed due to chunk
+ reassgin those changed shape
+ """
+ chunk_region = chunk_info["region"]
+ reshape_size = {}
+ chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"][0]]
+ for node in self.node_mgr.get_node_slice_by_idx(chunk_region[0], chunk_region[1] + 1):
+ if any(i == get_node_name(node) for i in ["reshape", "view"]):
+ if node in chunk_info["args"]["prepose_nodes"]:
+ continue
+ if node.args[0] in chunk_info["inputs_non_chunk"]:
+ continue
+ reshape_args = flat_list(node.args[1:])
+ if len(reshape_args) == 1 and get_node_shape(reshape_args[0]) is None and len(
+ reshape_args[0].meta['fwd_out']) > 1:
+ continue
+ chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
+ new_shape = ""
+ for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
+ if reshape_arg_dim == chunk_dim:
+ new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
+ else:
+ if isinstance(reshape_arg, int):
+ new_shape += "%s, " % str(reshape_arg)
+ else:
+ new_shape += "%s, " % reshape_arg.name
+ new_shape = new_shape[:-2]
+ origin_shape = str(reshape_args)[1:-1]
+ reshape_size[node.name] = [origin_shape, new_shape]
+ chunk_info["reshape_size"] = reshape_size
+ return chunk_info
+
+ def check_region_start_end(self, start_node: Node, start_dim: int, start_idx: int, end_node: Node, end_dim: int,
+ end_idx: int) -> bool:
+ """
+ check if region start and end is legal
+ """
+ # dim cannot be None
+ if (get_node_shape(end_node) is None or get_node_shape(start_node) is None):
+ return False
+ # dim size cannot be 1
+ if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
+ return False
+ # must have users
+ if len(end_node.users) == 0:
+ return False
+ # check index source align
+ if not self.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
+ return False
+ # check index copmute
+ if not self.check_index_compute(start_idx, end_dim, end_node, end_idx):
+ return False
+ return True
diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py
new file mode 100644
index 000000000000..307f4de326d7
--- /dev/null
+++ b/colossalai/autochunk/trace_indice.py
@@ -0,0 +1,930 @@
+import copy
+from typing import Dict, List, Tuple
+
+from torch.fx.node import Node
+
+from .utils import NodeMgr, find_first_tensor_arg, flat_list, get_module_node_name, get_node_name, get_node_shape
+
+
+class TraceIndice(object):
+ """
+ Trace all indice infomation for every node.
+
+ Indice is a logical concept. Equal dims can been treated as one indice.
+ eg. dim(x1) = [a, b, c]
+ dim(x2) = [d, e, f]
+ and we have x3 = x1 * x2.
+ then a=d, b=e, c=f, due to the broadcast property,
+ dim(x1)=dim(x2)=dim(x3)=[a, b, c]
+ This class will record every node's dims' indice, compute and source.
+
+ Attibutes:
+ node_list (List)
+ indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
+ indice_view_list (Dict): not used for now
+ indice_count (int): record indice number
+
+ Args:
+ node_list (List)
+ """
+
+ def __init__(self, node_mgr: NodeMgr) -> None:
+ self.node_mgr = node_mgr
+ self.indice_trace_list = self._init_indice_trace_list()
+ self.indice_view_list = {}
+ self.indice_count = -1
+ self.active_node_list = []
+
+ def _init_indice_trace_list(self) -> List:
+ indice_trace_list = []
+ for n in self.node_mgr.get_node_list():
+ if get_node_shape(n) != None:
+ cur_trace = {
+ "indice": [None for _ in range(len(get_node_shape(n)))],
+ "compute": [[] for _ in range(len(get_node_shape(n)))],
+ "source": [{} for _ in range(len(get_node_shape(n)))],
+ }
+ else:
+ cur_trace = {"indice": [], "compute": [], "source": []}
+ indice_trace_list.append(cur_trace)
+ return indice_trace_list
+
+ def set_active_nodes(self, active_node_list: List) -> None:
+ self.active_node_list = active_node_list
+
+ def _add_indice(self) -> int:
+ """
+ Update the count and return it. To record the idx number.
+
+ Returns:
+ indice_count: int
+ """
+ self.indice_count += 1
+ return self.indice_count
+
+ def _del_dim(self, idx: int, dim_idx: int) -> None:
+ """
+ delete a dim for indice, compute and source
+ """
+ self.indice_trace_list[idx]["indice"].pop(dim_idx)
+ self.indice_trace_list[idx]["compute"].pop(dim_idx)
+ self.indice_trace_list[idx]["source"].pop(dim_idx)
+
+ def _add_dim(self, node_idx: int, dim_idx: int) -> None:
+ """
+ add a dim for indice, compute and source
+ """
+ # need to remap if dim_idx < 0, e.g. -1
+ if dim_idx < 0:
+ dim_idx = list(range(len(self.indice_trace_list[node_idx]["indice"]) + 1))[dim_idx]
+ self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
+ self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
+ self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
+
+ def _add_source(
+ self,
+ node_from: Node,
+ node_from_dim: int,
+ node_to: Node,
+ node_to_dim: int,
+ init=False,
+ ) -> None:
+ node_from_dim = self._transform_indice(node_from, node_from_dim)
+ node_from_trace_source = self._find_source_trace_from_node(node_from)
+ node_to_dim = self._transform_indice(node_to, node_to_dim)
+ node_to_trace_source = self._find_source_trace_from_node(node_to)
+ node_from_idx = self.node_mgr.find_node_idx(node_from)
+ if init:
+ node_to_trace_source[node_to_dim] = {}
+ # add dim to cur new source
+ if node_from_idx not in node_to_trace_source[node_to_dim]:
+ node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
+ else:
+ if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
+ node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim)
+ # update inputs source
+ for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
+ if node_idx not in node_to_trace_source[node_to_dim]:
+ node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
+ else:
+ for d in node_dim:
+ if d not in node_to_trace_source[node_to_dim][node_idx]:
+ node_to_trace_source[node_to_dim][node_idx].append(d)
+
+ def _transform_indice(self, node: Node, node_dim: int) -> int:
+ node_idx = self._find_indice_trace_from_node(node)
+ dims = list(range(len(node_idx)))
+ return dims[node_dim]
+
+ def _inherit_indice(
+ self,
+ node_from: Node,
+ node_from_dim: int,
+ node_to: Node,
+ node_to_dim: int,
+ init: bool = True,
+ ) -> None:
+ """
+ node_to's node_to_dim inherit node_from's node_from_dim by indice, compute and source
+ """
+ node_from_dim = self._transform_indice(node_from, node_from_dim)
+ node_to_dim = self._transform_indice(node_to, node_to_dim)
+ node_from_trace = self._find_trace_from_node(node_from)
+ node_to_trace = self._find_trace_from_node(node_to)
+ if init:
+ node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
+ node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
+ else:
+ for j in node_from_trace["compute"][node_from_dim]:
+ if j not in node_to_trace["compute"][node_to_dim]:
+ node_to_trace["compute"][node_to_dim].append(j)
+ self._add_source(node_from, node_from_dim, node_to, node_to_dim, init)
+
+ def _inherit_all_indice(self, node_from: Node, node_to: Node) -> None:
+ """
+ inherit all dims with init
+ """
+ # find indice just for assert length
+ node_from_indice = self._find_indice_trace_from_node(node_from)
+ node_to_indice = self._find_indice_trace_from_node(node_to)
+ assert len(node_from_indice) == len(node_to_indice)
+ for i in range(len(node_from_indice)):
+ self._inherit_indice(node_from, i, node_to, i, init=True)
+
+ def _inherit_more_indice_from_node_with_exclude(self, node_from: Node, node_to: Node, exclude: List = None) -> None:
+ """
+ inheirt indice from node without init
+ """
+ if exclude == None:
+ exclude = []
+ else:
+ exclude = [self._transform_indice(node_to, i) for i in exclude]
+ node_from_compute = self._find_compute_trace_from_node(node_from)
+ node_to_compute = self._find_compute_trace_from_node(node_to)
+ # assert len(node_from_compute) == len(node_to_compute)
+ for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
+ if self._transform_indice(node_to, i) in exclude:
+ continue
+ self._inherit_indice(node_from, i, node_to, i, init=False)
+
+ def _mark_computation(self, node: Node, idx: int, dim: int) -> None:
+ """
+ Mark some dims of node as computed.
+
+ Args:
+ node (node)
+ idx (int): node index
+ dim (list or int): dims to be marked as computed
+ """
+ if isinstance(dim, int):
+ dim = [dim]
+ dims = list(range(len(get_node_shape(node))))
+ for d in dim:
+ cur_dim = dims[d]
+ if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
+ self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
+
+ def _find_trace_from_node(self, node: Node) -> Dict:
+ """
+ Find node idx and compute trace by the node.
+
+ Args:
+ node (node)
+ Returns:
+ idx (list): idx of the node
+ compute (list): computed idx of the node.
+ """
+ node_idx = self.node_mgr.find_node_idx(node)
+ node_dict = self.indice_trace_list[node_idx]
+ return node_dict
+
+ def _find_source_trace_from_node(self, node: Node) -> List:
+ """
+ Find node source trace by the node.
+
+ Args:
+ node (node)
+ Returns:
+ idx (list): idx of the node
+ compute (list): computed idx of the node.
+ """
+ node_idx = self.node_mgr.find_node_idx(node)
+ node_dict = self.indice_trace_list[node_idx]
+ return node_dict["source"]
+
+ def _find_indice_trace_from_node(self, node) -> List:
+ """
+ Find node idx trace by the node.
+
+ Args:
+ node (node)
+ Returns:
+ idx (list): idx of the node
+ """
+ node_idx = self.node_mgr.find_node_idx(node)
+ return self.indice_trace_list[node_idx]["indice"]
+
+ def _find_compute_trace_from_node(self, node: Node) -> List:
+ """
+ Find node compute trace by the node.
+
+ Args:
+ node (node)
+ Returns:
+ compute (list): computed idx of the node.
+ """
+ node_idx = self.node_mgr.find_node_idx(node)
+ return self.indice_trace_list[node_idx]["compute"]
+
+ def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None) -> None:
+ """
+ Assign node's trace as its input node.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ if input_node == None:
+ input_node = find_first_tensor_arg(node)
+ self._inherit_all_indice(input_node, node)
+
+ def _assign_all_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Add new indice for all node's dims.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ shape = node.meta["tensor_meta"].shape
+ if shape is None:
+ return
+ new_trace = []
+ for _ in shape:
+ new_trace.append(self._add_indice())
+ self.indice_trace_list[node_idx]["indice"] = new_trace
+
+ def _assign_transpose_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for transpose op.
+ 1. swap input's dim according to transpose args
+ 2. inherit input's computation
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ input_node = node.args[0]
+ tranpose_dim = node.args[1:]
+
+ self._assign_indice_as_input(node, node_idx, input_node)
+ self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
+ self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
+
+ def _assign_permute_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for permute op.
+ 1. swap input's dim according to permute args
+ 2. inherit input's computation
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ permute_dim = flat_list(node.args[1:])
+ input_node = node.args[0]
+
+ self._assign_indice_as_input(node, node_idx, input_node)
+ for idx, d in enumerate(permute_dim):
+ self._inherit_indice(input_node, d, node, idx)
+
+ def _assign_linear_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for linear op.
+ 1. copy trace from input node and change last indice accroding to weight
+ 2. mark equal for input node last indice, weight first dim and bias dim.
+ 3. inherit input's computation, mark computation for last dim.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ self._assign_indice_as_input(node, node_idx)
+
+ if len(node.args) >= 2:
+ weight = node.args[1]
+ self._inherit_indice(weight, 1, node, -1)
+ else:
+ self._del_dim(node_idx, -1)
+ self._add_dim(node_idx, -1)
+ self._mark_computation(node, node_idx, [-1])
+
+ def _assign_addmm_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for addmm op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ bias, input_node, weight = node.args
+ assert len(get_node_shape(bias)) == 1 and len(get_node_shape(weight)) == 2
+ self._assign_indice_as_input(node, node_idx, input_node)
+ self._inherit_indice(weight, 1, node, -1)
+ self._inherit_more_indice_from_node_with_exclude(bias, node)
+
+ self._mark_computation(node, node_idx, [-1])
+
+ def _assign_baddbmm_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for baddbmm(batch add and batch matmul) op.
+ add, matmul_left, matmul_right = args
+ out = add + (matmul_left x matmul_right)
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ add, matmul_left, matmul_right = node.args
+
+ assert get_node_shape(add) == get_node_shape(node)
+ assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
+ self._assign_indice_as_input(node, node_idx, matmul_left)
+ # matmul
+ self._inherit_indice(matmul_right, -1, node, -1)
+ self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-2, -1])
+ self._mark_computation(node, node_idx, [-1])
+ # add
+ self._inherit_more_indice_from_node_with_exclude(add, node)
+
+ def _assign_matmul_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for matmul op.
+ 1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
+ 2. mark equal for input matmul_left -1 indice and matmul_right -2 dim.
+ 3. inherit matmul_left and matmul_right computation, mark computation for last dim.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ matmul_left, matmul_right = node.args
+
+ assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
+ self._assign_indice_as_input(node, node_idx, matmul_left)
+
+ self._inherit_indice(matmul_right, -1, node, -1)
+ self._inherit_more_indice_from_node_with_exclude(matmul_right, node, [-1, -2])
+ self._mark_computation(node, node_idx, [-1])
+
+ def _assign_conv2d_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for conv2d op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ # get conv module
+ node_targets = node.target.split(".")
+ conv_module = node.graph.owning_module
+ for i in node_targets:
+ conv_module = getattr(conv_module, i)
+ assert conv_module.dilation == (1, 1), "dilation for conv2d not implemented"
+
+ # get conv input
+ assert len(node.args) == 1
+ input_node = node.args[0]
+ assert len(get_node_shape(input_node)) == 4
+
+ # assgin index
+ self._assign_indice_as_input(node, node_idx, input_node)
+ self._del_dim(node_idx, 1)
+ self._add_dim(node_idx, 1)
+ self._mark_computation(node, node_idx, [1, 2, 3])
+
+ def _assign_interpolate_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for interpolate op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ # get conv input
+ assert node.kwargs['size'] is None
+ assert len(get_node_shape(node)) == 4
+
+ # assgin index
+ self._assign_indice_as_input(node, node_idx)
+ self._mark_computation(node, node_idx, [-1, -2])
+
+ def _assign_layernorm_indice(self, node, idx):
+ """
+ Assign indice for layernorm op.
+ 1. assign indice as input node
+ 2. inherit computation and mark last 2 dims as computed.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ self._assign_indice_as_input(node, idx)
+ self._mark_computation(node, idx, [-1])
+
+ def _assign_groupnorm_indice(self, node, idx):
+ """
+ Assign indice for groupnorm op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ assert len(get_node_shape(node)) == 4
+ self._assign_indice_as_input(node, idx)
+ self._mark_computation(node, idx, [-1, -2, -3])
+
+ def _assign_elementwise_indice(self, node, idx):
+ """
+ Assign indice for element-wise op (eg. relu sigmoid add mul).
+ 1. assign indice as input node
+ 2. inherit computation from all input nodes.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ self._assign_indice_as_input(node, idx)
+ nodes_in = []
+ for node_in in node.args:
+ if type(node_in) == type(node):
+ nodes_in.append(node_in)
+ self._inherit_more_indice_from_node_with_exclude(node_in, node)
+
+ def _assgin_no_change_indice(self, node, idx):
+ self._assign_indice_as_input(node, idx)
+ for node_in in node.args:
+ if type(node_in) == type(node):
+ self._inherit_more_indice_from_node_with_exclude(node_in, node)
+
+ def _assign_einsum_indice(self, node, idx):
+ """
+ Assign indice for einsum op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ patterns = node.args[0]
+ input_nodes = node.args[1:]
+
+ patterns = patterns.replace(" ", "")
+ left, right = patterns.split("->")
+ left = left.split(",")
+
+ if "..." in right:
+ replace_list = "!@#$%^&*"
+ target_len = len(get_node_shape(node))
+ add_len = target_len - len(right) + 3
+ replace_str = replace_list[:add_len]
+ right = right.replace("...", replace_str)
+ for ll in range(len(left)):
+ left[ll] = left[ll].replace("...", replace_str)
+
+ all_index = []
+ for i in left:
+ for c in i:
+ all_index.append(c)
+ all_index = set(all_index)
+
+ for right_idx, right_indice in enumerate(right):
+ for left_idx, left_str in enumerate(left):
+ if right_indice in left_str:
+ source_idx = left_str.index(right_indice)
+ self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx)
+
+ def _assign_softmax_indice(self, node, idx):
+ """
+ Assign indice for softmax op.
+ 1. assign indice as input node
+ 2. inherit computation and mark softmax dim as computed.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ self._assign_indice_as_input(node, idx)
+ self._mark_computation(node, idx, [node.kwargs["dim"]])
+
+ def _assign_split_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for split op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ self._assign_indice_as_input(node, node_idx)
+ dim_idx = node.kwargs["dim"]
+ self._del_dim(node_idx, dim_idx)
+ self._add_dim(node_idx, dim_idx)
+
+ def _assign_unsqueeze_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for unsqueeze op.
+ 1. assign new indice for unsqueeze dim
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ self._del_dim(node_idx, -1)
+ self._assign_indice_as_input(node, node_idx)
+ dim_idx = node.args[1]
+ # unsqueeze(-1) = unsqueeze(shape_num + 1)
+ if dim_idx < 0:
+ dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
+ self._add_dim(node_idx, dim_idx)
+
+ def _assign_cat_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for cat op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ nodes_in = flat_list(node.args[0])
+ self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
+ for n in nodes_in[1:]:
+ self._inherit_more_indice_from_node_with_exclude(n, node)
+ cat_dim = node.kwargs["dim"]
+ self._del_dim(node_idx, cat_dim)
+ self._add_dim(node_idx, cat_dim)
+
+ def _assign_sum_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for sum op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ nodes_in = flat_list(node.args[0])
+ self._add_dim(node_idx, 0)
+ self._assign_indice_as_input(node, node_idx, input_node=nodes_in[0])
+ for n in nodes_in[1:]:
+ self._inherit_more_indice_from_node_with_exclude(n, node)
+ cat_dim = node.kwargs["dim"]
+ self._del_dim(node_idx, cat_dim)
+
+ def _assign_flatten_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for flatten op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ nodes_in = node.args[0]
+ nodes_in_shape = get_node_shape(nodes_in)
+ flatten_start_dim = node.args[1]
+ flatten_dim_num = len(nodes_in_shape) - flatten_start_dim - 1
+ assert flatten_dim_num > 0
+ for _ in range(flatten_dim_num):
+ self._add_dim(node_idx, 0)
+ self._assign_indice_as_input(node, node_idx, nodes_in)
+ for _ in range(flatten_dim_num + 1):
+ self._del_dim(node_idx, -1)
+ self._add_dim(node_idx, -1)
+
+ def _assign_expand_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for expand op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ expand_shape = node.args[1:]
+ node_in_shape = get_node_shape(node.args[0])
+ assert len(expand_shape) == len(node_in_shape)
+ self._assign_indice_as_input(node, node_idx)
+ for i in range(len(node_in_shape)):
+ if expand_shape[i] == node_in_shape[i] or expand_shape[i] == -1:
+ continue
+ elif expand_shape[i] > node_in_shape[i]:
+ self._del_dim(node_idx, i)
+ self._add_dim(node_idx, i)
+ else:
+ raise RuntimeError()
+
+ def _assign_unbind_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for unbind op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ unbind_dim = node.args[1]
+ self._add_dim(node_idx, unbind_dim)
+ self._assign_indice_as_input(node, node_idx)
+ self._del_dim(node_idx, unbind_dim)
+
+ def _assign_embedding_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for embedding op.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ self._del_dim(node_idx, -1)
+ self._assign_indice_as_input(node, node_idx)
+ self._add_dim(node_idx, -1)
+
+ def _assign_getitem_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for getitem.
+ getitem can act like slice sometimes
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ node_args = flat_list(node.args[1:])
+
+ # deal with split
+ if get_node_name(node.args[0]) == "split":
+ self._assign_indice_as_input(node, node_idx)
+ self._del_dim(node_idx, node.args[0].kwargs["dim"])
+ self._add_dim(node_idx, node.args[0].kwargs["dim"])
+ return
+
+ # skip non tensor
+ if get_node_shape(node) is None:
+ return
+
+ # find if slice
+ flag = False
+ for node_arg in node_args:
+ node_arg_str = str(node_arg)
+ if any(i == node_arg_str for i in ["None", "Ellipsis"]):
+ flag = True
+ break
+ if "slice" in node_arg_str:
+ flag = True
+ break
+ if flag == False:
+ return
+
+ # node args should be like [Ellipsis, slice(start, step, end), None]
+ node_shape = get_node_shape(node)
+ origin_idx_count = 0
+ new_idx_count = 0
+ new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
+ for _ in range(new_dim_num):
+ self._del_dim(node_idx, 0)
+ delete_dim_num = sum([1 if str(i) == "0" else 0 for i in node_args])
+ for _ in range(delete_dim_num):
+ self._add_dim(node_idx, 0)
+ self._assign_indice_as_input(node, node_idx)
+
+ for _, node_arg in enumerate(node_args):
+ node_arg_str = str(node_arg)
+ # Ellipsis means [..., ]
+ if "Ellipsis" == node_arg_str:
+ shape_gap = len(node_shape) - len(node_args) + 1
+ origin_idx_count += shape_gap
+ new_idx_count += shape_gap
+ # slice(None, None, None) means all indexes
+ elif "slice" in node_arg_str:
+ if "slice(None, None, None)" != node_arg_str:
+ self._del_dim(node_idx, new_idx_count)
+ self._add_dim(node_idx, new_idx_count)
+ origin_idx_count += 1
+ new_idx_count += 1
+ # None means a new dim
+ elif "None" == node_arg_str:
+ self._add_dim(node_idx, new_idx_count)
+ new_idx_count += 1
+ elif "0" == node_arg_str:
+ self._del_dim(node_idx, new_idx_count)
+ origin_idx_count += 1
+ else:
+ raise NotImplementedError()
+
+ def _assign_view_reshape_indice(self, node: Node, node_idx: int) -> None:
+ """
+ Assign indice for view and reshape op.
+ 1. get origin shape and target shape by meta info.
+ 2. compute the real value of -1 in target shape.
+ 3. determine changed dim, and assgin indice for generated dim.
+ 4. log changed dim and generated dim for restore
+ 5. inherit computation.
+ 6. look into view list to see whether the view is associated with other,
+ if so assgin equal dim according to previous view.
+
+ Args:
+ node (node)
+ node_idx (int)
+ """
+ # get data, turn into number
+ origin_node = node.args[0]
+ origin_shape = origin_node.meta["tensor_meta"].shape
+ target_shape = []
+ unflated_args = flat_list(node.args)
+ for i in range(1, len(unflated_args)):
+ if isinstance(unflated_args[i], int):
+ target_shape.append(unflated_args[i])
+ else:
+ target_shape.extend(unflated_args[i].meta["fwd_out"])
+
+ # compute the value of -1
+ if -1 in target_shape:
+ origin_product = 1
+ for i in origin_shape:
+ origin_product *= i
+ target_product = -1
+ for i in target_shape:
+ target_product *= i
+ shape_idx = target_shape.index(-1)
+ target_shape[shape_idx] = origin_product // target_product
+
+ # find same dim
+ dim_to_same_dim = []
+ dim_from_same_dim = []
+ for i in range(len(origin_shape)):
+ if origin_shape[i] == target_shape[i]:
+ dim_to_same_dim.append(i)
+ dim_from_same_dim.append(i)
+ else:
+ break
+ for i in range(-1, -len(origin_shape), -1):
+ if origin_shape[i] == target_shape[i]:
+ dim_to_same_dim.append(len(target_shape) + i)
+ dim_from_same_dim.append(len(origin_shape) + i)
+ else:
+ break
+
+ dim_from = list(set(range(len(origin_shape))) - set(dim_from_same_dim))
+ dim_to = list(set(range(len(target_shape))) - set(dim_to_same_dim))
+ assert len(dim_from) == 1 or len(dim_to) == 1 or len(dim_from) == len(dim_to)
+
+ dim_diff = len(dim_from) - len(dim_to)
+ if dim_diff > 0:
+ # dim merge
+ for i in range(dim_diff):
+ self._add_dim(node_idx, -1)
+ elif dim_diff < 0:
+ # dim expand
+ for i in range(-dim_diff):
+ self._del_dim(node_idx, -1)
+
+ # get new indice
+ origin_trace = self._find_indice_trace_from_node(origin_node)
+ self._assign_indice_as_input(node, node_idx, origin_node)
+ dim_from.reverse()
+ for i in dim_from:
+ self._del_dim(node_idx, i)
+ for i in dim_to:
+ self._add_dim(node_idx, i)
+ dim_from.reverse()
+
+ # inheirt indice from current node
+ if len(dim_from) != 0 and len(dim_to) != 0:
+ if dim_diff == 1:
+ if origin_shape[dim_from[0]] == 1:
+ self._inherit_indice(origin_node, dim_from[1], node, dim_to[0], init=False)
+ elif origin_shape[dim_from[1]] == 1:
+ self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
+ elif dim_diff == -1:
+ if target_shape[dim_to[0]] == 1:
+ self._inherit_indice(origin_node, dim_from[0], node, dim_to[1], init=False)
+ elif target_shape[dim_to[1]] == 1:
+ self._inherit_indice(origin_node, dim_from[0], node, dim_to[0], init=False)
+
+ # log view, not used now
+ view_dict = {
+ "idx_from": [origin_trace[i] for i in dim_from],
+ "dim_from": dim_from,
+ "idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to],
+ "dim_to": dim_to,
+ }
+ self.indice_view_list[node] = view_dict
+
+ def _clear_trace(self, node_idx: int) -> None:
+ """
+ clear too far trace to speed up computation
+ """
+ trace_barrier = max(node_idx - 100, 0)
+ active_nodes = self.active_node_list[trace_barrier]
+ active_nodes = [self.node_mgr.find_node_idx(i) for i in active_nodes.keys()]
+
+ trace = self.indice_trace_list[node_idx]
+ # clear compute
+ for dim_compute in trace["compute"]:
+ for i in range(len(dim_compute) - 1, -1, -1):
+ if (dim_compute[i] < trace_barrier and dim_compute[i] not in active_nodes):
+ dim_compute.pop(i)
+ continue
+ # clear source
+ for dim_source in trace["source"]:
+ for k in list(dim_source.keys()):
+ if k < trace_barrier and k not in active_nodes:
+ dim_source.pop(k)
+
+ def trace_indice(self) -> None:
+ for idx, node in enumerate(self.node_mgr.get_node_list()):
+ node_name = get_node_name(node)
+ if node.op == "placeholder":
+ self._assign_all_indice(node, idx)
+ elif node.op == "call_method":
+ if "transpose" == node_name:
+ self._assign_transpose_indice(node, idx)
+ elif "permute" == node_name:
+ self._assign_permute_indice(node, idx)
+ elif "view" == node_name or "reshape" == node_name:
+ self._assign_view_reshape_indice(node, idx)
+ elif "unsqueeze" == node_name:
+ self._assign_unsqueeze_indice(node, idx)
+ elif "split" == node_name:
+ self._assign_split_indice(node, idx)
+ elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
+ self._assgin_no_change_indice(node, idx)
+ elif "new_ones" == node_name:
+ self._assign_all_indice(node, idx)
+ elif "flatten" == node_name:
+ self._assign_flatten_indice(node, idx)
+ elif "expand" == node_name:
+ self._assign_expand_indice(node, idx)
+ elif "unbind" == node_name:
+ self._assign_unbind_indice(node, idx)
+ elif "softmax" == node_name:
+ self._assign_softmax_indice(node, idx)
+ elif any(i == node_name for i in ["size"]):
+ continue
+ else:
+ raise NotImplementedError(node_name, "method not implemented yet!")
+ elif node.op == "call_function":
+ if "linear" == node_name:
+ self._assign_linear_indice(node, idx)
+ elif "cat" == node_name:
+ self._assign_cat_indice(node, idx)
+ elif any(n == node_name for n in ["matmul", "bmm"]):
+ self._assign_matmul_indice(node, idx)
+ elif "softmax" == node_name:
+ self._assign_softmax_indice(node, idx)
+ elif any(n == node_name for n in [
+ "mul", "add", "sigmoid", "relu", "sub", "truediv", "pow", "dropout", "where", "tanh", "exp",
+ "sin", "cos"
+ ]):
+ self._assign_elementwise_indice(node, idx)
+ elif "einsum" == node_name:
+ self._assign_einsum_indice(node, idx)
+ elif "sum" == node_name:
+ self._assign_sum_indice(node, idx)
+ elif "layer_norm" == node_name:
+ self._assign_layernorm_indice(node, idx)
+ elif "getitem" == node_name:
+ self._assign_getitem_indice(node, idx)
+ elif "addmm" == node_name:
+ self._assign_addmm_indice(node, idx)
+ elif "baddbmm" == node_name:
+ self._assign_baddbmm_indice(node, idx)
+ elif "interpolate" == node_name:
+ self._assign_interpolate_indice(node, idx)
+ elif any(i == node_name for i in ["arange", "ones", "ones_like", "tensor", "empty"]):
+ self._assign_all_indice(node, idx)
+ elif any(i == node_name for i in ["getattr", "eq", "_assert_is_none", "_assert", "finfo"]):
+ continue
+ else:
+ raise NotImplementedError(node_name, "function not implemented yet!")
+ elif node.op == "call_module":
+ node_name = get_module_node_name(node)
+ if "layernorm" == node_name:
+ self._assign_layernorm_indice(node, idx)
+ elif "groupnorm" == node_name:
+ self._assign_groupnorm_indice(node, idx)
+ elif "embedding" == node_name:
+ self._assign_embedding_indice(node, idx)
+ elif "linear" == node_name:
+ self._assign_linear_indice(node, idx)
+ elif "conv2d" == node_name:
+ self._assign_conv2d_indice(node, idx)
+ elif "identity" == node_name:
+ self._assgin_no_change_indice(node, idx)
+ elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
+ self._assign_elementwise_indice(node, idx)
+ else:
+ raise NotImplementedError(node_name, "module not implemented yet!")
+ elif node.op == "get_attr":
+ self._assign_all_indice(node, idx) # get param
+ elif node.op == "output":
+ continue
+ else:
+ raise NotImplementedError(node.op, "op not implemented yet!")
+
+ # limit trace range
+ self._clear_trace(idx)
diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py
new file mode 100644
index 000000000000..064baa047155
--- /dev/null
+++ b/colossalai/autochunk/utils.py
@@ -0,0 +1,244 @@
+from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
+
+from torch.fx.node import Node
+
+from colossalai.logging import get_dist_logger
+
+NON_COMPUTE_OP = ["placeholder", "get_attr", "output"]
+NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"]
+logger = get_dist_logger()
+
+
+class NodeMgr(object):
+
+ def __init__(self, nodes_list: List[Node]) -> None:
+ self._node_list = nodes_list
+ self._node_dict = {}
+ self._set_node_dict()
+
+ def _set_node_dict(self) -> None:
+ """
+ create a dict {node_name: node_idx}
+ """
+ self._node_dict.clear()
+ for idx, node in enumerate(self._node_list):
+ self._node_dict[node.name] = idx
+
+ def find_node_idx(self, node: Node) -> int:
+ """
+ find node's index
+ """
+ return self._node_dict[node.name]
+
+ def find_node_idx_by_name(self, node_name: str) -> int:
+ """
+ find node's index
+ """
+ return self._node_dict[node_name]
+
+ def get_node_by_idx(self, idx: int) -> Node:
+ """
+ get a node by index
+ """
+ return self._node_list[idx]
+
+ def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
+ """
+ get a slice of node by index
+ """
+ return self._node_list[start:end]
+
+ def get_node_list(self) -> List:
+ """
+ get full node list
+ """
+ return self._node_list
+
+ def update_node_list(self, node_list: List) -> None:
+ """
+ update node list, reset node dict
+ """
+ self._node_list = node_list
+ self._set_node_dict()
+
+
+def get_logger() -> Any:
+ return logger
+
+
+def flat_list(inputs: Any) -> List:
+ """
+ flat a list by recursion
+ """
+ if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
+ return [inputs]
+ res = []
+ for i in inputs:
+ if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
+ res.extend(flat_list(i))
+ elif isinstance(i, dict):
+ res.extend(flat_list(list(i.keys())))
+ else:
+ res.append(i)
+ return res
+
+
+def find_first_tensor_arg(node: Node) -> Node:
+ """
+ Find the first input tensor arg for a node
+ """
+ for arg in node.args:
+ if type(arg) == type(node):
+ return arg
+ raise RuntimeError()
+
+
+def is_non_compute_node(node: Node) -> bool:
+ if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
+ return True
+ if "getitem" in node.name:
+ if get_node_shape(node) is not None:
+ return False
+ node_args = flat_list(node.args[1:])
+ for node_arg in node_args:
+ if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
+ return False
+ if "slice" in str(node_arg):
+ return False
+ return True
+ return False
+
+
+def get_node_shape(node: Node) -> Any:
+ """
+ return node data shape
+ """
+ if get_node_name(node) in ["split", "unbind"]:
+ return node.meta["tensor_meta"][0].shape
+ if hasattr(node.meta["tensor_meta"], "shape"):
+ return node.meta["tensor_meta"].shape
+ return None
+
+
+def is_non_memory_node(node: Node) -> bool:
+ if "getitem" in node.name:
+ return True
+ if "output" in node.op:
+ return True
+ return is_non_compute_node(node)
+
+
+def is_non_compute_node_except_placeholder(node: Node) -> bool:
+ if "placeholder" in node.op:
+ return False
+ return is_non_compute_node(node)
+
+
+def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
+ if "output" in node.op:
+ return False
+ return is_non_compute_node_except_placeholder(node)
+
+
+def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
+ for key, value in user_to_last_uses.items():
+ for n in value:
+ if n.op == "placeholder":
+ user_to_last_uses[key].remove(n)
+
+
+def find_chunk_all_input_nodes(nodes: List[Node]) -> List:
+ """
+ Find non-compute input and output node names.
+ input nodes are nodes used in the list
+ output nodes are nodes will use nodes in the list
+ """
+ input_nodes = []
+ for node in nodes:
+ for input_node in node._input_nodes.keys():
+ if input_node not in nodes and input_node not in input_nodes:
+ input_nodes.append(input_node)
+ return input_nodes
+
+
+def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]:
+ """
+ Find non-compute input and output node names.
+ input nodes are nodes used in the list
+ output nodes are nodes will use nodes in the list
+ """
+ input_nodes = []
+ output_nodes = []
+
+ # if a node has an input node which is not in the node list
+ # we treat that input node as the input of the checkpoint function
+ for node in nodes:
+ for input_node in node._input_nodes.keys():
+ if (input_node not in nodes and input_node not in input_nodes
+ and not is_non_compute_node_except_placeholder(input_node)):
+ input_nodes.append(input_node)
+
+ # if a node has a user node which is not in the node list
+ # we treat that user node as the node receiving the current node output
+ for node in nodes:
+ for output_node in node.users.keys():
+ if (output_node not in nodes and node not in output_nodes
+ and not is_non_compute_node_except_placeholder_output(output_node)):
+ output_nodes.append(node)
+
+ return input_nodes, output_nodes
+
+
+def get_module_node_name(node: Node) -> str:
+ """
+ get module class name
+ """
+ node_targets = node.target.split(".")
+ module = node.graph.owning_module
+ for i in node_targets:
+ module = getattr(module, i)
+ module_name = str(module.__class__).split(".")[-1][:-2]
+ module_name = module_name.lower()
+ return module_name
+
+
+def get_node_name(node: Node) -> str:
+ """
+ get node name
+ """
+ node_name = node.name
+ if "_" in node_name:
+ for i in range(len(node_name) - 1, -1, -1):
+ if node_name[i] == "_":
+ node_name = node_name[:i]
+ break
+ elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]:
+ continue
+ else:
+ break
+ return node_name
+
+
+def find_tensor_node(node_list: List[Node]) -> List[Node]:
+ """
+ find tensor nodes from a node list
+ """
+ out = []
+ for node in node_list:
+ if get_node_shape(node) is not None:
+ out.append(node)
+ return out
+
+
+def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
+ """
+ find tensor and shape nodes from a node list
+ """
+ out = []
+ for node in node_list:
+ if get_node_shape(node) is not None:
+ out.append(node)
+ elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
+ node.meta['fwd_out'][0], int):
+ out.append(node)
+ return out
diff --git a/colossalai/booster/__init__.py b/colossalai/booster/__init__.py
new file mode 100644
index 000000000000..3b3f45bb0fe2
--- /dev/null
+++ b/colossalai/booster/__init__.py
@@ -0,0 +1,4 @@
+from .accelerator import Accelerator
+from .booster import Booster
+from .environment_table import EnvironmentTable
+from .plugin import Plugin
diff --git a/colossalai/booster/accelerator.py b/colossalai/booster/accelerator.py
new file mode 100644
index 000000000000..fc2c4a40068b
--- /dev/null
+++ b/colossalai/booster/accelerator.py
@@ -0,0 +1,54 @@
+import torch
+import torch.nn as nn
+
+__all__ = ['Accelerator']
+
+_supported_devices = [
+ 'cpu',
+ 'cuda',
+
+ # To be supported
+ # 'xpu',
+ # 'npu',
+ # 'tpu',
+]
+
+
+class Accelerator:
+ """
+ Accelerator is an abstraction for the hardware device that is used to run the model.
+
+ Args:
+ device (str): The device to be used. Currently only support 'cpu' and 'gpu'.
+ """
+
+ def __init__(self, device: str):
+ self.device = device
+
+ assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"
+
+ def bind(self):
+ """
+ Set the default device for the current process.
+ """
+ if self.device == 'cpu':
+ pass
+ elif self.device == 'cuda':
+ # TODO(FrankLeeeee): use global environment to check if it is a dist job
+ # if is_distributed:
+ # local_rank = EnvTable().get_local_rank()
+ # torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
+ torch.cuda.set_device(torch.device('cuda'))
+ pass
+ else:
+ raise ValueError(f"Device {self.device} is not supported yet")
+
+ def configure_model(self, model: nn.Module) -> nn.Module:
+ """
+ Move the model to the device.
+
+ Args:
+ model (nn.Module): The model to be moved.
+ """
+ model = model.to(torch.device(self.device))
+ return model
diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py
new file mode 100644
index 000000000000..230c65a9e0a1
--- /dev/null
+++ b/colossalai/booster/booster.py
@@ -0,0 +1,157 @@
+import warnings
+from contextlib import contextmanager
+from typing import Callable, Iterator, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+
+from .accelerator import Accelerator
+from .mixed_precision import MixedPrecision, mixed_precision_factory
+from .plugin import Plugin
+
+__all__ = ['Booster']
+
+
+class Booster:
+ """
+ Booster is a high-level API for training neural networks. It provides a unified interface for
+ training with different precisio, accelerator, and plugin.
+
+ Examples:
+ >>> colossalai.launch(...)
+ >>> plugin = GeminiPlugin(stage=3, ...)
+ >>> booster = Booster(precision='fp16', plugin=plugin)
+ >>>
+ >>> model = GPT2()
+ >>> optimizer = Adam(model.parameters())
+ >>> dataloader = Dataloader(Dataset)
+ >>> lr_scheduler = LinearWarmupScheduler()
+ >>> criterion = GPTLMLoss()
+ >>>
+ >>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
+ >>>
+ >>> for epoch in range(max_epochs):
+ >>> for input_ids, attention_mask in dataloader:
+ >>> outputs = model(input_ids, attention_mask)
+ >>> loss = criterion(outputs.logits, input_ids)
+ >>> booster.backward(loss, optimizer)
+ >>> optimizer.step()
+ >>> lr_scheduler.step()
+ >>> optimizer.zero_grad()
+
+
+ Args:
+ device (str or torch.device): The device to run the training. Default: 'cuda'.
+ mixed_precision (str or MixedPrecision): The mixed precision to run the training. Default: None.
+ If the argument is a string, it can be 'fp16', 'fp16_apex', 'bf16', or 'fp8'.
+ 'fp16' would use PyTorch AMP while `fp16_apex` would use Nvidia Apex.
+ plugin (Plugin): The plugin to run the training. Default: None.
+ """
+
+ def __init__(self,
+ device: str = 'cuda',
+ mixed_precision: Union[MixedPrecision, str] = None,
+ plugin: Optional[Plugin] = None) -> None:
+ if plugin is not None:
+ assert isinstance(
+ plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
+ self.plugin = plugin
+
+ # set accelerator
+ if self.plugin and self.plugin.control_device:
+ self.accelerator = None
+ warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
+ else:
+ self.accelerator = Accelerator(device)
+
+ # set precision
+ if mixed_precision is None or (self.plugin and self.plugin.control_precision):
+ self.mixed_precision = None
+ warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
+ else:
+ # validate and set precision
+ if isinstance(MixedPrecision, str):
+ # the user will take the default arguments for amp training
+ self.mixed_precision = mixed_precision_factory(mixed_precision)
+ elif isinstance(mixed_precision, MixedPrecision):
+ # the user can customize the arguments by passing the precision object
+ self.mixed_precision = mixed_precision
+ else:
+ raise ValueError(
+ f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
+ )
+
+ def boost(
+ self,
+ model: nn.Module,
+ optimizer: Optimizer,
+ criterion: Callable = None,
+ dataloader: DataLoader = None,
+ lr_scheduler: LRScheduler = None,
+ ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
+ """
+ Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
+
+ Args:
+ model (nn.Module): The model to be boosted.
+ optimizer (Optimizer): The optimizer to be boosted.
+ criterion (Callable): The criterion to be boosted.
+ dataloader (DataLoader): The dataloader to be boosted.
+ lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
+ """
+ # TODO(FrankLeeeee): consider multi-model and multi-optimizer case
+ # TODO(FrankLeeeee): consider multi-dataloader case
+ # transform model for mixed precision
+ if self.plugin:
+ model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
+ model, optimizer, criterion, dataloader, lr_scheduler)
+
+ if self.plugin and not self.plugin.control_device:
+ # transform model for accelerator
+ model = self.accelerator.configure(model)
+
+ if self.mixed_precision and self.plugin and not self.plugin.control_precision:
+ # transform model for mixed precision
+ model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
+
+ return model, optimizer, criterion, dataloader, lr_scheduler
+
+ def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
+ # TODO: implement this method with plugin
+ optimizer.backward(loss)
+
+ def execute_pipeline(self,
+ data_iter: Iterator,
+ model: nn.Module,
+ criterion: Callable[[torch.Tensor], torch.Tensor],
+ optimizer: Optimizer,
+ return_loss: bool = True,
+ return_outputs: bool = False) -> Tuple[Optional[torch.Tensor], ...]:
+ # TODO: implement this method
+ # run pipeline forward backward pass
+ # return loss or outputs if needed
+ pass
+
+ def no_sync(self, model: nn.Module) -> contextmanager:
+ assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
+ assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
+ return self.plugin.no_sync(model)
+
+ def save(self,
+ obj: Union[nn.Module, Optimizer, LRScheduler],
+ path_like: str,
+ plan: str = 'torch',
+ **kwargs) -> None:
+ # TODO: implement this method
+ pass
+
+ def load(self,
+ obj: Union[nn.Module, Optimizer, LRScheduler],
+ path_like: str,
+ plan: str = 'torch',
+ **kwargs) -> None:
+ # TODO: implement this method
+ pass
diff --git a/colossalai/booster/environment_table.py b/colossalai/booster/environment_table.py
new file mode 100644
index 000000000000..4b16f120c1b9
--- /dev/null
+++ b/colossalai/booster/environment_table.py
@@ -0,0 +1,18 @@
+from typing import List
+
+__all__ = ['EnvironmentTable']
+
+
+class EnvironmentTable:
+
+ def __init__(self, intra_op_world_sizes: List[int]):
+ # TODO: implement this method
+ pass
+
+ @property
+ def is_master(self) -> bool:
+ # TODO: implement this method
+ pass
+
+ # TODO: implement more utility methods as given in
+ # https://github.com/hpcaitech/ColossalAI/issues/3051
diff --git a/colossalai/booster/interface/__init__.py b/colossalai/booster/interface/__init__.py
new file mode 100644
index 000000000000..8892a13e1814
--- /dev/null
+++ b/colossalai/booster/interface/__init__.py
@@ -0,0 +1,3 @@
+from .optimizer import OptimizerWrapper
+
+__all__ = ['OptimizerWrapper']
diff --git a/colossalai/booster/interface/optimizer.py b/colossalai/booster/interface/optimizer.py
new file mode 100644
index 000000000000..dd9acab17584
--- /dev/null
+++ b/colossalai/booster/interface/optimizer.py
@@ -0,0 +1,121 @@
+from typing import Union
+
+import torch.nn as nn
+from torch import Tensor
+from torch.optim import Optimizer
+
+
+class OptimizerWrapper:
+ """
+ A standard interface for optimizers wrapped by the Booster.
+
+ Args:
+ optim (Optimizer): The optimizer to be wrapped.
+ """
+
+ def __init__(self, optim: Optimizer):
+ self.optim = optim
+
+ @property
+ def parameters(self):
+ params = []
+
+ for group in self.param_groups:
+ params += group['params']
+ return params
+
+ @property
+ def param_groups(self):
+ return self.optim.param_groups
+
+ @property
+ def defaults(self):
+ return self.optim.defaults
+
+ def add_param_group(self, *args, **kwargs):
+ return self.optim.add_param_group(*args, **kwargs)
+
+ def step(self, *args, **kwargs):
+ """
+ Performs a single optimization step.
+ """
+ return self.optim.step(*args, **kwargs)
+
+ def zero_grad(self, *args, **kwargs):
+ """
+ Clears the gradients of all optimized `torch.Tensor`.
+ """
+ self.optim.zero_grad(*args, **kwargs)
+
+ def backward(self, loss: Tensor, *args, **kwargs):
+ """
+ Performs a backward pass on the loss.
+ """
+ loss.backward(*args, **kwargs)
+
+ def state_dict(self):
+ """
+ Returns the optimizer state.
+ """
+ return self.optim.state_dict()
+
+ def load_state_dict(self, *args, **kwargs):
+ """
+ Loads the optimizer state.
+ """
+ self.optim.load_state_dict(*args, **kwargs)
+
+ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
+ """
+ Clips gradient of an iterable of parameters at specified min and max values.
+
+ Args:
+ clip_value (float or int): maximum allowed value of the gradients. Gradients are clipped in the range
+
+ Note:
+ In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_value_ to use the
+ faster implementation. Please refer to the PyTorch documentation for more details.
+ """
+ nn.utils.clip_grad_value_(self.parameters, clip_value, *args, **kwargs)
+
+ def clip_grad_by_norm(self,
+ max_norm: Union[float, int],
+ norm_type: Union[float, int] = 2.0,
+ error_if_nonfinite: bool = False,
+ *args,
+ **kwargs) -> Tensor:
+ """
+ Clips gradient norm of an iterable of parameters.
+
+ Args:
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
+ error_if_nonfinite (bool): if True, an error is raised if the total norm is non-finite. Default: False
+
+ Note:
+ In PyTorch Torch 2.0 and above, you can pass in foreach=True as kwargs to clip_grad_norm_ to use the
+ faster implementation. Please refer to the PyTorch documentation for more details.
+ """
+ norm = nn.utils.clip_grad_norm_(self.parameters, max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
+ return norm
+
+ def scale_loss(self, loss: Tensor):
+ """
+ Scales the loss for mixed precision training.
+
+ Note: Only available for optimizers with mixed precision training.
+
+ Args:
+ loss (Tensor): The loss to be scaled.
+ """
+ raise NotImplementedError(
+ "The method scale_loss is only available for optimizers with mixed precision training")
+
+ def unscale_grad(self):
+ """
+ Unscale the gradients for mixed precision training.
+
+ Note: Only available for optimizers with mixed precision training.
+ """
+ raise NotImplementedError(
+ "The method unscale_grad is only available for optimizers with mixed precision training")
diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py
new file mode 100644
index 000000000000..3cf0ad28cdbe
--- /dev/null
+++ b/colossalai/booster/mixed_precision/__init__.py
@@ -0,0 +1,33 @@
+from .bf16 import BF16MixedPrecision
+from .fp8 import FP8MixedPrecision
+from .fp16_apex import FP16ApexMixedPrecision
+from .fp16_torch import FP16TorchMixedPrecision
+from .mixed_precision_base import MixedPrecision
+
+__all__ = [
+ 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision',
+ 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision'
+]
+
+_mixed_precision_mapping = {
+ 'fp16': FP16TorchMixedPrecision,
+ 'fp16_apex': FP16ApexMixedPrecision,
+ 'bf16': BF16MixedPrecision,
+ 'fp8': FP8MixedPrecision
+}
+
+
+def mixed_precision_factory(mixed_precision_type: str) -> MixedPrecision:
+ """
+ Factory method to create mixed precision object
+
+ Args:
+ mixed_precision_type (str): mixed precision type, including None, 'fp16', 'fp16_apex', 'bf16', and 'fp8'.
+ """
+
+ if mixed_precision_type in _mixed_precision_mapping:
+ return _mixed_precision_mapping[mixed_precision_type]()
+ else:
+ raise ValueError(
+ f'Mixed precision type {mixed_precision_type} is not supported, support types include {list(_mixed_precision_mapping.keys())}'
+ )
diff --git a/colossalai/booster/mixed_precision/bf16.py b/colossalai/booster/mixed_precision/bf16.py
new file mode 100644
index 000000000000..4a840fea69ea
--- /dev/null
+++ b/colossalai/booster/mixed_precision/bf16.py
@@ -0,0 +1,5 @@
+from .mixed_precision_base import MixedPrecision
+
+
+class BF16MixedPrecision(MixedPrecision):
+ pass
diff --git a/colossalai/booster/mixed_precision/fp16_apex.py b/colossalai/booster/mixed_precision/fp16_apex.py
new file mode 100644
index 000000000000..266a750734b1
--- /dev/null
+++ b/colossalai/booster/mixed_precision/fp16_apex.py
@@ -0,0 +1,5 @@
+from .mixed_precision_base import MixedPrecision
+
+
+class FP16ApexMixedPrecision(MixedPrecision):
+ pass
diff --git a/colossalai/booster/mixed_precision/fp16_torch.py b/colossalai/booster/mixed_precision/fp16_torch.py
new file mode 100644
index 000000000000..054f78d2e226
--- /dev/null
+++ b/colossalai/booster/mixed_precision/fp16_torch.py
@@ -0,0 +1,122 @@
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.optim import Optimizer
+
+from ..interface import OptimizerWrapper
+from .mixed_precision_base import MixedPrecision
+
+__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
+
+
+class TorchAMPOptimizer(OptimizerWrapper):
+ """
+ Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP.
+
+ Args:
+ optim (Optimizer): Optimizer to wrap.
+ init_scale (float): Initial scale factor. Default: 2**16.
+ growth_factor (float): Factor by which the scale is multiplied during
+ :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
+ this iteration. Default: 2.0.
+ backoff_factor (float): Factor by which the scale is multiplied during
+ :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
+ this iteration. Default: 0.5.
+ growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
+ calls that may cause the scale to increase. Default: 2000.
+ """
+
+ def __init__(self,
+ optim: Optimizer,
+ init_scale: float = 2.**16,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000) -> None:
+ super().__init__(optim)
+ self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval)
+
+ def backward(self, loss: Tensor, *args, **kwargs) -> None:
+ scaled_loss = self.scale_loss(loss)
+ scaled_loss.backward(*args, **kwargs)
+
+ def step(self, *args, **kwargs) -> Optional[float]:
+ return self.scaler.step(self.optim, *args, **kwargs)
+
+ def scale_loss(self, loss: Tensor) -> Tensor:
+ return self.scaler.scale(loss)
+
+ def unscale_grad(self) -> None:
+ self.scaler.unscale_(self.optim)
+
+ def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
+ self.unscale_grad()
+ super().clip_grad_by_value(clip_value, *args, **kwargs)
+
+ def clip_grad_by_norm(self,
+ max_norm: Union[float, int],
+ norm_type: Union[float, int] = 2.0,
+ error_if_nonfinite: bool = False,
+ *args,
+ **kwargs) -> None:
+ self.unscale_grad()
+ super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
+
+
+class TorchAMPModule(nn.Module):
+ """
+ Module wrapper for mixed precision training in FP16 using PyTorch AMP.
+
+ Args:
+ module (nn.Module): Module to wrap.
+ """
+
+ def __init__(self, module: nn.Module):
+ super().__init__()
+ self.module = module
+
+ def forward(self, *args, **kwargs):
+ with torch.cuda.amp.autocast():
+ return self.module(*args, **kwargs)
+
+
+class FP16TorchMixedPrecision(MixedPrecision):
+ """
+ Precision for mixed precision training in FP16 using PyTorch AMP.
+
+ Args:
+ init_scale (float): Initial scale factor. Default: 2**16.
+ growth_factor (float): Factor by which the scale is multiplied during
+ :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
+ this iteration. Default: 2.0.
+ backoff_factor (float): Factor by which the scale is multiplied during
+ :meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
+ this iteration. Default: 0.5.
+ growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
+ calls that may cause the scale to increase. Default: 2000.
+ """
+
+ def __init__(self,
+ init_scale: float = 2.**16,
+ growth_factor: float = 2.0,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 2000) -> None:
+ super().__init__()
+ self.torch_amp_kwargs = dict(init_scale=init_scale,
+ growth_factor=growth_factor,
+ backoff_factor=backoff_factor,
+ growth_interval=growth_interval)
+
+ def configure(self,
+ model: nn.Module,
+ optimizer: Optimizer,
+ criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
+ model = TorchAMPModule(model)
+ optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
+ if criterion is not None:
+ criterion = TorchAMPModule(criterion)
+ return model, optimizer, criterion
diff --git a/colossalai/booster/mixed_precision/fp8.py b/colossalai/booster/mixed_precision/fp8.py
new file mode 100644
index 000000000000..28847345d91d
--- /dev/null
+++ b/colossalai/booster/mixed_precision/fp8.py
@@ -0,0 +1,5 @@
+from .mixed_precision_base import MixedPrecision
+
+
+class FP8MixedPrecision(MixedPrecision):
+ pass
diff --git a/colossalai/booster/mixed_precision/mixed_precision_base.py b/colossalai/booster/mixed_precision/mixed_precision_base.py
new file mode 100644
index 000000000000..d1e8acc82cc6
--- /dev/null
+++ b/colossalai/booster/mixed_precision/mixed_precision_base.py
@@ -0,0 +1,21 @@
+from abc import ABC, abstractmethod
+from typing import Callable, Tuple
+
+import torch.nn as nn
+from torch.optim import Optimizer
+
+from ..interface import OptimizerWrapper
+
+
+class MixedPrecision(ABC):
+ """
+ An abstract class for mixed precision training.
+ """
+
+ @abstractmethod
+ def configure(self,
+ model: nn.Module,
+ optimizer: Optimizer,
+ criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
+ # TODO: implement this method
+ pass
diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py
new file mode 100644
index 000000000000..3328fe2b9627
--- /dev/null
+++ b/colossalai/booster/plugin/__init__.py
@@ -0,0 +1,4 @@
+from .plugin_base import Plugin
+from .torch_ddp_plugin import TorchDDPPlugin
+
+__all__ = ['Plugin', 'TorchDDPPlugin']
diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py
new file mode 100644
index 000000000000..3c347cb4252d
--- /dev/null
+++ b/colossalai/booster/plugin/plugin_base.py
@@ -0,0 +1,51 @@
+from abc import ABC, abstractmethod
+from typing import Callable, List, Tuple, Union
+
+import torch.nn as nn
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+
+from colossalai.booster.interface import OptimizerWrapper
+
+__all__ = ['Plugin']
+
+
+class Plugin(ABC):
+
+ @property
+ @abstractmethod
+ def supported_devices(self) -> List[str]:
+ pass
+
+ @property
+ @abstractmethod
+ def supported_precisions(self) -> List[str]:
+ pass
+
+ @property
+ @abstractmethod
+ def control_precision(self) -> bool:
+ pass
+
+ @property
+ @abstractmethod
+ def control_device(self) -> bool:
+ pass
+
+ @property
+ @abstractmethod
+ def support_no_sync(self) -> bool:
+ pass
+
+ @abstractmethod
+ def configure(
+ self,
+ model: nn.Module,
+ optimizer: Optimizer,
+ criterion: Callable = None,
+ dataloader: DataLoader = None,
+ lr_scheduler: LRScheduler = None,
+ ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
+ # implement this method
+ pass
diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py
new file mode 100644
index 000000000000..07d6be8c748d
--- /dev/null
+++ b/colossalai/booster/plugin/torch_ddp_plugin.py
@@ -0,0 +1,147 @@
+import random
+from typing import Callable, List, Tuple, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from colossalai.booster.interface import OptimizerWrapper
+
+from .plugin_base import Plugin
+
+__all__ = ['TorchDDPPlugin']
+
+
+class TorchDDPPlugin(Plugin):
+ """
+ Plugin for PyTorch DDP.
+
+ Example:
+ >>> from colossalai.booster import Booster
+ >>> from colossalai.booster.plugin import TorchDDPPlugin
+ >>>
+ >>> model, train_dataset, optimizer, criterion = ...
+ >>> plugin = TorchDDPPlugin()
+
+ >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
+ >>> booster = Booster(plugin=plugin)
+ >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
+
+ Args:
+ broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
+ bucket_cap_mb (int, optional): The bucket size in MB. Defaults to 25.
+ find_unused_parameters (bool, optional): Whether to find unused parameters. Defaults to False.
+ check_reduction (bool, optional): Whether to check reduction. Defaults to False.
+ gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False.
+ static_graph (bool, optional): Whether to use static graph. Defaults to False.
+ """
+
+ def __init__(self,
+ broadcast_buffers: bool = True,
+ bucket_cap_mb: int = 25,
+ find_unused_parameters: bool = False,
+ check_reduction: bool = False,
+ gradient_as_bucket_view: bool = False,
+ static_graph: bool = False) -> None:
+
+ assert dist.is_initialized(
+ ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers,
+ bucket_cap_mb=bucket_cap_mb,
+ find_unused_parameters=find_unused_parameters,
+ check_reduction=check_reduction,
+ gradient_as_bucket_view=gradient_as_bucket_view,
+ static_graph=static_graph)
+
+ def support_no_sync(self) -> bool:
+ return True
+
+ def control_precision(self) -> bool:
+ return False
+
+ def supported_precisions(self) -> List[str]:
+ return ['fp16', 'fp16_apex', 'bf16', 'fp8']
+
+ def control_device(self) -> bool:
+ return True
+
+ def supported_devices(self) -> List[str]:
+ return ['cuda']
+
+ def prepare_train_dataloader(self,
+ dataset,
+ batch_size,
+ shuffle=False,
+ seed=1024,
+ drop_last=False,
+ pin_memory=False,
+ num_workers=0,
+ **kwargs):
+ r"""
+ Prepare a dataloader for distributed training. The dataloader will be wrapped by
+ `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
+
+ Note:
+ 1. Evaluation datasets should not be passed to this function.
+
+ Args:
+ dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
+ seed (int, optional): Random worker seed for sampling, defaults to 1024.
+ add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
+ drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
+ is not divisible by the batch size. If False and the size of dataset is not divisible by
+ the batch size, then the last batch will be smaller, defaults to False.
+ pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
+ num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
+ kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
+ `DataLoader `_.
+
+ Returns:
+ :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
+ """
+ _kwargs = kwargs.copy()
+ sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
+
+ # Deterministic dataloader
+ def seed_worker(worker_id):
+ worker_seed = seed
+ np.random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+ random.seed(worker_seed)
+
+ return DataLoader(dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ worker_init_fn=seed_worker,
+ drop_last=drop_last,
+ pin_memory=pin_memory,
+ num_workers=num_workers,
+ **_kwargs)
+
+ def configure(
+ self,
+ model: nn.Module,
+ optimizer: Optimizer,
+ criterion: Callable = None,
+ dataloader: DataLoader = None,
+ lr_scheduler: LRScheduler = None,
+ ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
+ # cast model to cuda
+ model = model.cuda()
+
+ # wrap the model with PyTorch DDP
+ model = DDP(model, **self.ddp_kwargs)
+
+ if not isinstance(optimizer, OptimizerWrapper):
+ optimizer = OptimizerWrapper(optimizer)
+
+ return model, optimizer, criterion, dataloader, lr_scheduler
diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py
new file mode 100644
index 000000000000..3cec630b2f86
--- /dev/null
+++ b/colossalai/checkpoint_io/__init__.py
@@ -0,0 +1,4 @@
+from .checkpoint_io_base import CheckpointIO, ShardCheckpointIndexFile
+from .general_checkpoint_io import GeneralCheckpointIO
+
+__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile', 'GeneralCheckpointIO']
diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py
new file mode 100644
index 000000000000..00a65424bece
--- /dev/null
+++ b/colossalai/checkpoint_io/checkpoint_io_base.py
@@ -0,0 +1,374 @@
+import json
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Any
+
+import torch
+import torch.nn as nn
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
+
+__all__ = ['CheckpointIO', 'ShardCheckpointIndexFile']
+
+
+class CheckpointIO(ABC):
+ """
+ CheckpointIO is the base class for all checkpoint IO classes. It defines the interface for checkpoint IO.
+
+
+ Examples:
+ >>> from colossalai.checkpoint_io import GeneralCheckpointIO
+ >>> checkpoint_io = CheckpointIO()
+ >>>
+ >>> # load model from checkpoint
+ >>> model = checkpoint_io.load_model(model, 'model.pt')
+ >>>
+ >>> # save model to checkpoint
+ >>> checkpoint_io.save_model(model, 'model.pt')
+ >>>
+ >>> # save model to sharded checkpoints
+ >>> checkpoint_io.save_model(model, './checkpoints/', shard=True)
+ >>>
+ >>> # load model from sharded checkpoints
+ >>> model = checkpoint_io.load_model(model, './checkpoints/')
+ >>>
+ >>> # load optimizer from checkpoint
+ >>> optimizer = checkpoint_io.load_optimizer(optimizer, 'optimizer.pt')
+ >>>
+ >>> # save optimizer to checkpoint
+ >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt')
+
+ """
+
+ # ======================================
+ # Abstract methods for implementation
+ # ======================================
+
+ @abstractmethod
+ def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
+ """
+ Load model from checkpoint.
+
+ Args:
+ model (nn.Module): model to be loaded.
+ checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
+ mainstream model zoos such as Hugging Face and TIMM. The checkpoint path can be:
+ 1. a file path, e.g. 'model.pt'
+ 2. a path to a json file which defines the index to the sharded checkpoint
+ 3. a path to a folder containing a unique .index.json file for sharded checkpoint
+ strict (bool): whether to strictly enforce that the param name in
+ the checkpoint match the keys returned by this module's.
+ """
+ pass
+
+ @abstractmethod
+ def save_model(self,
+ model: nn.Module,
+ checkpoint: str,
+ prefix: str = None,
+ shard: bool = False,
+ size_per_shard: int = 1024):
+ """
+ Save model to checkpoint.
+
+ Examples:
+ >>> from colossalai.checkpoint_io import GeneralCheckpointIO
+ >>> checkpoint_io = CheckpointIO()
+ >>>
+ >>> # save model to a single file
+ >>> save_model(model, 'model.pt')
+ >>>
+ >>> # save model to a sharded checkpoint
+ >>> save_model(model, './checkpoints/', shard=True)
+
+ Args:
+ model (nn.Module): model to be saved.
+ checkpoint: checkpoint path. The checkpoint path can be :
+ 1. a file path, e.g. 'model.pt'
+ 2. a directory path to save the sharded checkpoint, e.g. './checkpoints/' when shard = True.
+ shard: whether to shard the checkpoint. Default: False. If set to True, the checkpoint will be sharded into
+ multiple files. The model shards will be specificed by a `model.index.json` file. When shard = True, please ensure
+ that the checkpoint path is a directory path instead of a file path.
+ size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
+ """
+ pass
+
+ @abstractmethod
+ def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
+ """
+ Load optimizer from checkpoint.
+
+ Args:
+ optimizer (Optimizer): optimizer to be loaded.
+ checkpoint (str): checkpoint path. This value is made compatiblity with the model checkpoints in the
+ """
+ pass
+
+ @abstractmethod
+ def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
+ """
+ Save optimizer to checkpoint.
+
+ Args:
+ optimizer (Optimizer): optimizer to be saved.
+ checkpoint: checkpoint path. The checkpoint path can be :
+ 1. a file path, e.g. 'model.pt'
+ 2. a path to a json file which defines the index to the sharded checkpoint for the optimizer
+ 3. a path to a folder containing a unique .index.json file for sharded checkpoint
+ """
+ pass
+
+ # ============================================
+ # methods for loading and saving lr scheduler
+ # as this is quite standard, there is no need
+ # to make them abstract
+ # ============================================
+
+ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Save lr scheduler to checkpoint.
+
+ Args:
+ lr_scheduler (LRScheduler): lr scheduler to be saved.
+ checkpoint: checkpoint path. The checkpoint path can only be a file path.
+ """
+ torch.save(lr_scheduler.state_dict(), checkpoint)
+
+ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
+ """
+ Load lr scheduler from checkpoint.
+
+ Args:
+ lr_scheduler (LRScheduler): lr scheduler to be loaded.
+ checkpoint (str): the path for a single checkpoint file.
+ """
+ state_dict = torch.load(checkpoint)
+ lr_scheduler.load_state_dict(state_dict)
+
+ # ========================================
+ # Helper functions for loading state dict
+ # ========================================
+
+ def get_sharded_checkpoint_index_file(self, checkpoint_path: Path):
+ """
+ Get the index file path for a sharded checkpoint.
+
+ Args:
+ checkpoint_path (Path): path to the checkpoint.
+
+ Returns:
+ Path: path to the index file.
+ """
+ if checkpoint_path.is_file():
+ # check if it is .index.json
+ if checkpoint_path.name.endswith('.index.json'):
+ return checkpoint_path
+ else:
+ raise ValueError(f'Invalid checkpoint path: {checkpoint_path}. ')
+ elif checkpoint_path.is_dir():
+ # check if there is only one a file ending with .index.json in this directory
+ index_files = list(checkpoint_path.glob('*.index.json'))
+ if len(index_files) == 1:
+ return index_files[0]
+ else:
+ raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
+
+ def is_sharded_checkpoint(self, checkpoint_path: Path):
+ """
+ Check whether the checkpoint is sharded.
+
+ Args:
+ checkpoint (str): checkpoint path.
+
+ Returns:
+ bool: whether the checkpoint is sharded.
+ """
+ if checkpoint_path.is_file():
+ # check if it is .index.json
+ if checkpoint_path.name.endswith('.index.json'):
+ return True
+ else:
+ return False
+ elif checkpoint_path.is_dir():
+ # check if there is only one a file ending with .index.json in this directory
+ index_files = list(checkpoint_path.glob('*.index.json'))
+ if len(index_files) == 1:
+ return True
+ else:
+ raise ValueError(f'Found {len(index_files)} index files in {checkpoint_path}. ')
+
+ def get_checkpoint_shard_filenames(self, index_file_path: Path):
+ """
+ Get checkpoint shard filenames from a json file.
+
+ Args:
+ index_file_path (Path): path to the json file.
+
+ Returns:
+ list: checkpoint shard filenames.
+ """
+ with open(str(index_file_path), 'r') as f:
+ shard_filenames = json.load(f)
+
+ if "weight_map" in index:
+ index = index["weight_map"]
+
+ checkpoint_root_path = index_file_path.absolute().parent
+
+ # read the checkpoint file list from the json file and get a list of unique file names
+ checkpoint_files = sorted(list(set(index.values())))
+
+ # get the absolute paths for all checkpoint files
+ checkpoint_files = [checkpoint_root_path.joinpath(f) for f in checkpoint_files]
+ return shard_filenames
+
+ def load_safetensors_state_dict(self, *args, **kwargs):
+ """
+ Load safetensors state dict from checkpoint.
+ """
+ # TODO(FrankLeeeee): support huggingface safetensors
+ raise NotImplementedError("This method is not implemented to support safe tensors")
+
+ def load_state_dict(self, checkpoint_file_path: Path):
+ """
+ Load state dict from checkpoint.
+
+ Args:
+ checkpoint_file_path (Path): path to the checkpoint file.
+
+ Returns:
+ dict: state dict.
+ """
+ return torch.load(str(checkpoint_file_path))
+
+ # ======================================
+ # Helper functions for saving state dict
+ # ======================================
+
+ def save_safetensors_state_dict(self, *args, **kwargs):
+ """
+ Save safetensors state dict to checkpoint.
+ """
+ # TODO(FrankLeeeee): support huggingface safetensors
+ raise NotImplementedError("This method is not implemented to support safe tensors")
+
+ def generate_checkpoint_shard_file_name(self, index: int, total_number: int, prefix: str = None):
+ """
+ Generate checkpoint shard file name.
+
+ Args:
+ index (int): index of the shard.
+ total_number (int): total number of shards.
+ prefix (str): prefix of the shard file name. Default: None.
+ """
+ if prefix is None:
+ return f"{index}-of-{total_number}.bin"
+ else:
+ return f"{prefix}-{index}-of-{total_number}.bin"
+
+ def save_checkpoint(self, state_dict: dict, checkpoint_file_path: Path):
+ """
+ Save state dict to checkpoint.
+
+ Args:
+ state_dict (dict): state dict.
+ checkpoint_file_path (Path): path to the checkpoint file.
+ """
+ torch.save(state_dict, str(checkpoint_file_path))
+
+ def save_state_dict_as_shard(self, state_dict: dict, index: int, total_number: int, prefix: str,
+ checkpoint_path: Path):
+ """
+ Save state dict as shard.
+
+ Args:
+ state_dict (dict): state dict.
+ checkpoint_path (Path): path to the checkpoint file.
+ """
+ # generate the shard name
+ shard_file_name = self.generate_checkpoint_shard_file_name(index, total_number, prefix)
+ shard_file_path = checkpoint_path.joinpath(shard_file_name)
+
+ # save the shard
+ self.save_checkpoint(state_dict, shard_file_path)
+
+ def calculate_param_size(self, param: torch.Tensor):
+ """
+ Calculate the size of a parameter in MB. Used to compute whether a group of params exceed the shard size.
+ If so, a new shard should be created.
+
+ ArgsL
+ param (torch.Tensor): parameter tensor.
+ """
+ # TODO(FrankLeeeee): check if this tensor is a DTensor, compute its global size if so
+ return param.numel() * param.element_size() / 1024 / 1024
+
+
+class ShardCheckpointIndexFile:
+ """
+ This class is a data structure to keep the content in the index.json file for sharded checkpoint.
+
+ Example:
+ >>> index = ShardCheckpointIndexFile()
+ >>> index.load('index.json')
+ >>> index.append_metadata('model_type', 'bert')
+ >>> index.append_weight_map('bert.embeddings.word_embeddings.weight', 'bert.embeddings.word_embeddings.weight-0-of-2.bin')
+ >>> index.export('index.json')
+ """
+
+ def __init__(self) -> None:
+ self.metadata: dict = dict()
+ self.weight_map: dict = dict()
+
+ def load(self, json_path: str):
+ """
+ Load the index file from a json file.
+
+ Args:
+ json_path (str): path to the json file.
+ """
+ # load the json file
+ with open(json_path, 'r') as f:
+ index = json.load(f)
+
+ # assign attributes if exists
+ if "metadata" in index:
+ self.metadata = index["metadata"]
+ if "weight_map" in index:
+ self.weight_map = index["weight_map"]
+
+ def export(self, json_path: str):
+ """
+ Export the index file to a json file.
+
+ Args:
+ json_path (str): path to the json file.
+ """
+ # create the index file
+ index = dict()
+ index["metadata"] = self.metadata
+ index["weight_map"] = self.weight_map
+
+ # export the index file
+ with open(json_path, 'w') as f:
+ json.dump(index, f, indent=4)
+
+ def append_weight_map(self, param_name: str, shard_file: str):
+ """
+ Append a weight map entry to the index file.
+
+ Args:
+ param_name (str): name of the parameter.
+ shard_file (str): name of the shard file.
+ """
+ self.weight_map[param_name] = shard_file
+
+ def append_meta_data(self, name: str, val: Any):
+ """
+ Append a metadata entry to the index file.
+
+ Args:
+ name (str): name of the metadata.
+ val (Any): value of the metadata.
+ """
+ self.metadata[name] = val
diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py
new file mode 100644
index 000000000000..0a3636655530
--- /dev/null
+++ b/colossalai/checkpoint_io/general_checkpoint_io.py
@@ -0,0 +1,66 @@
+from pathlib import Path
+
+import torch.nn as nn
+from torch.optim import Optimizer
+
+from .checkpoint_io_base import CheckpointIO
+
+__all__ = ['GeneralCheckpointIO']
+
+
+class GeneralCheckpointIO(CheckpointIO):
+
+ def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
+ checkpoint = Path(checkpoint)
+ is_sharded = self.is_sharded_checkpoint(checkpoint)
+
+ if not is_sharded:
+ checkpoint = self.load_state_dict(checkpoint)
+ model.load_state_dict(checkpoint, strict=strict)
+ else:
+ # find the index file
+ checkpoint_path = Path(checkpoint)
+ index_file_path = self.get_sharded_checkpoint_index_file(checkpoint_path)
+
+ # iterate over the shard checkpoint files
+ # and load each
+ shard_files = self.get_checkpoint_shard_filenames(index_file_path)
+ for shard_file in shard_files:
+ shard_checkpoint = self.load_state_dict(shard_file)
+ model.load_state_dict(shard_checkpoint, strict=strict)
+
+ return model
+
+ def save_model(self,
+ model: nn.Module,
+ checkpoint: str,
+ prefix: str = None,
+ shard: bool = False,
+ size_per_shard: int = 1024):
+ checkpoint = Path(checkpoint)
+ if shard:
+ # TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
+ raise NotImplementedError("Not implemented yet")
+ else:
+ self.save_checkpoint(model.state_dict(), checkpoint)
+
+ def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
+ checkpoint = Path(checkpoint)
+ is_sharded = self.is_sharded_checkpoint(checkpoint)
+
+ if not is_sharded:
+ checkpoint = self.load_state_dict(checkpoint)
+ optimizer.load_state_dict(checkpoint)
+ else:
+ # TODO(FrankLeeeee): implement checkpoint loading from sharded checkpoint
+ # This is not an urgent feature, so we can leave it for later
+ # let's implement this when we test large-scale models
+ pass
+ return optimizer
+
+ def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024):
+ if shard:
+ # TODO(FrankLeeeee): implement checkpoint saving to sharded checkpoint
+ pass
+ else:
+ self.save_checkpoint(optimizer.state_dict(), checkpoint)
diff --git a/colossalai/cli/check/check_installation.py b/colossalai/cli/check/check_installation.py
index a12b24402794..44d7840700ef 100644
--- a/colossalai/cli/check/check_installation.py
+++ b/colossalai/cli/check/check_installation.py
@@ -7,30 +7,103 @@
import colossalai
+def to_click_output(val):
+ # installation check output to understandable symbols for readability
+ VAL_TO_SYMBOL = {True: u'\u2713', False: 'x', None: 'N/A'}
+
+ if val in VAL_TO_SYMBOL:
+ return VAL_TO_SYMBOL[val]
+ else:
+ return val
+
+
def check_installation():
- cuda_ext_installed = _check_cuda_extension_installed()
- cuda_version, torch_version, torch_cuda_version = _check_cuda_torch()
- colossalai_verison, torch_version_required, cuda_version_required = _parse_colossalai_version()
-
- cuda_compatibility = _get_compatibility_string([cuda_version, torch_cuda_version, cuda_version_required])
- torch_compatibility = _get_compatibility_string([torch_version, torch_version_required])
-
- click.echo(f'#### Installation Report ####\n')
- click.echo(f"Colossal-AI version: {colossalai_verison}")
- click.echo(f'----------------------------')
- click.echo(f"PyTorch Version: {torch_version}")
- click.echo(f"PyTorch Version required by Colossal-AI: {torch_version_required}")
- click.echo(f'PyTorch version match: {torch_compatibility}')
- click.echo(f'----------------------------')
- click.echo(f"System CUDA Version: {cuda_version}")
- click.echo(f"CUDA Version required by PyTorch: {torch_cuda_version}")
- click.echo(f"CUDA Version required by Colossal-AI: {cuda_version_required}")
- click.echo(f"CUDA Version Match: {cuda_compatibility}")
- click.echo(f'----------------------------')
- click.echo(f"CUDA Extension: {cuda_ext_installed}")
-
-
-def _get_compatibility_string(versions):
+ """
+ This function will check the installation of colossalai, specifically, the version compatibility of
+ colossalai, pytorch and cuda.
+
+ Example:
+ ```text
+ ```
+
+ Returns: A table of installation information.
+ """
+ found_aot_cuda_ext = _check_aot_built_cuda_extension_installed()
+ cuda_version = _check_cuda_version()
+ torch_version, torch_cuda_version = _check_torch_version()
+ colossalai_verison, prebuilt_torch_version_required, prebuilt_cuda_version_required = _parse_colossalai_version()
+
+ # if cuda_version is None, that means either
+ # CUDA_HOME is not found, thus cannot compare the version compatibility
+ if not cuda_version:
+ sys_torch_cuda_compatibility = None
+ else:
+ sys_torch_cuda_compatibility = _is_compatible([cuda_version, torch_cuda_version])
+
+ # if cuda_version or cuda_version_required is None, that means either
+ # CUDA_HOME is not found or AOT compilation is not enabled
+ # thus, there is no need to compare the version compatibility at all
+ if not cuda_version or not prebuilt_cuda_version_required:
+ sys_colossalai_cuda_compatibility = None
+ else:
+ sys_colossalai_cuda_compatibility = _is_compatible([cuda_version, prebuilt_cuda_version_required])
+
+ # if torch_version_required is None, that means AOT compilation is not enabled
+ # thus there is no need to compare the versions
+ if prebuilt_torch_version_required is None:
+ torch_compatibility = None
+ else:
+ torch_compatibility = _is_compatible([torch_version, prebuilt_torch_version_required])
+
+ click.echo(f'#### Installation Report ####')
+ click.echo(f'\n------------ Environment ------------')
+ click.echo(f"Colossal-AI version: {to_click_output(colossalai_verison)}")
+ click.echo(f"PyTorch version: {to_click_output(torch_version)}")
+ click.echo(f"System CUDA version: {to_click_output(cuda_version)}")
+ click.echo(f"CUDA version required by PyTorch: {to_click_output(torch_cuda_version)}")
+ click.echo("")
+ click.echo(f"Note:")
+ click.echo(f"1. The table above checks the versions of the libraries/tools in the current environment")
+ click.echo(f"2. If the System CUDA version is N/A, you can set the CUDA_HOME environment variable to locate it")
+ click.echo(
+ f"3. If the CUDA version required by PyTorch is N/A, you probably did not install a CUDA-compatible PyTorch. This value is give by torch.version.cuda and you can go to https://pytorch.org/get-started/locally/ to download the correct version."
+ )
+
+ click.echo(f'\n------------ CUDA Extensions AOT Compilation ------------')
+ click.echo(f"Found AOT CUDA Extension: {to_click_output(found_aot_cuda_ext)}")
+ click.echo(f"PyTorch version used for AOT compilation: {to_click_output(prebuilt_torch_version_required)}")
+ click.echo(f"CUDA version used for AOT compilation: {to_click_output(prebuilt_cuda_version_required)}")
+ click.echo("")
+ click.echo(f"Note:")
+ click.echo(
+ f"1. AOT (ahead-of-time) compilation of the CUDA kernels occurs during installation when the environment varialbe CUDA_EXT=1 is set"
+ )
+ click.echo(f"2. If AOT compilation is not enabled, stay calm as the CUDA kernels can still be built during runtime")
+
+ click.echo(f"\n------------ Compatibility ------------")
+ click.echo(f'PyTorch version match: {to_click_output(torch_compatibility)}')
+ click.echo(f"System and PyTorch CUDA version match: {to_click_output(sys_torch_cuda_compatibility)}")
+ click.echo(f"System and Colossal-AI CUDA version match: {to_click_output(sys_colossalai_cuda_compatibility)}")
+ click.echo(f"")
+ click.echo(f"Note:")
+ click.echo(f"1. The table above checks the version compatibility of the libraries/tools in the current environment")
+ click.echo(
+ f" - PyTorch version mistach: whether the PyTorch version in the current environment is compatible with the PyTorch version used for AOT compilation"
+ )
+ click.echo(
+ f" - System and PyTorch CUDA version match: whether the CUDA version in the current environment is compatible with the CUDA version required by PyTorch"
+ )
+ click.echo(
+ f" - System and Colossal-AI CUDA version match: whether the CUDA version in the current environment is compatible with the CUDA version used for AOT compilation"
+ )
+
+
+def _is_compatible(versions):
+ """
+ Compare the list of versions and return whether they are compatible.
+ """
+ if None in versions:
+ return False
# split version into [major, minor, patch]
versions = [version.split('.') for version in versions]
@@ -44,52 +117,98 @@ def _get_compatibility_string(versions):
equal = len(set(version_values)) == 1
if idx in [0, 1] and not equal:
- # if the major/minor versions do not match
- # return a cross
- return 'x'
+ return False
elif idx == 1:
- # if the minor versions match
- # return a tick
- return u'\u2713'
+ return True
else:
continue
def _parse_colossalai_version():
+ """
+ Get the Colossal-AI version information.
+
+ Returns:
+ colossalai_version: Colossal-AI version.
+ torch_version_for_aot_build: PyTorch version used for AOT compilation of CUDA kernels.
+ cuda_version_for_aot_build: CUDA version used for AOT compilation of CUDA kernels.
+ """
+ # colossalai version can be in two formats
+ # 1. X.X.X+torchX.XXcuXX.X (when colossalai is installed with CUDA extensions)
+ # 2. X.X.X (when colossalai is not installed with CUDA extensions)
+ # where X represents an integer.
colossalai_verison = colossalai.__version__.split('+')[0]
- torch_version_required = colossalai.__version__.split('torch')[1].split('cu')[0]
- cuda_version_required = colossalai.__version__.split('cu')[1]
- return colossalai_verison, torch_version_required, cuda_version_required
-
-def _check_cuda_extension_installed():
+ try:
+ torch_version_for_aot_build = colossalai.__version__.split('torch')[1].split('cu')[0]
+ cuda_version_for_aot_build = colossalai.__version__.split('cu')[1]
+ except:
+ torch_version_for_aot_build = None
+ cuda_version_for_aot_build = None
+ return colossalai_verison, torch_version_for_aot_build, cuda_version_for_aot_build
+
+
+def _check_aot_built_cuda_extension_installed():
+ """
+ According to `op_builder/README.md`, the CUDA extension can be built with either
+ AOT (ahead-of-time) or JIT (just-in-time) compilation.
+ AOT compilation will build CUDA extensions to `colossalai._C` during installation.
+ JIT (just-in-time) compilation will build CUDA extensions to `~/.cache/colossalai/torch_extensions` during runtime.
+ """
try:
import colossalai._C.fused_optim
- is_cuda_extension_installed = u'\u2713'
+ found_aot_cuda_ext = True
except ImportError:
- is_cuda_extension_installed = 'x'
- return is_cuda_extension_installed
+ found_aot_cuda_ext = False
+ return found_aot_cuda_ext
-def _check_cuda_torch():
- # get cuda version
- if CUDA_HOME is None:
- cuda_version = 'N/A (CUDA_HOME is not set)'
- else:
- raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True)
- output = raw_output.split()
- release_idx = output.index("release") + 1
- release = output[release_idx].split(".")
- bare_metal_major = release[0]
- bare_metal_minor = release[1][0]
- cuda_version = f'{bare_metal_major}.{bare_metal_minor}'
+def _check_torch_version():
+ """
+ Get the PyTorch version information.
+ Returns:
+ torch_version: PyTorch version.
+ torch_cuda_version: CUDA version required by PyTorch.
+ """
# get torch version
+ # torch version can be of two formats
+ # - 1.13.1+cu113
+ # - 1.13.1.devxxx
torch_version = torch.__version__.split('+')[0]
+ torch_version = '.'.join(torch_version.split('.')[:3])
# get cuda version in pytorch build
- torch_cuda_major = torch.version.cuda.split(".")[0]
- torch_cuda_minor = torch.version.cuda.split(".")[1]
- torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}'
+ try:
+ torch_cuda_major = torch.version.cuda.split(".")[0]
+ torch_cuda_minor = torch.version.cuda.split(".")[1]
+ torch_cuda_version = f'{torch_cuda_major}.{torch_cuda_minor}'
+ except:
+ torch_cuda_version = None
+
+ return torch_version, torch_cuda_version
+
- return cuda_version, torch_version, torch_cuda_version
+def _check_cuda_version():
+ """
+ Get the CUDA version information.
+
+ Returns:
+ cuda_version: CUDA version found on the system.
+ """
+
+ # get cuda version
+ if CUDA_HOME is None:
+ cuda_version = CUDA_HOME
+ else:
+ try:
+ raw_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], universal_newlines=True)
+ output = raw_output.split()
+ release_idx = output.index("release") + 1
+ release = output[release_idx].split(".")
+ bare_metal_major = release[0]
+ bare_metal_minor = release[1][0]
+ cuda_version = f'{bare_metal_major}.{bare_metal_minor}'
+ except:
+ cuda_version = None
+ return cuda_version
diff --git a/colossalai/cli/cli.py b/colossalai/cli/cli.py
index 3e5b9ae6343f..a94e1150e49f 100644
--- a/colossalai/cli/cli.py
+++ b/colossalai/cli/cli.py
@@ -1,7 +1,8 @@
import click
-from .launcher import run
-from .check import check
+
from .benchmark import benchmark
+from .check import check
+from .launcher import run
class Arguments():
diff --git a/colossalai/cli/launcher/__init__.py b/colossalai/cli/launcher/__init__.py
index 4ada68b4b68f..8d9ec147d401 100644
--- a/colossalai/cli/launcher/__init__.py
+++ b/colossalai/cli/launcher/__init__.py
@@ -1,7 +1,9 @@
import click
-from .run import launch_multi_processes
+
from colossalai.context import Config
+from .run import launch_multi_processes
+
@click.command(help="Launch distributed training on a single node or multiple nodes",
context_settings=dict(ignore_unknown_options=True))
diff --git a/colossalai/cli/launcher/hostinfo.py b/colossalai/cli/launcher/hostinfo.py
index 2f0830c5880d..065cbc37101f 100644
--- a/colossalai/cli/launcher/hostinfo.py
+++ b/colossalai/cli/launcher/hostinfo.py
@@ -1,5 +1,5 @@
-from typing import List
import socket
+from typing import List
class HostInfo:
@@ -35,9 +35,14 @@ def is_host_localhost(hostname: str, port: str = None) -> None:
if port is None:
port = 22 # no port specified, lets just use the ssh port
- hostname = socket.getfqdn(hostname)
+
+ # socket.getfqdn("127.0.0.1") does not return localhost
+ # on some users' machines
+ # thus, we directly return True if hostname is locahost, 127.0.0.1 or 0.0.0.0
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
return True
+
+ hostname = socket.getfqdn(hostname)
localhost = socket.gethostname()
localaddrs = socket.getaddrinfo(localhost, port)
targetaddrs = socket.getaddrinfo(hostname, port)
diff --git a/colossalai/cli/launcher/multinode_runner.py b/colossalai/cli/launcher/multinode_runner.py
index c45ad5e5a082..a51e1e371f13 100644
--- a/colossalai/cli/launcher/multinode_runner.py
+++ b/colossalai/cli/launcher/multinode_runner.py
@@ -1,8 +1,10 @@
-import fabric
-from .hostinfo import HostInfo, HostInfoList
from multiprocessing import Pipe, Process
from multiprocessing import connection as mp_connection
+
import click
+import fabric
+
+from .hostinfo import HostInfo, HostInfoList
def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Connection,
@@ -45,8 +47,10 @@ def run_on_host(hostinfo: HostInfo, workdir: str, recv_conn: mp_connection.Conne
# execute on the remote machine
fab_conn.run(cmds, hide=False)
send_conn.send('success')
- except:
- click.echo(f"Error: failed to run {cmds} on {hostinfo.hostname}")
+ except Exception as e:
+ click.echo(
+ f"Error: failed to run {cmds} on {hostinfo.hostname}, is localhost: {hostinfo.is_local_host}, exception: {e}"
+ )
send_conn.send('failure')
# shutdown
diff --git a/colossalai/cli/launcher/run.py b/colossalai/cli/launcher/run.py
index e078a57c15c9..6411b4302e95 100644
--- a/colossalai/cli/launcher/run.py
+++ b/colossalai/cli/launcher/run.py
@@ -1,12 +1,15 @@
-import click
-import sys
import os
+import sys
+from typing import List
+
+import click
import torch
+from packaging import version
+
from colossalai.context import Config
-from .multinode_runner import MultiNodeRunner
+
from .hostinfo import HostInfo, HostInfoList
-from typing import List
-from packaging import version
+from .multinode_runner import MultiNodeRunner
# Constants that define our syntax
NODE_SEP = ','
@@ -15,7 +18,7 @@
def fetch_hostfile(hostfile_path: str, ssh_port: int) -> HostInfoList:
"""
Parse the hostfile to obtain a list of hosts.
-
+
A hostfile should look like:
worker-0
worker-1
@@ -63,7 +66,7 @@ def parse_device_filter(device_pool: HostInfoList, include_str=None, exclude_str
device_pool (HostInfoList): a list of HostInfo objects
include_str (str): --include option passed by user, default None
exclude_str (str): --exclude option passed by user, default None
-
+
Returns:
filtered_hosts (HostInfoList): filtered hosts after inclusion/exclusion
'''
@@ -192,7 +195,7 @@ def launch_multi_processes(args: Config) -> None:
Launch multiple processes on a single node or multiple nodes.
The overall logic can be summarized as the pseudo code below:
-
+
if hostfile given:
hostinfo = parse_hostfile(hostfile)
hostinfo = include_or_exclude_hosts(hostinfo)
@@ -202,7 +205,7 @@ def launch_multi_processes(args: Config) -> None:
launch_on_multi_nodes(hostinfo)
else:
launch_on_current_node()
-
+
Args:
args (Config): the arguments taken from command line
@@ -276,6 +279,33 @@ def launch_multi_processes(args: Config) -> None:
extra_launch_args=args.extra_launch_args)
runner.send(hostinfo=hostinfo, cmd=cmd)
- runner.recv_from_all()
+ # start training
+ msg_from_node = runner.recv_from_all()
+ has_error = False
+
+ # print node status
+ click.echo("\n====== Training on All Nodes =====")
+ for hostname, msg in msg_from_node.items():
+ click.echo(f"{hostname}: {msg}")
+
+ # check if a process failed
+ if msg == "failure":
+ has_error = True
+
+ # stop all nodes
runner.stop_all()
- runner.recv_from_all()
+
+ # receive the stop status
+ msg_from_node = runner.recv_from_all()
+
+ # printe node status
+ click.echo("\n====== Stopping All Nodes =====")
+ for hostname, msg in msg_from_node.items():
+ click.echo(f"{hostname}: {msg}")
+
+ # give the process an exit code
+ # so that it behaves like a normal process
+ if has_error:
+ sys.exit(1)
+ else:
+ sys.exit(0)
diff --git a/colossalai/cluster/__init__.py b/colossalai/cluster/__init__.py
new file mode 100644
index 000000000000..2fbdfd3cc999
--- /dev/null
+++ b/colossalai/cluster/__init__.py
@@ -0,0 +1,5 @@
+from .device_mesh_manager import DeviceMeshManager
+from .dist_coordinator import DistCoordinator
+from .process_group_manager import ProcessGroupManager
+
+__all__ = ['DistCoordinator', 'ProcessGroupManager', 'DeviceMeshManager']
diff --git a/colossalai/cluster/device_mesh_manager.py b/colossalai/cluster/device_mesh_manager.py
new file mode 100644
index 000000000000..744799182e22
--- /dev/null
+++ b/colossalai/cluster/device_mesh_manager.py
@@ -0,0 +1,36 @@
+from colossalai.device.device_mesh import DeviceMesh
+
+
+class DeviceMeshManager:
+ """
+ Device mesh manager is responsible for creating and managing device meshes.
+ """
+
+ def __init__(self):
+ self.device_mesh_store = dict()
+
+ def create_device_mesh(self, name, *args, **kwargs) -> DeviceMesh:
+ """
+ Create a device mesh and store it in the manager.
+
+ Args:
+ name (str): name of the device mesh
+ *args: args for DeviceMesh
+ **kwargs: kwargs for DeviceMesh
+ """
+ # TODO(Yuliang): replace *args, **kwargs with explicit arguments
+ if name not in self.device_mesh_store:
+ device_mesh = DeviceMesh(*args, **kwargs)
+ self.device_mesh_store[name] = device_mesh
+ return device_mesh
+ else:
+ raise ValueError(f'Device mesh {name} already exists.')
+
+ def get(self, name: str) -> DeviceMesh:
+ pass
+
+ def destroy(self):
+ pass
+
+ def destroy_all(self):
+ pass
diff --git a/colossalai/cluster/dist_coordinator.py b/colossalai/cluster/dist_coordinator.py
new file mode 100644
index 000000000000..6b48faf5b720
--- /dev/null
+++ b/colossalai/cluster/dist_coordinator.py
@@ -0,0 +1,158 @@
+import os
+from contextlib import contextmanager
+
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+from colossalai.context.singleton_meta import SingletonMeta
+
+
+class DistCoordinator(metaclass=SingletonMeta):
+ """
+ This class is used to coordinate distributed training. It is a singleton class, which means that there is only one instance of this
+ class in the whole program.
+
+ There are some terms that are used in this class:
+ - rank: the rank of the current process
+ - world size: the total number of processes
+ - local rank: the rank of the current process on the current node
+ - master: the process with rank 0
+ - node master: the process with local rank 0 on the current node
+
+ Example:
+ >>> from colossalai.cluster.dist_coordinator import DistCoordinator
+ >>> coordinator = DistCoordinator()
+ >>>
+ >>> if coordinator.is_master():
+ >>> do_something()
+ >>>
+ >>> coordinator.print_on_master('hello world')
+
+ Attributes:
+ rank (int): the rank of the current process
+ world_size (int): the total number of processes
+ local_rank (int): the rank of the current process on the current node
+ """
+
+ def __init__(self):
+ assert dist.is_initialized(
+ ), 'Distributed is not initialized. Please call `torch.distributed.init_process_group` or `colossalai.launch` first.'
+ self._rank = dist.get_rank()
+ self._world_size = dist.get_world_size()
+ # this is often passed by launchers such as torchrun
+ self._local_rank = os.environ.get('LOCAL_RANK', -1)
+
+ @property
+ def rank(self) -> int:
+ return self._rank
+
+ @property
+ def world_size(self) -> int:
+ return self._world_size
+
+ @property
+ def local_rank(self) -> int:
+ return self._local_rank
+
+ def _assert_local_rank_set(self):
+ """
+ Assert that the local rank is set. This is often passed by launchers such as torchrun.
+ """
+ assert self.local_rank >= 0, 'The environment variable LOCAL_RANK is not set, thus the coordinator is not aware of the local rank of the current process.'
+
+ def is_master(self, process_group: ProcessGroup = None) -> bool:
+ """
+ Check if the current process is the master process (rank is 0). It can accept a sub process group to check the rank 0 with respect to the process.
+
+ Args:
+ process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
+
+ Returns:
+ bool: True if the current process is the master process, False otherwise
+ """
+ rank = dist.get_rank(group=process_group)
+ return rank == 0
+
+ def is_node_master(self) -> bool:
+ """
+ Check if the current process is the master process on the current node (local rank is 0).
+
+ Returns:
+ bool: True if the current process is the master process on the current node, False otherwise
+ """
+ self._assert_local_rank_set()
+ return self.local_rank == 0
+
+ def is_last_process(self, process_group: ProcessGroup = None) -> bool:
+ """
+ Check if the current process is the last process (rank is world size - 1). It can accept a sub process group to check the last rank with respect to the process.
+
+ Args:
+ process_group (ProcessGroup, optional): process group to use for the last rank check. Defaults to None, which refers to the default process group.
+
+ Returns:
+ bool: True if the current process is the last process, False otherwise
+ """
+ rank = dist.get_rank(group=process_group)
+ world_size = dist.get_world_size(group=process_group)
+ return rank == world_size - 1
+
+ def print_on_master(self, msg: str, process_group: ProcessGroup = None):
+ """
+ Print message only from rank 0.
+
+ Args:
+ msg (str): message to print
+ process_group (ProcessGroup, optional): process group to use for the rank 0 check. Defaults to None, which refers to the default process group.
+ """
+ rank = dist.get_rank(group=process_group)
+ if rank == 0:
+ print(msg)
+
+ def print_on_node_master(self, msg: str):
+ """
+ Print message only from local rank 0. Local rank 0 refers to the 0th process running the current node.
+
+ Args:
+ msg (str): message to print
+ """
+ self._assert_local_rank_set()
+ if self.local_rank == 0:
+ print(msg)
+
+ @contextmanager
+ def priority_execution(self, executor_rank: int = 0, process_group: ProcessGroup = None):
+ """
+ This context manager is used to allow one process to execute while blocking all
+ other processes in the same process group. This is often useful when downloading is required
+ as we only want to download in one process to prevent file corruption.
+
+ Example:
+ >>> from colossalai.cluster import DistCoordinator
+ >>> dist_coordinator = DistCoordinator()
+ >>> with dist_coordinator.priority_execution():
+ >>> dataset = CIFAR10(root='./data', download=True)
+
+ Args:
+ executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
+ process_group (ProcessGroup, optional): process group to use for the executor rank check. Defaults to None, which refers to the default process group.
+ """
+ rank = dist.get_rank(group=process_group)
+ should_block = rank != executor_rank
+
+ if should_block:
+ dist.barrier(group=process_group)
+
+ yield
+
+ if not should_block:
+ dist.barrier(group=process_group)
+
+ def destroy(self, process_group: ProcessGroup = None):
+ """
+ Destroy the distributed process group.
+
+ Args:
+ process_group (ProcessGroup, optional): process group to destroy. Defaults to None, which refers to the default process group.
+ """
+ dist.destroy_process_group(process_group)
diff --git a/colossalai/cluster/process_group_manager.py b/colossalai/cluster/process_group_manager.py
new file mode 100644
index 000000000000..e52661846f3e
--- /dev/null
+++ b/colossalai/cluster/process_group_manager.py
@@ -0,0 +1,75 @@
+from typing import List
+
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
+
+
+class ProcessGroupManager:
+ """
+ ProcessGroupManager is used to manage the process groups in the cluster.
+
+ There are some terms used in this class:
+ - pg: the short name for process group
+ - pg_name: the name of the process group
+ - pg_size: the world size of the process group
+ - rank: the rank of the current process in the process group
+ - world_size: the total number of processes in the process group
+ """
+
+ def __init__(self):
+ self.pg_store = dict()
+
+ def create_process_group(self, name: str, ranks: List[int], backend: str = 'nccl') -> ProcessGroup:
+ """
+ Get a process group by name. If the process group does not exist, it will be created.
+
+ Args:
+ name (str): name of the process group
+ ranks (List[int]): ranks of the process group
+ backend (str, optional): backend of the process group. Defaults to 'nccl'.
+
+ Returns:
+ ProcessGroup: the process group
+ """
+ if name not in self.pg_store:
+ pg = dist.new_group(ranks=ranks, backend=backend)
+ self.pg_store[name] = pg
+ return pg
+ else:
+ raise ValueError(f'Process group {name} already exists.')
+
+ def get(self, name: str) -> ProcessGroup:
+ """
+ Get a process group by name.
+
+ Args:
+ name (str): name of the process group
+
+ Returns:
+ ProcessGroup: the process group
+ """
+ if name in self.pg_store:
+ return self.pg_store[name]
+ else:
+ raise ValueError(f'Process group {name} does not exist.')
+
+ def destroy(self, name: str) -> None:
+ """
+ Destroy a process group by name.
+
+ Args:
+ name (str): name of the process group
+ """
+ if name in self.pg_store:
+ dist.destroy_process_group(self.pg_store[name])
+ del self.pg_store[name]
+ else:
+ raise ValueError(f'Process group {name} does not exist.')
+
+ def destroy_all(self) -> None:
+ """
+ Destroy all process groups.
+ """
+ for name in self.pg_store:
+ dist.destroy_process_group(self.pg_store[name])
+ self.pg_store.clear()
diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py
index 0879f5fd2659..1d7a883b1552 100644
--- a/colossalai/context/moe_context.py
+++ b/colossalai/context/moe_context.py
@@ -1,129 +1,129 @@
-import torch
-import torch.distributed as dist
-
-from colossalai.context.parallel_mode import ParallelMode
-from colossalai.context.singleton_meta import SingletonMeta
-from colossalai.tensor import ProcessGroup
-
-from typing import Tuple
-
-
-def _check_sanity():
- from colossalai.core import global_context as gpc
- if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
- raise NotImplementedError("Moe is not compatible with tensor or "
- "pipeline parallel at present.")
-
-
-class MoeParallelInfo:
- """Moe parallelism information, storing parallel sizes and groups.
- """
-
- def __init__(self, ep_size: int, dp_size: int):
- _check_sanity()
- self.ep_size = ep_size
- self.dp_size = dp_size
- self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
- self.ep_group = self.pg.tp_process_group()
- self.dp_group = self.pg.dp_process_group()
-
-
-class MoeContext(metaclass=SingletonMeta):
- """MoE parallel context manager. This class manages different
- parallel groups in MoE context and MoE loss in training.
- """
-
- def __init__(self):
- self.world_size = 1
- # Users may want to set maximum expert parallel size smaller than the world size
- # since very low bandwidth across nodes may constrain the performance of MoE
- # When we have a maximum expert parallel size, we have a minimum data parallel size naturally
- self.max_ep_size = 1
- self.min_dp_size = 1
- self.aux_loss = None
- self.use_kernel_optim = True
-
- self.has_setup = False
- self._parallel_info_dict = dict()
-
- @property
- def parallel_info_dict(self):
- return self._parallel_info_dict
-
- @property
- def is_initialized(self):
- return self.has_setup
-
- def setup(self, seed: int, use_kernel_optim: bool = True):
- assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
- _check_sanity()
- assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
-
- self.world_size = dist.get_world_size()
-
- from colossalai.core import global_context as gpc
- self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
- assert self.world_size % self.max_ep_size == 0, \
- "Maximum epxert parallel size must be a factor of the number of GPUs"
- self.min_dp_size = self.world_size // self.max_ep_size
-
- # Enabling kernel optimization may raise error in some cases
- # Users can close kernel optimization manually
- self.use_kernel_optim = use_kernel_optim
-
- from .random import moe_set_seed
- moe_set_seed(seed)
- self.has_setup = True
-
- def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
- """Calculate the Data Parallel Group and Expert Parallel Group.
-
- Parameters
- ----------
- num_experts : int
- The number experts
-
- Returns
- -------
- int, MoeParallelInfo
- number of local experts, the MoeParallelInfo of the current ep_size
- """
-
- gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
- lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
-
- assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
- " is not a multiple of ep size or vice versa."
-
- # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
- # there are multiple experts in each GPU and each GPU has different experts
- # So it's data parallel size is 1
- # Otherwise, there is only one expert in each GPU
- # The data parallel size should be calculated
- dp_size = 1 if gt_flag else self.max_ep_size // num_experts
- ep_size = self.max_ep_size // dp_size
-
- # Calculate the number of experts for each GPU
- num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
-
- # Don't forget to multiply minimum data parallel size
- dp_size *= self.min_dp_size
- if not (ep_size in self.parallel_info_dict):
- self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
-
- return num_local_experts, self.parallel_info_dict[ep_size]
-
- def set_kernel_not_use(self):
- self.use_kernel_optim = False
-
- def reset_loss(self):
- self.aux_loss = 0
-
- def add_loss(self, loss):
- self.aux_loss += loss
-
- def get_loss(self):
- return self.aux_loss
-
-
-MOE_CONTEXT = MoeContext()
+from typing import Tuple
+
+import torch
+import torch.distributed as dist
+
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.context.singleton_meta import SingletonMeta
+from colossalai.tensor import ProcessGroup
+
+
+def _check_sanity():
+ from colossalai.core import global_context as gpc
+ if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
+ raise NotImplementedError("Moe is not compatible with tensor or "
+ "pipeline parallel at present.")
+
+
+class MoeParallelInfo:
+ """Moe parallelism information, storing parallel sizes and groups.
+ """
+
+ def __init__(self, ep_size: int, dp_size: int):
+ _check_sanity()
+ self.ep_size = ep_size
+ self.dp_size = dp_size
+ self.pg = ProcessGroup(tp_degree=ep_size, dp_degree=dp_size)
+ self.ep_group = self.pg.tp_process_group()
+ self.dp_group = self.pg.dp_process_group()
+
+
+class MoeContext(metaclass=SingletonMeta):
+ """MoE parallel context manager. This class manages different
+ parallel groups in MoE context and MoE loss in training.
+ """
+
+ def __init__(self):
+ self.world_size = 1
+ # Users may want to set maximum expert parallel size smaller than the world size
+ # since very low bandwidth across nodes may constrain the performance of MoE
+ # When we have a maximum expert parallel size, we have a minimum data parallel size naturally
+ self.max_ep_size = 1
+ self.min_dp_size = 1
+ self.aux_loss = None
+ self.use_kernel_optim = True
+
+ self.has_setup = False
+ self._parallel_info_dict = dict()
+
+ @property
+ def parallel_info_dict(self):
+ return self._parallel_info_dict
+
+ @property
+ def is_initialized(self):
+ return self.has_setup
+
+ def setup(self, seed: int, use_kernel_optim: bool = True):
+ assert not self.is_initialized, "MoE distributed context shouldn't be set up again"
+ _check_sanity()
+ assert torch.cuda.is_available(), "MoE requires to enable CUDA first"
+
+ self.world_size = dist.get_world_size()
+
+ from colossalai.core import global_context as gpc
+ self.max_ep_size = gpc.config.get('max_ep_size', self.world_size)
+ assert self.world_size % self.max_ep_size == 0, \
+ "Maximum epxert parallel size must be a factor of the number of GPUs"
+ self.min_dp_size = self.world_size // self.max_ep_size
+
+ # Enabling kernel optimization may raise error in some cases
+ # Users can close kernel optimization manually
+ self.use_kernel_optim = use_kernel_optim
+
+ from .random import moe_set_seed
+ moe_set_seed(seed)
+ self.has_setup = True
+
+ def get_info(self, num_experts: int) -> Tuple[int, MoeParallelInfo]:
+ """Calculate the Data Parallel Group and Expert Parallel Group.
+
+ Parameters
+ ----------
+ num_experts : int
+ The number experts
+
+ Returns
+ -------
+ int, MoeParallelInfo
+ number of local experts, the MoeParallelInfo of the current ep_size
+ """
+
+ gt_flag = num_experts % self.max_ep_size == 0 # check whether num_experts is greater
+ lt_flag = self.max_ep_size % num_experts == 0 # check whether num_experts is less
+
+ assert gt_flag or lt_flag, "Automatic experts placement dose not not support expert number" \
+ " is not a multiple of ep size or vice versa."
+
+ # If the number of experts is greater than maximum expert parallel size. a.k.a ep_size,
+ # there are multiple experts in each GPU and each GPU has different experts
+ # So it's data parallel size is 1
+ # Otherwise, there is only one expert in each GPU
+ # The data parallel size should be calculated
+ dp_size = 1 if gt_flag else self.max_ep_size // num_experts
+ ep_size = self.max_ep_size // dp_size
+
+ # Calculate the number of experts for each GPU
+ num_local_experts = 1 if lt_flag else num_experts // self.max_ep_size
+
+ # Don't forget to multiply minimum data parallel size
+ dp_size *= self.min_dp_size
+ if not (ep_size in self.parallel_info_dict):
+ self.parallel_info_dict[ep_size] = MoeParallelInfo(ep_size, dp_size)
+
+ return num_local_experts, self.parallel_info_dict[ep_size]
+
+ def set_kernel_not_use(self):
+ self.use_kernel_optim = False
+
+ def reset_loss(self):
+ self.aux_loss = 0
+
+ def add_loss(self, loss):
+ self.aux_loss += loss
+
+ def get_loss(self):
+ return self.aux_loss
+
+
+MOE_CONTEXT = MoeContext()
diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/context/process_group_initializer/initializer_2d.py
index fe0ba553d6f3..7fbe3be5901f 100644
--- a/colossalai/context/process_group_initializer/initializer_2d.py
+++ b/colossalai/context/process_group_initializer/initializer_2d.py
@@ -2,10 +2,11 @@
import torch.distributed as dist
+from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
-from .process_group_initializer import ProcessGroupInitializer
+
from ..parallel_mode import ParallelMode
-from colossalai.global_variables import tensor_parallel_env as env
+from .process_group_initializer import ProcessGroupInitializer
def _check_summa_env_var(summa_dim):
diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py
index edd1a3706c68..0ddb52f63e22 100644
--- a/colossalai/context/process_group_initializer/initializer_pipeline.py
+++ b/colossalai/context/process_group_initializer/initializer_pipeline.py
@@ -4,8 +4,9 @@
from torch import distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
-from .process_group_initializer import ProcessGroupInitializer
+
from ..parallel_mode import ParallelMode
+from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/context/process_group_initializer/initializer_sequence.py
index 682fe4bb7633..eaacb14d2282 100644
--- a/colossalai/context/process_group_initializer/initializer_sequence.py
+++ b/colossalai/context/process_group_initializer/initializer_sequence.py
@@ -3,9 +3,10 @@
import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
+
+from ..parallel_mode import ParallelMode
from .initializer_tensor import Initializer_Tensor
from .process_group_initializer import ProcessGroupInitializer
-from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
diff --git a/colossalai/device/alpha_beta_profiler.py b/colossalai/device/alpha_beta_profiler.py
index 324acacb8b4a..af2b10928c6f 100644
--- a/colossalai/device/alpha_beta_profiler.py
+++ b/colossalai/device/alpha_beta_profiler.py
@@ -21,7 +21,7 @@ class AlphaBetaProfiler:
# multi-process with multi-gpu in mpi style.
>>> physical_devices = [0, 1, 4, 5]
>>> ab_profiler = AlphaBetaProfiler(physical_devices)
- >>> ab_dict = profiler.profile_ab()
+ >>> ab_dict = profiler.alpha_beta_dict
>>> print(ab_dict)
{(0, 1): (1.9641406834125518e-05, 4.74049549614719e-12), (0, 4): (1.9506998360157013e-05, 6.97421973297474e-11), (0, 5): (2.293858677148819e-05, 7.129930361393644e-11),
(1, 4): (1.9010603427886962e-05, 7.077968863788975e-11), (1, 5): (1.9807778298854827e-05, 6.928845708992215e-11), (4, 5): (1.8681809306144713e-05, 4.7522367291330524e-12),
@@ -31,13 +31,16 @@ class AlphaBetaProfiler:
def __init__(self,
physical_devices: List[int],
+ alpha_beta_dict: Dict[Tuple[int, int], Tuple[float, float]] = None,
ctype: str = 'a',
warmup: int = 5,
repeat: int = 25,
- latency_iters: int = 5):
+ latency_iters: int = 5,
+ homogeneous_tolerance: float = 0.1):
'''
Args:
physical_devices: A list of device id, each element inside it is the global rank of that device.
+ alpha_beta_dict: A dict which maps a process group to alpha-beta value pairs.
ctype: 'a' for all-reduce, 'b' for broadcast.
warmup: Number of warmup iterations.
repeat: Number of iterations to measure.
@@ -49,8 +52,13 @@ def __init__(self,
self.warmup = warmup
self.repeat = repeat
self.latency_iters = latency_iters
+ self.homogeneous_tolerance = homogeneous_tolerance
self.process_group_dict = None
self._init_profiling()
+ if alpha_beta_dict is None:
+ self.alpha_beta_dict = self.profile_ab()
+ else:
+ self.alpha_beta_dict = alpha_beta_dict
def _init_profiling(self):
# Create process group list based on its global rank
@@ -139,7 +147,7 @@ def profile_latency(self, process_group, pg_handler):
return latency
- def profile_bandwidth(self, process_group, pg_handler, maxbytes):
+ def profile_bandwidth(self, process_group, pg_handler, maxbytes=(1 * GB)):
'''
This function is used to profile the bandwidth of the given process group.
@@ -159,6 +167,7 @@ def profile_ab(self):
'''
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = {}
rank = dist.get_rank()
+ global_pg_handler = dist.new_group(self.physical_devices)
def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
assert rank in process_group
@@ -197,3 +206,183 @@ def get_max_nbytes(process_group: Tuple[int], pg_handler: dist.ProcessGroup):
alpha_beta_dict.update(symmetry_ab_dict)
return alpha_beta_dict
+
+ def search_best_logical_mesh(self):
+ '''
+ This method is used to search the best logical mesh for the given device list.
+
+ The best logical mesh is searched in following steps:
+ 1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict
+ are homogeneous if the beta value is close enough.
+ 2. Find the best homogeneous device group contains all the physical devices. The best homogeneous
+ device group means the lowest beta value in the groups which contains all the physical devices.
+ And the reason we require the group contains all the physical devices is that the devices not in
+ the group will decrease the bandwidth of the group.
+ 3. If the best homogeneous device group is found, we will construct the largest ring for each device
+ based on the best homogeneous device group, and the best logical mesh will be the union of all the
+ rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for
+ 4 devices.
+
+ Returns:
+ best_logical_mesh: The best logical mesh for the given device list.
+
+ Usage:
+ >>> physical_devices = [0, 1, 2, 3]
+ >>> ab_profiler = AlphaBetaProfiler(physical_devices)
+ >>> best_logical_mesh = profiler.search_best_logical_mesh()
+ >>> print(best_logical_mesh)
+ [[0, 1], [2, 3]]
+ '''
+
+ def _power_of_two(integer):
+ return integer & (integer - 1) == 0
+
+ def _detect_homogeneous_device(alpha_beta_dict):
+ '''
+ This function is used to detect whether the devices in the alpha_beta_dict are homogeneous.
+
+ Note: we assume that the devices in the alpha_beta_dict are homogeneous if the beta value
+ of the devices are in range of [(1 - self.homogeneous_tolerance), (1 + self.homogeneous_tolerance)]
+ * base_beta.
+ '''
+ homogeneous_device_dict: Dict[float, List[Tuple[int]]] = {}
+ for process_group, (_, beta) in alpha_beta_dict.items():
+ if homogeneous_device_dict is None:
+ homogeneous_device_dict[beta] = []
+ homogeneous_device_dict[beta].append(process_group)
+
+ match_beta = None
+ for beta_value in homogeneous_device_dict.keys():
+ if beta <= beta_value * (1 + self.homogeneous_tolerance) and beta >= beta_value * (
+ 1 - self.homogeneous_tolerance):
+ match_beta = beta_value
+ break
+
+ if match_beta is not None:
+ homogeneous_device_dict[match_beta].append(process_group)
+ else:
+ homogeneous_device_dict[beta] = []
+ homogeneous_device_dict[beta].append(process_group)
+
+ return homogeneous_device_dict
+
+ def _check_contain_all_devices(homogeneous_group: List[Tuple[int]]):
+ '''
+ This function is used to check whether the homogeneous_group contains all physical devices.
+ '''
+ flatten_mesh = []
+ for process_group in homogeneous_group:
+ flatten_mesh.extend(process_group)
+ non_duplicated_flatten_mesh = set(flatten_mesh)
+ return len(non_duplicated_flatten_mesh) == len(self.physical_devices)
+
+ def _construct_largest_ring(homogeneous_group: List[Tuple[int]]):
+ '''
+ This function is used to construct the largest ring in the homogeneous_group for each rank.
+ '''
+ # Construct the ring
+ ring = []
+ ranks_in_ring = []
+ for rank in self.physical_devices:
+ if rank in ranks_in_ring:
+ continue
+ stable_status = False
+ ring_for_rank = []
+ ring_for_rank.append(rank)
+ check_rank_list = [rank]
+ rank_to_check_list = []
+
+ while not stable_status:
+ stable_status = True
+ check_rank_list.extend(rank_to_check_list)
+ rank_to_check_list = []
+ for i in range(len(check_rank_list)):
+ check_rank = check_rank_list.pop()
+ for process_group in homogeneous_group:
+ if check_rank in process_group:
+ rank_to_append = process_group[0] if process_group[1] == check_rank else process_group[1]
+ if rank_to_append not in ring_for_rank:
+ stable_status = False
+ rank_to_check_list.append(rank_to_append)
+ ring_for_rank.append(rank_to_append)
+
+ ring.append(ring_for_rank)
+ ranks_in_ring.extend(ring_for_rank)
+
+ return ring
+
+ assert _power_of_two(self.world_size)
+ power_of_two = int(math.log2(self.world_size))
+ median = power_of_two // 2
+ balanced_logical_mesh_shape = (2**median, 2**(power_of_two - median))
+ row_size, column_size = balanced_logical_mesh_shape[0], balanced_logical_mesh_shape[1]
+ balanced_logical_mesh = []
+ for row_index in range(row_size):
+ balanced_logical_mesh.append([])
+ for column_index in range(column_size):
+ balanced_logical_mesh[row_index].append(self.physical_devices[row_index * column_size + column_index])
+
+ homogeneous_device_dict = _detect_homogeneous_device(self.alpha_beta_dict)
+ beta_list = [b for b in homogeneous_device_dict.keys()]
+ beta_list.sort()
+ beta_list.reverse()
+ homogeneous_types = len(beta_list)
+ best_logical_mesh = None
+ if homogeneous_types >= 2:
+ for _ in range(homogeneous_types - 1):
+ lowest_beta = beta_list.pop()
+ best_homogeneous_group = homogeneous_device_dict[lowest_beta]
+ # if the best homogeneous group contains all physical devices,
+ # we will build the logical device mesh based on it. Otherwise,
+ # we will check next level homogeneous group.
+ if _check_contain_all_devices(best_homogeneous_group):
+ # We choose the largest ring for each rank to maximum the best bus utilization.
+ best_logical_mesh = _construct_largest_ring(best_homogeneous_group)
+ break
+
+ if homogeneous_types == 1 or best_logical_mesh is None:
+ # in this case, we use balanced logical mesh as the best
+ # logical mesh.
+ best_logical_mesh = balanced_logical_mesh
+
+ return best_logical_mesh
+
+ def extract_alpha_beta_for_device_mesh(self):
+ '''
+ Extract the mesh_alpha list and mesh_beta list based on the
+ best logical mesh, which will be used to initialize the device mesh.
+
+ Usage:
+ >>> physical_devices = [0, 1, 2, 3]
+ >>> ab_profiler = AlphaBetaProfiler(physical_devices)
+ >>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
+ >>> print(mesh_alpha)
+ [2.5917552411556242e-05, 0.00010312341153621673]
+ >>> print(mesh_beta)
+ [5.875573704655635e-11, 4.7361584445959614e-12]
+ '''
+ best_logical_mesh = self.search_best_logical_mesh()
+
+ first_axis = [row[0] for row in best_logical_mesh]
+ second_axis = best_logical_mesh[0]
+
+ # init process group for both axes
+ first_axis_process_group = dist.new_group(first_axis)
+ second_axis_process_group = dist.new_group(second_axis)
+
+ # extract alpha and beta for both axes
+ def _extract_alpha_beta(pg, pg_handler):
+ latency = self.profile_latency(pg, pg_handler)
+ bandwidth = self.profile_bandwidth(pg, pg_handler)
+ broadcast_object = [latency, bandwidth]
+ dist.broadcast_object_list(broadcast_object, src=pg[0])
+ return broadcast_object
+
+ first_latency, first_bandwidth = _extract_alpha_beta(first_axis, first_axis_process_group)
+ second_latency, second_bandwidth = _extract_alpha_beta(second_axis, second_axis_process_group)
+ mesh_alpha = [first_latency, second_latency]
+ # The beta values have been enlarged by 1e10 times temporarilly because the computation cost
+ # is still estimated in the unit of TFLOPs instead of time. We will remove this factor in future.
+ mesh_beta = [1e10 / first_bandwidth, 1e10 / second_bandwidth]
+
+ return mesh_alpha, mesh_beta
diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py
index 7596a100bf93..2a5f747fbc23 100644
--- a/colossalai/device/device_mesh.py
+++ b/colossalai/device/device_mesh.py
@@ -1,21 +1,24 @@
+"""This code is adapted from Alpa
+ https://github.com/alpa-projects/alpa/
+ with some changes. """
+
import operator
from functools import reduce
+from typing import List, Tuple
import torch
import torch.distributed as dist
+# modified from alpa LogicalDeviceMesh(https://github.com/alpa-projects/alpa/blob/main/alpa/shard_parallel/auto_sharding.py)
class DeviceMesh:
- """A logical view of a physical mesh. The logical view is used in the
- search process.
- A physical mesh can have multiple logical views. (e.g., a 2x8 physical mesh
- can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
- own latency and bandwidth. We use alpha-beta model to model the
- communication cost.
+ """A logical view of a physical cluster. For example, we could view a physical cluster
+ with 16 devices as a device mesh with shape (2, 2, 4) or (4, 4).
Arguments:
physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
- mesh_shape (torch.Size): shape of logical view.
+ logical_mesh_id (torch.Tensor): logical view of the devices in global rank.
+ mesh_shape (torch.Size, optional): shape of logical view.
mesh_alpha (List[float], optional): coefficients used for computing
communication cost (default: None)
mesh_beta (List[float], optional): coefficients used for computing
@@ -28,15 +31,21 @@ class DeviceMesh:
"""
def __init__(self,
- physical_mesh_id,
- mesh_shape,
- mesh_alpha=None,
- mesh_beta=None,
- init_process_group=False,
- need_flatten=True):
+ physical_mesh_id: torch.Tensor,
+ mesh_shape: torch.Size = None,
+ logical_mesh_id: torch.Tensor = None,
+ mesh_alpha: List[float] = None,
+ mesh_beta: List[float] = None,
+ init_process_group: bool = False,
+ need_flatten: bool = True):
self.physical_mesh_id = physical_mesh_id
- self.mesh_shape = mesh_shape
- self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
+ if logical_mesh_id is None:
+ self.mesh_shape = mesh_shape
+ self._logical_mesh_id = self.physical_mesh_id.reshape(self.mesh_shape)
+ else:
+ self._logical_mesh_id = logical_mesh_id
+ self.mesh_shape = self._logical_mesh_id.shape
+
# map global rank into logical rank
self.convert_map = {}
self._global_rank_to_logical_rank_map(self._logical_mesh_id, [])
@@ -54,8 +63,8 @@ def __init__(self,
if self.need_flatten and self._logical_mesh_id.dim() > 1:
self.flatten_device_mesh = self.flatten()
# Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten())
- self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
- self.mesh_beta)
+ # self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha,
+ # self.mesh_beta)
@property
def shape(self):
@@ -90,7 +99,7 @@ def flatten(self):
return DeviceMesh(self.physical_mesh_id,
tuple(flatten_mesh_shape),
mesh_alpha=[max(self.mesh_alpha)] * (flatten_mesh_shape_size - 1),
- mesh_beta=[min(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
+ mesh_beta=[max(self.mesh_beta)] * (flatten_mesh_shape_size - 1),
init_process_group=self.init_process_group,
need_flatten=False)
diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py
index 146a29669227..59d8e1058652 100644
--- a/colossalai/engine/_base_engine.py
+++ b/colossalai/engine/_base_engine.py
@@ -1,16 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
+# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
-from typing import List, Iterable
+from typing import Iterable, List, Optional, Type
+
+from torch import Tensor
from torch.nn import Module
from torch.nn.modules.loss import _Loss
-from colossalai.logging import get_dist_logger
-from torch import Tensor
-from colossalai.gemini.ophooks import register_ophooks_recursively, BaseOpHook
-from colossalai.engine.schedule import BaseSchedule, NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule
-from typing import Optional, Type
from colossalai.engine.gradient_handler import BaseGradientHandler
+from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
+from colossalai.gemini.ophooks import BaseOpHook, register_ophooks_recursively
from colossalai.logging import get_dist_logger
@@ -93,7 +93,7 @@ def __init__(self,
if self.uses_pipeline:
self._schedule.pre_processing(self)
- #register hook if any
+ # register hook if any
if len(self._ophook_list) > 0:
register_ophooks_recursively(self._model, self._ophook_list)
diff --git a/colossalai/engine/gradient_handler/utils.py b/colossalai/engine/gradient_handler/utils.py
index e92044b47279..fca5f2ec9da9 100644
--- a/colossalai/engine/gradient_handler/utils.py
+++ b/colossalai/engine/gradient_handler/utils.py
@@ -1,29 +1,30 @@
-import torch.distributed as dist
-import torch.nn as nn
-from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-from typing import Iterable
-
-
-def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
- # get communication world size
- comm_size = dist.get_world_size(group)
- # bucketize and all-reduce
- buckets = {}
- # Pack the buckets.
- for param in param_list:
- if param.requires_grad and param.grad is not None:
- tp = param.data.type()
- if tp not in buckets:
- buckets[tp] = []
- buckets[tp].append(param)
-
- # For each bucket, all-reduce and copy all-reduced grads.
- for tp in buckets:
- bucket = buckets[tp]
- grads = [param.grad.data for param in bucket]
- coalesced = _flatten_dense_tensors(grads)
- coalesced /= comm_size
-
- dist.all_reduce(coalesced, group=group)
- for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
- buf.copy_(synced)
+from typing import Iterable
+
+import torch.distributed as dist
+import torch.nn as nn
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+
+def bucket_allreduce(param_list: Iterable[nn.Parameter], group=None):
+ # get communication world size
+ comm_size = dist.get_world_size(group)
+ # bucketize and all-reduce
+ buckets = {}
+ # Pack the buckets.
+ for param in param_list:
+ if param.requires_grad and param.grad is not None:
+ tp = param.data.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(param)
+
+ # For each bucket, all-reduce and copy all-reduced grads.
+ for tp in buckets:
+ bucket = buckets[tp]
+ grads = [param.grad.data for param in bucket]
+ coalesced = _flatten_dense_tensors(grads)
+ coalesced /= comm_size
+
+ dist.all_reduce(coalesced, group=group)
+ for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
+ buf.copy_(synced)
diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py
index 97571fa024ba..712ae8242409 100644
--- a/colossalai/engine/schedule/_pipeline_schedule.py
+++ b/colossalai/engine/schedule/_pipeline_schedule.py
@@ -4,8 +4,9 @@
import inspect
from typing import Callable, List, Tuple, Union
-import colossalai.communication as comm
import torch.cuda
+
+import colossalai.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
@@ -72,9 +73,9 @@ class PipelineSchedule(BaseSchedule):
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
-
+
Example:
-
+
# this shows an example of customized data_process_func
def data_process_func(stage_output, dataloader_output):
output1, output2 = stage_output
@@ -157,6 +158,7 @@ def load_micro_batch(self):
def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
+
# TODO: remove this after testing new zero with pipeline parallelism
model = engine.model
if isinstance(model, NaiveAMPModel):
@@ -229,7 +231,7 @@ def _get_data_label_for_current_step(self, stage_output, micro_batch_data, crite
return data, label
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
- """Forward step for passed-in model. If it is the first stage, the input tensor
+ """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
@@ -266,7 +268,7 @@ def _forward_step(self, engine, input_obj, return_tensors, return_output_label=T
return output_obj
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
- """Backward step through the passed-in output tensor. If it is the last stage, the
+ """Backward step through the passed-in output tensor. If it is the last stage, the
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users.
@@ -511,7 +513,7 @@ def _forward_step(self,
return_tensors,
return_output_label=True,
accum_loss=None):
- """Forward step for passed-in model. If it is the first stage, the input tensor
+ """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_registrations.py
index 8c0201c71e08..153214447223 100644
--- a/colossalai/fx/_meta_registrations.py
+++ b/colossalai/fx/_meta_registrations.py
@@ -164,18 +164,9 @@ def pick_memory_format():
@register_meta(aten._convolution.default)
-def meta_conv_1(
- input_tensor: torch.Tensor,
- weight: torch.Tensor,
- bias: torch.Tensor,
- stride: List[int],
- padding: List[int],
- dilation: List[int],
- is_transposed: bool,
- output_padding: List[int],
- groups: int,
- *extra_args
-):
+def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
+ padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
+ *extra_args):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@@ -233,11 +224,8 @@ def meta_cuda_rnn(
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
- out_shape = (
- [mini_batch, seq_length, out_size * num_directions]
- if batch_first
- else [seq_length, mini_batch, out_size * num_directions]
- )
+ out_shape = ([mini_batch, seq_length, out_size *
+ num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@@ -372,6 +360,15 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
return dX, dgamma, dbeta
+# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp
+@register_meta(aten.native_group_norm_backward.default)
+def meta_gn_backward(dY: torch.Tensor, input: torch.Tensor, mean, rstd, gamma, N, C, HxW, group, grad_input_mask):
+ dX = torch.empty_like(input)
+ dgamma = torch.empty_like(gamma)
+ dbeta = torch.empty_like(gamma)
+ return dX, dgamma, dbeta
+
+
# ================================== Misc ==========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)
diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py
index fbafd326c6d4..ebb9975f27db 100644
--- a/colossalai/fx/graph_module.py
+++ b/colossalai/fx/graph_module.py
@@ -1,24 +1,34 @@
import os
import warnings
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Set, Type, Union
+
import torch
import torch.nn as nn
from torch.nn.modules.module import _addindent
-from typing import Type, Dict, List, Any, Union, Optional, Set
-from pathlib import Path
+
try:
- from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _WrappedCall, _exec_with_source, _forward_from_src
- from torch.fx.graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
+ from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
+ from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
+
+ from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
COLOGM = True
except:
- from torch.fx.graph_module import GraphModule
from torch.fx.graph import Graph
+ from torch.fx.graph_module import GraphModule
COLOGM = False
if COLOGM:
class ColoGraphModule(GraphModule):
- def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
+ def __init__(self,
+ root: Union[torch.nn.Module, Dict[str, Any]],
+ graph: Graph,
+ class_name: str = 'GraphModule',
+ ckpt_codegen: bool = True):
+ if ckpt_codegen:
+ graph.set_codegen(ActivationCheckpointCodeGen())
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
diff --git a/colossalai/fx/passes/adding_split_node_pass.py b/colossalai/fx/passes/adding_split_node_pass.py
index 373d20c51041..2c7b842b530c 100644
--- a/colossalai/fx/passes/adding_split_node_pass.py
+++ b/colossalai/fx/passes/adding_split_node_pass.py
@@ -1,4 +1,6 @@
+import numpy as np
import torch
+import tqdm
from torch.fx import symbolic_trace
from torch.fx.node import Node
@@ -9,6 +11,199 @@ def pipe_split():
pass
+def block_split():
+ pass
+
+
+# Construct blocks with the condition that (block_flops / total_flops) >= limit.
+def construct_blocks(gm: torch.fx.GraphModule, limit=0.01):
+ total_fwd_flop = 0
+ total_bwd_flop = 0
+ for node in gm.graph.nodes:
+ total_fwd_flop += node.fwd_flop
+ total_bwd_flop += node.bwd_flop
+
+ total_flop = total_fwd_flop + total_bwd_flop
+ per_block_flop = total_flop * limit
+ accumulate_fwd_flop = 0
+ accumulate_bwd_flop = 0
+ block_nodes = []
+ for node in gm.graph.nodes:
+ if 'block_split' in node.name:
+ continue
+ accumulate_fwd_flop += node.fwd_flop
+ accumulate_bwd_flop += node.bwd_flop
+ if accumulate_fwd_flop + accumulate_bwd_flop >= per_block_flop:
+ with gm.graph.inserting_after(node):
+ block_node = gm.graph.create_node('call_function', block_split)
+ setattr(block_node, 'fwd_flop', accumulate_fwd_flop)
+ setattr(block_node, 'bwd_flop', accumulate_bwd_flop)
+ accumulate_fwd_flop = 0
+ accumulate_bwd_flop = 0
+ block_nodes.append(block_node)
+
+ return block_nodes
+
+
+def remove_blocks(gm: torch.fx.GraphModule):
+ for node in gm.graph.nodes:
+ if (node.op, node.target) == ('call_function', block_split):
+ gm.graph.erase_node(node)
+
+
+def get_compute_costs(node_list):
+ num_nodes = len(node_list)
+ all_compute_cost = np.full((num_nodes, num_nodes), np.inf, dtype=np.float64)
+
+ for start in tqdm.tqdm(range(num_nodes), desc='start pos', position=0):
+ for end in tqdm.tqdm(range(start, num_nodes), desc='end pos', position=1, leave=False):
+ selected_flops = [(node_list[i].fwd_flop + node_list[i].bwd_flop) for i in range(start, end + 1)]
+ all_compute_cost[start, end] = sum(selected_flops)
+
+ return all_compute_cost
+
+
+def do_dp_split_gpipe_impl(num_nodes, num_stages, num_microbatches, compute_costs, max_compute_cost):
+ """The core implementation of the DP algorithm."""
+ # Adapted from Alpa DP Formulation.
+ # For f, node ID start from 0
+ # f[number of stages,
+ # node id that is currently being considered]
+
+ # record time cost(assess by fwd+bwd flop now)
+ f = np.full((num_stages + 1, num_nodes + 1), np.inf, dtype=np.float32)
+
+ # record max stage compute cost among all stages in this partition.
+ f_stage_max = np.full((num_stages + 1, num_nodes + 1), 0.0, dtype=np.float32)
+ # record start node index for next stage in this partition
+ f_argmin = np.full((num_stages + 1, num_nodes + 1), -1, dtype=np.int32)
+ f[0, num_nodes] = 0
+ for s in tqdm.tqdm(range(1, num_stages + 1), desc='stage', position=2, leave=False): # pylint: disable=too-many-nested-blocks
+ for i in tqdm.tqdm(range(num_nodes - 1, -1, -1), desc='start node', position=3, leave=False):
+ for k in tqdm.tqdm(range(num_nodes, i, -1), desc='mid node', position=4, leave=False):
+ stage_cost = compute_costs[i, k - 1]
+ new_cost = f[s - 1, k] + stage_cost
+ if (stage_cost <= max_compute_cost and new_cost < f[s, i]):
+ f[s, i] = new_cost
+ f_stage_max[s, i] = max(f_stage_max[s - 1, k], stage_cost)
+ f_argmin[s, i] = k
+
+ best_total_cost = f[num_stages, 0]
+ if np.isinf(best_total_cost):
+ return np.inf, None
+
+ total_cost = f[num_stages, 0] + (num_microbatches - 1) * f_stage_max[num_stages, 0]
+
+ current_s = num_stages
+ current_node = 0
+
+ res = []
+ while current_s > 0 and current_node < num_nodes:
+ next_start_node = f_argmin[current_s, current_node]
+ res.append((current_node, next_start_node))
+ current_s -= 1
+ current_node = next_start_node
+
+ return total_cost, res
+
+
+def do_dp_split_gpipe(node_list, compute_costs, num_stages: int, num_microbatches: int):
+ # Ignore the memory cost profiling in Alpa's design for convenience.
+ max_compute_costs = np.sort(np.unique(compute_costs))
+ best_cost = np.inf
+ best_solution = None
+ last_max_compute_cost = 0.0
+ gap = 1e6 # temporary magic number, unit: flops
+
+ for max_compute_cost in tqdm.tqdm(max_compute_costs):
+ # Pruning to reduce search space.
+ if max_compute_cost * num_microbatches >= best_cost:
+ break
+ if max_compute_cost - last_max_compute_cost < gap:
+ continue
+
+ cost, solution = do_dp_split_gpipe_impl(len(node_list), num_stages, num_microbatches, compute_costs,
+ max_compute_cost)
+
+ if cost < best_cost:
+ best_cost = cost
+ best_solution = solution
+ last_max_compute_cost = max_compute_cost
+ return best_cost, best_solution
+
+
+# Auto DP partition based on Alpa.
+# Adapted to Gpipe Scheduler
+# split_mode:
+# 'node': fx_node
+# 'block': many fx_nodes construct a block
+def gpipe_dp_split_pass(gm: torch.fx.GraphModule, pp_size: int, num_microbatches: int, mode='block', block_limit=0.01):
+ assert mode in ['node', 'block']
+
+ # nodes or blocks will be used in partition.
+ node_list = []
+ if mode == 'node':
+ for node in gm.graph.nodes:
+ node_list.append(node)
+ elif mode == 'block':
+ node_list = construct_blocks(gm, limit=block_limit)
+ else:
+ pass
+
+ compute_costs = get_compute_costs(node_list)
+
+ best_cost, best_solution = do_dp_split_gpipe(node_list, compute_costs, pp_size, num_microbatches)
+
+ for (_, next_start_node) in best_solution:
+ if pp_size <= 1:
+ break
+ node = node_list[next_start_node]
+ with gm.graph.inserting_before(node):
+ split_node = gm.graph.create_node('call_function', pipe_split)
+ pp_size -= 1
+
+ # remove block node if possible
+ if mode == 'block':
+ remove_blocks(gm)
+
+ gm.recompile()
+ return gm
+
+
+def avgcompute_split_pass(gm: torch.fx.GraphModule, pp_size: int):
+ """
+ In avgcompute_split_pass, we split module by the fwd flops.
+ """
+ mod_graph = gm.graph
+ # To use avgcompute_split_pass, we need run meta_info_prop interpreter first.
+ # If nodes don't have meta info, this pass will fall back to normal balanced split pass.
+ check_node = list(mod_graph.nodes)[0]
+ if 'tensor_meta' not in check_node.meta:
+ return balanced_split_pass(gm, pp_size)
+
+ total_fwd_flop = 0
+ for node in mod_graph.nodes:
+ total_fwd_flop += node.fwd_flop
+
+ partition_flop = total_fwd_flop // pp_size
+ accumulate_fwd_flop = 0
+ for node in mod_graph.nodes:
+ if pp_size <= 1:
+ break
+ if 'pipe_split' in node.name:
+ continue
+ accumulate_fwd_flop += node.fwd_flop
+ if accumulate_fwd_flop >= partition_flop:
+ total_fwd_flop = total_fwd_flop - accumulate_fwd_flop
+ accumulate_fwd_flop = 0
+ pp_size -= 1
+ partition_flop = total_fwd_flop // pp_size
+ with mod_graph.inserting_after(node):
+ split_node = mod_graph.create_node('call_function', pipe_split)
+ gm.recompile()
+ return gm
+
+
def avgnode_split_pass(gm: torch.fx.GraphModule, pp_size: int):
"""
In avgnode_split_pass, simpliy split graph by node number.
@@ -104,8 +299,10 @@ def balanced_split_pass_v2(gm: torch.fx.GraphModule, pp_size: int):
continue
accumulate_node_size += node.node_size
if accumulate_node_size >= partition_size:
+ total_element_size = total_element_size - accumulate_node_size
accumulate_node_size = 0
pp_size -= 1
+ partition_size = total_element_size // pp_size
with mod_graph.inserting_after(node):
split_node = mod_graph.create_node('call_function', pipe_split)
gm.recompile()
diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py
deleted file mode 100644
index 9ccf135d0911..000000000000
--- a/colossalai/fx/passes/algorithms/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .ckpt_solver_chen import chen_greedy
-from .linearize import linearize
-from .ckpt_solver_rotor import solver_rotor
-from .ckpt_solver_pofo import solver_pofo
diff --git a/colossalai/fx/passes/algorithms/build_c_ext.py b/colossalai/fx/passes/algorithms/build_c_ext.py
deleted file mode 100644
index cb360cb20340..000000000000
--- a/colossalai/fx/passes/algorithms/build_c_ext.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from setuptools import setup, Extension
-import os
-
-this_dir = os.path.dirname(os.path.abspath(__file__))
-ext_modules = [Extension(
- 'dynamic_programs_C_version',
- sources=[os.path.join(this_dir, 'dynamic_programs.c')],
-)]
-
-setup(
- name='rotor c extension',
- version='0.1',
- description='rotor c extension for faster dp computing',
- ext_modules=ext_modules,
-)
diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py
deleted file mode 100644
index 52000ebe5364..000000000000
--- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import math
-from typing import List, Set, Tuple
-
-import torch
-from torch.fx import GraphModule, Node
-
-from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
-
-__all__ = ['chen_greedy']
-CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
-
-
-def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
- """
- In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
- """
-
- def is_sink():
- """
- If we can free all memories when executing a certain node, it is a sink.
- """
- return not sum((v for k, v in deps.items()))
-
- deps = {}
- ckpt_nodes = []
- for n in gm.graph.nodes:
- for n_par in n._input_nodes:
- deps[n_par] -= 1 # free memory and dependencies
-
- # We can only put act_ckpt on these nodes
- if n.op in CKPT_OP and is_sink():
- ckpt_nodes.append(n)
- deps[n] = len(n.users) # add dependencies for future executions
- return ckpt_nodes
-
-
-def chen_greedy(gm: GraphModule) -> GraphModule:
- """
- This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
- Note that this algorithm targets at memory optimization only, using techniques in appendix A.
-
- Usage:
- model = resnet18()
- input_sample = torch.rand(4, 3, 224, 224)
- gm = symbolic_trace(model)
- MetaInfoProp(gm).run(input_sample)
- gm = chen_greedy(gm)
-
- Args:
- gm (GraphModule): The module to add checkpoints
- """
-
- def grid_search(num_grids: int = 6) -> Set:
- """
- Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
- Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
- """
- _, b_approx = run_chen_greedy(0)
- b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
- b_opt = math.inf
- for b in range(b_min, b_max, (b_max - b_min) // num_grids):
- ckpt_intv, b_approx = run_chen_greedy(b)
- if b_approx < b_opt:
- b_opt = b_approx
- ckpt_opt = ckpt_intv
- return ckpt_opt
-
- def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
- """
- This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
- """
- ckpt_nodes = _all_potential_ckpt_nodes(gm)
- ckpt_intv = []
- temp = 0
- x = 0
- y = 0
- prev_idx = 2
- for (idx, n) in enumerate(gm.graph.nodes):
- n: Node
- temp += calculate_fwd_in(n) + calculate_fwd_tmp(n)
- y = max(y, temp)
- if temp > b and n in ckpt_nodes:
- x += calculate_fwd_in(n)
- temp = 0
- ckpt_intv.append((prev_idx, idx + 1))
- prev_idx = idx + 1
- return ckpt_intv, math.floor(math.sqrt(x * y))
-
- gm.graph.lint() # make sure nodes are in topological order
- ckpt = grid_search(num_grids=6)
- node_list = list(gm.graph.nodes)
- for i, seg in enumerate(ckpt):
- for idx in range(*seg):
- n = node_list[idx]
- if n.op in CKPT_OP:
- setattr(n, 'activation_checkpoint', i)
- gm.recompile()
- return gm
diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py b/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
deleted file mode 100644
index 69e4e9f2cce8..000000000000
--- a/colossalai/fx/passes/algorithms/ckpt_solver_pofo.py
+++ /dev/null
@@ -1,537 +0,0 @@
-import copy
-import math
-from typing import List, Tuple
-
-import torch
-from colossalai.fx import is_compatible_with_meta
-from colossalai.fx.codegen.activation_checkpoint_codegen import \
- _find_nested_ckpt_regions
-from colossalai.fx.graph_module import ColoGraphModule
-from colossalai.fx.passes.algorithms.ckpt_solver_rotor import (_compute_table, _construct_chain, _rec)
-from colossalai.fx.passes.meta_info_prop import MetaInfoProp
-from colossalai.fx.profiler import parameter_size
-from torch.fx import GraphModule, Node
-
-from .linearize import linearize
-from .operation import (Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Offload, Prefetch,
- Sequence)
-
-INF = float("inf")
-
-
-def _normalize_flops(chain: Chain, flops) -> Chain:
- """
- Normalize flops
- """
- for i in range(chain.length):
- chain.fweight[i] /= flops
- chain.bweight[i] /= flops
-
- return chain
-
-
-class PofoTable:
- """PofoTable
- The PofoTable contains the necessary components to store intermediate results
- of dynamic programming and the operations alone the way.
- """
-
- def __init__(self, chain_length: int, mem_slots: int):
- """Init pofo table
- The pofo table contains two tables, opt and what, indicating values and
- operations.
-
- Args:
- chain_length (int): chain length
- mem_slots (int): number of memory slots
- """
-
- self.length = chain_length
- self.mem_slots = mem_slots
-
- # initializing tables
- # the first bool indicates whether the input has bar
- # opt table is for value, opt[True/False][i][A][(df, db)] = OCx(i, A, df, db)
- # what table is for decision, what[True/False][i][A][(df, db)] = (is_enable, is_offload, index)
- # where is_enable indicates whether we enable the gradient, is_offload indicates whether we
- # offload the input, index indicates the end of F_\empty sequence if is_enable = False
- self.opt = {
- False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
- True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
- }
- self.what = {
- False: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)],
- True: [[{} for _ in range(mem_slots + 1)] for _ in range(self.length + 1)]
- }
-
- def _get_value(self, state, table, default):
- i, act_size, df, db, input_has_bar = state
- if act_size + df > self.mem_slots or act_size + db > self.mem_slots:
- return default
-
- try:
- return table[input_has_bar][i][act_size][(df, db)]
- except KeyError:
- print(f"state not found {state}")
-
- def get_opt(self, state):
- return self._get_value(state, self.opt, INF)
-
- def get_what(self, state):
- return self._get_value(state, self.what, INF)
-
- def set_value(self, state, opt, what):
- i, act_size, df, db, input_has_bar = state
- self.opt[input_has_bar][i][act_size][(df, db)] = opt
- self.what[input_has_bar][i][act_size][(df, db)] = what
-
-
-class PofoSolver:
- """PofoSolver that executes algorithm mentioned in https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
- The new pofo solver is based on paper Efficient Combination of Rematerialization and Offloading for Training DNNs
- and it's code given in the supplemental. Currently we doesn't use the whole set up in the original paper and reuse
- rotor solver for the backward sequence as suggested in supplemental. The solver now is able to find strategy with offload.
- """
-
- def __init__(self, chain: Chain, max_memory: int, bandwidth, mem_slots: int) -> None:
- self.chain = chain
- self.length = chain.length
- self.max_memory = max_memory
- self.mem_slots = mem_slots
- self.mem_unit = max_memory / mem_slots
- self.bandwidth = bandwidth
-
- self.disc_chain = copy.deepcopy(self.chain)
- self.disc_chain._discretize(self.mem_unit)
-
- self.rotor_table = _compute_table(self.disc_chain, mem_slots)
- self._compute_pofo_table()
-
- def _discretize(self, *values) -> Tuple:
- return tuple(math.ceil(value / self.mem_unit) for value in values)
-
- def _undiscretize(self, *discrete_values) -> Tuple:
- if len(discrete_values) == 1:
- return discrete_values[0] * self.mem_unit
- else:
- return tuple(d * self.mem_unit for d in discrete_values)
-
- def _mmax_all(self, idx: int):
- """
- Calculate the maximum memory usage of Fi_all
- """
-
- return self.chain.cbweight[idx + 1] + self.chain.fwd_mem_tmp[idx]
-
- def _mmax_b(self, idx: int):
- """
- Calculate the maximum memory usage of Bi
- """
-
- return self.chain.cbweight[idx +
- 1] + self.chain.cweight[idx +
- 1] + self.chain.cweight[idx] + self.chain.bwd_mem_tmp[idx]
-
- def _mmax_ng(self, i: int, j: int):
- """
- Calculate the maximum memory usage of CF_i, F_i+1\empty, ... F_j\empty
- """
-
- res = self.chain.cweight[j + 1] + self.chain.fwd_mem_tmp[j]
- if j > i:
- res += self.chain.cweight[j]
- return res
-
- def _rotor_estimated_bwd(self, i, j, m, delta):
- compute = self.rotor_table[0][math.floor((m - self.chain.cweight[i]) / self.mem_unit)][i][j]
- comm = delta / self.bandwidth
- return (max(compute, comm) + compute + comm) / 2
-
- def _rotor_estimated_bwd_sequence(self, i, j, m, delta):
- return _rec(self.disc_chain, i, j, math.floor((m - self.chain.cweight[i]) / self.mem_unit), self.rotor_table)
-
- def _common_values_enable(self, state: Tuple):
-
- idx, act_size, df, db, input_has_bar = state
- input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
- mf = act_size + df + input_size
- mb = act_size + db + input_size
- mem_avail = self.max_memory - act_size - input_size
- f_usage = self._mmax_all(idx)
- b_usage = self._mmax_b(idx)
-
- # infeasible
- if f_usage > mem_avail or b_usage > mem_avail:
- return None
-
- # calculate idle time
- eps_f_beta = max(0, f_usage - self.max_memory + mf)
- eps_b_beta = max(0, b_usage - self.max_memory + mb)
- idle_time = (eps_f_beta + eps_b_beta) / self.bandwidth
-
- # calculate offload and prefetch data
- offload_data = self.chain.fweight[idx] * self.bandwidth + eps_f_beta
- prefetch_data = self.chain.bweight[idx] * self.bandwidth + eps_b_beta
-
- # total_time
- total_time = self.chain.fweight[idx] + self.chain.bweight[idx] + idle_time
-
- return (offload_data, prefetch_data, total_time, idle_time)
-
- def _common_values_nograd(self, state: Tuple, j: int, iterative: bool = False):
-
- i, act_size, df, db, input_has_bar = state
-
- # compute new epsilon_tmp and sum_fwds
- if iterative:
- self.epsilon_tmp = max(self.epsilon_tmp, self._mmax_ng(i, j) - self.bandwidth * self.sum_fwds)
- self.sum_fwds += self.chain.fweight[j]
- else:
- self.epsilon_tmp = max(
- self._mmax_ng(i, k) - self.bandwidth * sum(self.chain.fweight[i:k]) for k in range(i, j + 1))
- self.sum_fwds = sum(self.chain.fweight[i:j + 1])
-
- input_size = self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i]
- mf = act_size + df + input_size
- mem_avail = self.max_memory - act_size - input_size
-
- # if infeasible
- if max(self._mmax_ng(i, k) for k in range(i, self.length)) > mem_avail:
- return None
-
- eps_f_beta = max(0, self.epsilon_tmp - self.max_memory + mf)
- offload_data = self.sum_fwds * self.bandwidth + eps_f_beta
-
- # TODO: Implement the precise backward recompute sequence mentioned in the paper
- # currently we will use an approximate way to get the backward time
- time_backward = self._rotor_estimated_bwd(i, j, mem_avail, db)
-
- prefetch_data = time_backward * self.bandwidth
- idle_time = eps_f_beta / self.bandwidth
- total_time = self.sum_fwds + idle_time + time_backward
-
- return (offload_data, prefetch_data, total_time, idle_time)
-
- def _new_values(self, state: Tuple, do_offload: bool, common_values: Tuple) -> Tuple:
- """Generate new values for next state
-
- Args:
- state (Tuple): undiscretized states
- do_offload (bool): bool type indicates whether we need to do offload
- common_values (Tuple): common values (offload_data, prefetch_data, total_time, idle_time)
-
- Returns:
- Tuple: (new_act_size, new_df, new_db)
- """
- idx, act_size, df, db, input_has_bar = state
- offload_data, prefetch_data, *_ = common_values
- input_size = self.chain.cbweight[idx] if input_has_bar else self.chain.cweight[idx]
- if do_offload:
- new_act_size = act_size
- new_df = max(0, df + input_size - offload_data)
- new_db = max(0, db - prefetch_data) + input_size
- else:
- new_act_size = act_size + input_size
- new_df = max(0, df - offload_data)
- new_db = max(0, db - prefetch_data)
-
- return (new_act_size, new_df, new_db)
-
- def _compute_pofo_table(self):
- self.table = PofoTable(self.length, self.mem_slots)
-
- # initializing the loss
- for act_size in range(self.mem_slots + 1):
- for df in range(self.mem_slots - act_size + 1):
- for db in range(self.mem_slots - act_size + 1):
- # undiscretize for idle time calculation
- origin_values = self._undiscretize(act_size, df, db)
-
- for input_has_bar in (False, True):
- disc_state = (self.length, act_size, df, db, input_has_bar)
- state = (self.length, *origin_values, input_has_bar)
- common_values = self._common_values_enable(state)
-
- # if no feasible choice
- if common_values is None:
- self.table.set_value(disc_state, INF, None)
- continue
-
- # if there is feasible choice
- new_act_size, new_df, new_db = self._new_values(state, False, common_values)
- eps_g = (new_df + new_db) / self.bandwidth
- total_time = common_values[2] + eps_g
- self.table.set_value(disc_state, total_time, (True, False))
-
- # main loop
- for i in reversed(range(self.length)):
- for act_size in range(self.mem_slots + 1):
- for df in range(self.mem_slots - act_size + 1):
- for db in range(self.mem_slots - act_size + 1):
- # undiscretize for idle time calculation
- origin_values = self._undiscretize(act_size, df, db)
-
- for input_has_bar in (False, True):
- best_result = INF
- best_choice = None
- disc_state = (i, act_size, df, db, input_has_bar)
- state = (i, *origin_values, input_has_bar)
-
- # case 1: start with F_all
- vals_enable = self._common_values_enable(state)
- if vals_enable is not None:
- for do_offload in (True, False):
- new_state = self._new_values(state, do_offload, vals_enable)
- new_state = (i + 1, *self._discretize(*new_state), True)
- total_time = vals_enable[2]
- results_all = self.table.get_opt(new_state) + total_time
- if results_all < best_result:
- best_result = results_all
- best_choice = (True, do_offload)
-
- # case 2: start with F_ck
- self.sum_fwds = 0
- self.epsilon_tmp = 0
- for j in range(i, self.length):
- vals_nograd = self._common_values_nograd(state, j, True)
-
- # if infeasible
- if vals_nograd is None:
- continue
-
- for do_offload in (True, False):
- new_state = self._new_values(state, do_offload, vals_nograd)
- new_state = (j + 1, *self._discretize(*new_state), False)
- total_time = vals_nograd[2]
- result_nograd = total_time + self.table.get_opt(new_state)
- if result_nograd < best_result:
- best_result = result_nograd
- best_choice = (False, do_offload, j)
-
- self.table.set_value(disc_state, best_result, best_choice)
-
- def pofo_rec(self, disc_state):
- i, act_size, df, db, input_has_bar = disc_state
- result = Sequence(Function("pofo", *disc_state))
- what = self.table.get_what(disc_state)
- state = self._undiscretize(act_size, df, db)
- state = (i, *state, input_has_bar)
- i, act_size, df, db, input_has_bar = state
-
- if what is None:
- return None
-
- # if loss
- if i == self.length:
- result.insert(Loss())
- return result
-
- if what[0]:
- do_offload = what[1]
- values = self._common_values_enable(state)
- new_state = self._discretize(*self._new_values(state, do_offload, values))
- new_state = (i + 1, *new_state, True)
- if do_offload:
- result.insert(Offload(i, input_has_bar))
- result.insert(ForwardEnable(i))
- result.insert_sequence(self.pofo_rec(new_state))
- if do_offload:
- result.insert(Prefetch(i, input_has_bar))
- result.insert(Backward(i))
-
- else:
- _, do_offload, j = what
- values = self._common_values_nograd(state, j)
- new_state = self._discretize(*self._new_values(state, do_offload, values))
- new_state = (j + 1, *new_state, False)
- if do_offload:
- result.insert(Offload(i, input_has_bar))
- result.insert(ForwardCheck(i))
- for k in range(i + 1, j + 1):
- result.insert(ForwardNograd(k))
- result.insert_sequence(self.pofo_rec(new_state))
- if do_offload:
- result.insert(Prefetch(i, input_has_bar))
- m = self.max_memory - act_size - (self.chain.cbweight[i] if input_has_bar else self.chain.cweight[i])
-
- #TODO: Implement the precise backward recompute sequence mentioned in the paper
- result.insert_sequence(self._rotor_estimated_bwd_sequence(i, j, m, db))
-
- return result
-
-
-def _annotate_from_pofo_sequence(sequence: Sequence, node_list: List[List[Node]]):
- op_list = sequence.list_operations()
- loss_op = next(op for op in op_list if isinstance(op, Loss))
- fwd_list = op_list[:op_list.index(loss_op)]
- bwd_list = op_list[op_list.index(loss_op) + 1:]
- ckpt_idx = 0
- in_ckpt = False
- ckpt_region = []
-
- # forward annotation
- for op in fwd_list:
- if in_ckpt:
- if isinstance(op, ForwardNograd):
- ckpt_region.append(op.index)
-
- elif isinstance(op, ForwardEnable):
- in_ckpt = False
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- setattr(n, "activation_checkpoint", [ckpt_idx])
-
- ckpt_idx += 1
- ckpt_region = []
-
- elif isinstance(op, ForwardCheck):
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- setattr(n, "activation_checkpoint", [ckpt_idx])
-
- ckpt_idx += 1
- ckpt_region = [op.index]
-
- else:
- if isinstance(op, ForwardCheck):
- in_ckpt = True
- ckpt_region.append(op.index)
-
- # annotate the backward if there is any nested activation checkpoint
- in_recompute = False
- for op in bwd_list:
- if in_recompute:
- if isinstance(op, ForwardNograd):
- ckpt_region.append(op.index)
-
- elif isinstance(op, ForwardEnable):
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- n.activation_checkpoint.append(ckpt_idx)
-
- ckpt_idx += 1
- ckpt_region = []
-
- elif isinstance(op, ForwardCheck):
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- n.activation_checkpoint.append(ckpt_idx)
-
- ckpt_idx += 1
- ckpt_region = [op.index]
-
- elif isinstance(op, Backward):
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- n.activation_checkpoint.append(ckpt_idx)
-
- in_recompute = False
-
- else:
- if not isinstance(op, Backward):
- in_recompute = True
- ckpt_idx = 0
- ckpt_region = []
- if isinstance(op, ForwardCheck):
- ckpt_region.append(op.index)
-
- # postprocess, make sure every activation checkpoint label in the
- # same activation checkpoint region (level = 0) has the same length
- op_list = []
- for node in node_list:
- op_list += node
- ckpt_regions = _find_nested_ckpt_regions(op_list)
- for (start_idx, end_idx) in ckpt_regions:
- nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
- for idx in range(start_idx, end_idx + 1):
- op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
-
- # annotate the offload
- offload_idx = 0
- for idx, op in enumerate(fwd_list):
- if isinstance(op, Offload):
- # corner case: offload input
- if op.index == 0:
- if isinstance(fwd_list[idx + 1], ForwardCheck):
- for n in node_list[op.index]:
- setattr(n, "activation_offload", True)
- else:
- for n in node_list[op.index]:
- setattr(n, "activation_offload", (offload_idx, True, False))
- offload_idx += 1
-
- else:
- if op.has_bar:
- # annotate previous node
- if hasattr(node_list[op.index - 1][0], "activation_offload"):
- for n in node_list[op.index - 1]:
- n.activation_offload[-1] = True
- else:
- for n in node_list[op.index - 1]:
- setattr(n, "activation_offload", [offload_idx, False, True])
-
- offload_idx += 1
-
- # annotate this node
- if isinstance(fwd_list[idx + 1], ForwardCheck):
- for n in node_list[op.index]:
- setattr(n, "activation_offload", True)
- else:
- for n in node_list[op.index]:
- setattr(n, "activation_offload", [offload_idx, True, False])
-
- offload_idx += 1
-
-
-def solver_pofo(gm: ColoGraphModule,
- data,
- bandwidth,
- flops,
- mem_limit: int,
- mem_slots: int = 50,
- cnode: List[str] = None,
- eps: float = 0.0) -> ColoGraphModule:
- """Solver that combine offload and activation checkpoint
- Reference: https://proceedings.neurips.cc/paper/2021/hash/c8461bf13fca8a2b9912ab2eb1668e4b-Abstract.html
-
- Args:
- gm (ColoGraphModule): ColoGraphModule derived from tracer
- data: input of the model
- bandwidth: offload bandwidth, unit Byte/s
- flops: FLOPS of device, unit FLOPs/s
- mem_limit (int): memory limit, unit Byte
- mem_slots (int, optional): number of memory slots. Defaults to 500.
- cnode (List[str], optional): common node for linearize. Defaults to None.
- eps (float, optional): epsilon for memory decay. Defaults to 0.02.
-
- Returns:
- ColoGraphModule: annotated graph module
- """
-
- node_list = linearize(gm, cnode)
- mem_limit -= parameter_size(gm)
-
- # prepare data
- if is_compatible_with_meta():
- from colossalai.fx.profiler import MetaTensor
- data = MetaTensor(data, fake_device=next(gm.parameters()).device)
- MetaInfoProp(gm).run(data)
- chain: Chain = _construct_chain(node_list, data)
- chain = _normalize_flops(chain, flops)
- # currently we view loss as an op without expense
- chain.cbweight.append(0)
- chain.cweight.append(0)
- chain.fwd_mem_tmp.append(0)
- chain.bwd_mem_tmp.append(0)
- chain.fweight.append(0)
- chain.bweight.append(0)
-
- solver = PofoSolver(chain, mem_limit, bandwidth, mem_slots)
- first_state = (0, 0, 0, 0, False)
- sequence = solver.pofo_rec(first_state)
- if sequence == None:
- raise ValueError(f"Cannot solve sequence with {mem_limit} Bytes memory")
-
- _annotate_from_pofo_sequence(sequence, node_list)
- setattr(gm, "__sequence__", sequence)
- return gm
diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py b/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
deleted file mode 100644
index 5b8d0da9ffe6..000000000000
--- a/colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
+++ /dev/null
@@ -1,436 +0,0 @@
-import math
-import sys
-from typing import List, Tuple
-
-from torch.fx import Node
-
-from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
-from colossalai.fx.graph_module import ColoGraphModule
-from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
-from colossalai.logging import get_dist_logger
-
-from .linearize import linearize
-from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
-
-# global vairable to indicate whether the solver is failed
-SOLVER_FAILED = False
-
-
-# this is the python compute table code from rotor
-# https://gitlab.inria.fr/hiepacs/rotor
-# paper link: https://hal.inria.fr/hal-02352969
-def _compute_table(chain: Chain, mmax) -> Tuple:
- """Returns the optimal table: a tuple containing:
- Opt[m][lmin][lmax] with lmin = 0...chain.length
- and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
- what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
- (False, j) if the optimal choice is a leaf checkpoint of length j
- The computation uses dynamic programming"""
-
- fw = chain.fweight + [0] ## forward time
- bw = chain.bweight ## backward time, not used
- cw = chain.cweight + [0] ## size of x (and of y)
- cbw = chain.cbweight + [0] ## size of xbar
- fwd_mem_tmp = chain.fwd_mem_tmp + [0]
- bwd_mem_tmp = chain.bwd_mem_tmp + [0]
-
- # Build table
- opt = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
- what = [[{} for _ in range(chain.length + 1)] for _ in range(mmax + 1)]
- # Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
-
- # Initialize borders of the tables for lmax-lmin = 0
- for m in range(mmax + 1):
- for i in range(chain.length + 1):
- #lmax-lmin = 0
- limit = max(cw[i + 1] + cbw[i + 1] + fwd_mem_tmp[i], cw[i + 1] + cbw[i + 1] + bwd_mem_tmp[i])
- if m >= limit: ## Equation (1)
- opt[m][i][i] = fw[i] + bw[i]
- else:
- opt[m][i][i] = float("inf")
-
- # Compute everything
- for m in range(mmax + 1):
- for d in range(1, chain.length + 1):
- for i in range(chain.length + 1 - d):
- # for idx in range(i+1, chain.length + 1):
- idx = i + d
- mmin = cw[idx + 1] + cw[i + 1] + fwd_mem_tmp[i]
- if idx > i + 1:
- mmin = max(mmin, cw[idx + 1] + max(cw[j] + cw[j + 1] + fwd_mem_tmp[j] for j in range(i + 1, idx)))
- if m < mmin:
- opt[m][i][idx] = float("inf")
- else:
- leaf_checkpoints = [(j, sum(fw[i:j]) + opt[m - cw[j]][j][idx] + opt[m][i][j - 1])
- for j in range(i + 1, idx + 1)
- if m >= cw[j]]
- if leaf_checkpoints:
- best_leaf = min(leaf_checkpoints, key=lambda t: t[1])
- else:
- best_leaf = None
- if m >= cbw[i + 1]:
- chain_checkpoint = opt[m][i][i] + opt[m - cbw[i + 1]][i + 1][idx]
- else:
- chain_checkpoint = float("inf")
- if best_leaf and best_leaf[1] <= chain_checkpoint:
- opt[m][i][idx] = best_leaf[1]
- what[m][i][idx] = (False, best_leaf[0])
- else:
- opt[m][i][idx] = chain_checkpoint
- what[m][i][idx] = (True,)
- return (opt, what)
-
-
-def _rec(chain: Chain, lmin, lmax, cmem, opt_table):
- """ chain : the class describing the AC graph
- lmin : index of the first forward to execute
- lmax : upper bound index of the last forward to execute (not included)
- cmem : number of available memory slots
- Return the optimal sequence of makespan Opt_hete[cmem][lmin][lmax-lmin]"""
- if cmem <= 0:
- raise ValueError("Can not process a chain with negative memory {cmem}".format(cmem=cmem))
- opt, what = opt_table
- sequence = Sequence(Function("Persistent", lmax - lmin, cmem))
- if opt[cmem][lmin][lmax] == float("inf"):
- # using logger to annonce that the solver is failed
- logger = get_dist_logger()
- logger.info("Can not process this chain from index {lmin} to {lmax} with memory {cmem}".format(lmin=lmin,
- lmax=lmax,
- cmem=cmem))
-
- # set global indicater SOLVER_FAILED to True
- global SOLVER_FAILED
- SOLVER_FAILED = True
- return sequence
-
- if lmin == lmax:
- if lmin == chain.length:
- sequence.insert(Loss())
- else:
- sequence.insert(ForwardEnable(lmin))
- sequence.insert(Backward(lmin))
- return sequence
-
- if what[cmem][lmin][lmax][0]:
- sequence.insert(ForwardEnable(lmin))
- sequence.insert_sequence(_rec(chain, lmin + 1, lmax, cmem - chain.cbweight[lmin + 1], opt_table))
- sequence.insert(Backward(lmin))
- else:
- j = what[cmem][lmin][lmax][1]
- sequence.insert(ForwardCheck(lmin))
- for k in range(lmin + 1, j):
- sequence.insert(ForwardNograd(k))
- sequence.insert_sequence(_rec(chain, j, lmax, cmem - chain.cweight[j], opt_table))
- sequence.insert_sequence(_rec(chain, lmin, j - 1, cmem, opt_table))
- return sequence
-
-
-def _fwd_xbar(node: List[Node]) -> int:
- """Get the forward xbar of a node
-
- Args:
- node (List[Node]): List of torch.fx Node,
- indicates a node in linearized graph
-
- Returns:
- int: xbar size, unit Byte
- """
-
- xbar = 0
- for n in node:
- xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
- return xbar
-
-
-def _fwd_time(node: List[Node]) -> int:
- """Get the foward time of a node
-
- Args:
- node (List[Node]): List of torch.fx Node,
- indicates a node in linearized graph
-
- Returns:
- int: foward time, extimated by flops count
- """
-
- fwd_time = 0
- for n in node:
- # minimum flop count is needed
- fwd_time += max(n.meta['fwd_flop'], 1)
- return fwd_time
-
-
-def _bwd_time(node: List[Node]) -> int:
- """Get the backward time of a node
-
- Args:
- node (List[Node]): List of torch.fx Node,
- indicates a node in linearized graph
-
- Returns:
- int: backward time, extimated by flops count
- """
-
- bwd_time = 0
- for n in node:
- # minimum flop count is needed
- bwd_time += max(n.meta['bwd_flop'], 1)
- return bwd_time
-
-
-def _get_fwd_mem_tmp(node: List[Node]) -> int:
- """Get the forward temp memory of a node
- This could be done by subtracting the saved activation from all output of a node
-
- Args:
- node (List[Node]): List of torch.fx Node,
- indicates a node in linearized graph
-
- Returns:
- int: forward temp memory, unit Byte
- """
- n = node[-1]
- return activation_size(n.meta['fwd_out']) - calculate_fwd_out(n)
-
-
-def _get_bwd_mem_tmp(node: List[Node]) -> int:
- """Get the backward temp memory of a node
-
- Args:
- node (List[Node]): List of torch.fx Node,
- indicates a node in linearized graph
-
- Returns:
- int: backward temp memory, unit Byte
- """
-
- def _get_deps_size():
- deps_size = 0
- for k, v in deps.items():
- k: Node
- if v > 0:
- deps_size += k.meta['bwd_mem_out']
- if v == float('-inf'):
- deps_size -= calculate_fwd_tmp(k) + calculate_fwd_out(k)
-
- return deps_size
-
- bwd_mem_tmp = 0
- deps = {}
-
- for n in reversed(node):
- deps[n] = len(n.all_input_nodes)
- bwd_mem_tmp = max(bwd_mem_tmp, _get_deps_size() + n.meta['bwd_mem_tmp'])
-
- for child in n.users:
- if child in deps:
- deps[child] -= 1
- if deps[child] <= 0:
- deps[child] = float('-inf') # free
-
- return bwd_mem_tmp
-
-
-def _construct_chain(node_list: List[List[Node]], input) -> Chain:
-
- fwd_time = []
- bwd_time = []
- xbar_sizes = [activation_size(input)]
- x_sizes = [activation_size(input)]
- tmp_fwd = []
- tmp_bwd = []
-
- for idx, node in enumerate(node_list):
- fwd_time.append(_fwd_time(node))
- bwd_time.append(_bwd_time(node))
- x_sizes.append(calculate_fwd_out(node[-1]))
- xbar_sizes.append(max(x_sizes[-1], _fwd_xbar(node)))
- tmp_fwd.append(_get_fwd_mem_tmp(node))
- tmp_bwd.append(_get_bwd_mem_tmp(node))
-
- bwd_time.append(0)
-
- # currently we view loss backward temp as zero
- tmp_bwd.append(0)
-
- return Chain(fwd_time, bwd_time, x_sizes, xbar_sizes, tmp_fwd, tmp_bwd)
-
-
-def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
- op_list = sequence.list_operations()
- loss_op = next(op for op in op_list if isinstance(op, Loss))
- fwd_list = op_list[:op_list.index(loss_op)]
- bwd_list = op_list[op_list.index(loss_op) + 1:]
- ckpt_idx = 0
- in_ckpt = False
- ckpt_region = []
-
- # forward annotation
- for idx, op in enumerate(fwd_list, 0):
- if in_ckpt:
- if isinstance(op, ForwardNograd):
- ckpt_region.append(idx)
-
- elif isinstance(op, ForwardEnable):
- in_ckpt = False
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- setattr(n, "activation_checkpoint", [ckpt_idx])
-
- ckpt_idx += 1
- ckpt_region = []
-
- elif isinstance(op, ForwardCheck):
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- setattr(n, "activation_checkpoint", [ckpt_idx])
-
- ckpt_idx += 1
- ckpt_region = [idx]
-
- else:
- if isinstance(op, ForwardCheck):
- in_ckpt = True
- ckpt_region.append(idx)
-
- # annotate the backward if there is any nested activation checkpoint
- in_recompute = False
- for op in bwd_list:
- if in_recompute:
- if isinstance(op, ForwardNograd):
- ckpt_region.append(op.index)
-
- elif isinstance(op, ForwardEnable):
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- n.activation_checkpoint.append(ckpt_idx)
-
- ckpt_idx += 1
- ckpt_region = []
-
- elif isinstance(op, ForwardCheck):
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- n.activation_checkpoint.append(ckpt_idx)
-
- ckpt_idx += 1
- ckpt_region = [op.index]
-
- elif isinstance(op, Backward):
- for node_idx in ckpt_region:
- for n in node_list[node_idx]:
- n.activation_checkpoint.append(ckpt_idx)
-
- in_recompute = False
-
- else:
- if not isinstance(op, Backward):
- in_recompute = True
- ckpt_idx = 0
- ckpt_region = []
- if isinstance(op, ForwardCheck):
- ckpt_region.append(op.index)
-
- # postprocess, make sure every activation checkpoint label in the
- # same activation checkpoint region (level = 0) has the same length
- op_list = []
- for node in node_list:
- op_list += node
- ckpt_regions = _find_nested_ckpt_regions(op_list)
- for (start_idx, end_idx) in ckpt_regions:
- nested_length = max(len(op_list[idx].activation_checkpoint) for idx in range(start_idx, end_idx + 1))
- for idx in range(start_idx, end_idx + 1):
- op_list[idx].activation_checkpoint += [None] * (nested_length - len(op_list[idx].activation_checkpoint))
-
-
-def solver_rotor(gm: ColoGraphModule,
- data,
- mem_limit: int,
- mem_slots: int = 500,
- cnode: List[str] = None,
- eps: float = 0.0,
- force_python: bool = False) -> ColoGraphModule:
- """solver that automatically find activation checkpoint in rotor's manner
-
- Args:
- gm (ColoGraphModule): ColoGraphModule generated by tracing model and MetaInfoProp.
- data (torch.Tensor): input data.
- mem_limit (int): memory budget in Byte.
- mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
- cnode (List[Node], optional): common node list for linearize. Defaults to None.
- eps (float): epsilon for memory decay. Defaults to 0.0
- force_python (bool): force to use python version of dynamic programs
-
- Returns:
- ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
- """
-
- # try to import C version solver if force_python is not set
- logger = get_dist_logger()
- if not force_python:
- try:
- from .dynamic_programs_C_version import persistent_compute_table
- CVERSION = True
-
- # build module if module not found
- except ModuleNotFoundError:
- import os
- import subprocess
- logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
- this_dir = os.path.dirname(os.path.abspath(__file__))
- result = subprocess.Popen(
- [
- f"{sys.executable}", f"{os.path.join(this_dir, 'build_c_ext.py')}", "build_ext",
- f"--build-lib={this_dir}"
- ],
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- )
- if result.wait() == 0:
- logger.info("dynamic_programs_C_version has been built!", ranks=[0])
- from .dynamic_programs_C_version import persistent_compute_table
- CVERSION = True
- else:
- logger.info("dynamic_programs_C_version built failed! Using python version!", ranks=[0])
- CVERSION = False
- else:
- CVERSION = False
-
- # check if metainfoprop is done
- if any(len(node.meta) == 0 for node in gm.graph.nodes):
- raise RuntimeError(
- "Nodes meta information hasn't been prepared! Please run MetaInfoProp before calling solver!")
-
- # linearize the graph
- node_list = linearize(gm, cnode)
-
- # construct chain
- mem_unit = mem_limit * (1.0 - eps) // mem_slots
- chain: Chain = _construct_chain(node_list, data)
- chain._discretize(mem_unit)
-
- # use C version if possible
- if CVERSION and not force_python:
- logger.info("Using C version rotor solver!", ranks=[0])
- opt_table = persistent_compute_table(chain, mem_slots)
- else:
- opt_table = _compute_table(chain, mem_slots)
- logger.info("Using python version rotor solver!", ranks=[0])
-
- # found sequence
- sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
-
- # if solver failed, we don't need to annotate the graph
- if not SOLVER_FAILED:
- _annotate_from_sequence(sequence, node_list)
-
- # set __sequence__ attribute to GraphModule
- if SOLVER_FAILED:
- setattr(gm, "__sequence__", None)
- else:
- setattr(gm, "__sequence__", sequence)
-
- # set __opttable__ attribute to GraphModule
- setattr(gm, "__opttable__", opt_table[0])
- gm.recompile()
- return gm
diff --git a/colossalai/fx/passes/algorithms/dynamic_programs.c b/colossalai/fx/passes/algorithms/dynamic_programs.c
deleted file mode 100644
index 3efad58400fa..000000000000
--- a/colossalai/fx/passes/algorithms/dynamic_programs.c
+++ /dev/null
@@ -1,516 +0,0 @@
-#define PY_SSIZE_T_CLEAN
-#include
-
-long* PySequenceToLongArray(PyObject* pylist) {
- if (!(pylist && PySequence_Check(pylist))) return NULL;
- Py_ssize_t len = PySequence_Size(pylist);
- long* result = (long*)calloc(len + 1, sizeof(long));
- for (Py_ssize_t i = 0; i < len; ++i) {
- PyObject* item = PySequence_GetItem(pylist, i);
- result[i] = PyLong_AsLong(item);
- Py_DECREF(item);
- }
- result[len] = 0;
- return result;
-}
-
-double* PySequenceToDoubleArray(PyObject* pylist) {
- if (!(pylist && PySequence_Check(pylist))) return NULL;
- Py_ssize_t len = PySequence_Size(pylist);
- double* result = (double*)calloc(len + 1, sizeof(double));
- for (Py_ssize_t i = 0; i < len; ++i) {
- PyObject* item = PySequence_GetItem(pylist, i);
- result[i] = PyFloat_AsDouble(item);
- Py_DECREF(item);
- }
- result[len] = 0;
- return result;
-}
-
-long* getLongArray(PyObject* container, const char* attributeName) {
- PyObject* sequence = PyObject_GetAttrString(container, attributeName);
- long* result = PySequenceToLongArray(sequence);
- Py_DECREF(sequence);
- return result;
-}
-
-double* getDoubleArray(PyObject* container, const char* attributeName) {
- PyObject* sequence = PyObject_GetAttrString(container, attributeName);
- double* result = PySequenceToDoubleArray(sequence);
- Py_DECREF(sequence);
- return result;
-}
-
-static PyObject* persistent_compute_table(PyObject* self, PyObject* args) {
- PyObject* chain_param;
- int mmax;
-
- if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
-
- double* fw = getDoubleArray(chain_param, "fweight");
- if (!fw) return NULL;
-
- double* bw = getDoubleArray(chain_param, "bweight");
- if (!bw) return NULL;
-
- long* cw = getLongArray(chain_param, "cweight");
- if (!cw) return NULL;
-
- long* cbw = getLongArray(chain_param, "cbweight");
- if (!cbw) return NULL;
-
- long* fwd_tmp = getLongArray(chain_param, "fwd_mem_tmp");
- if (!cbw) return NULL;
-
- long* bwd_tmp = getLongArray(chain_param, "bwd_mem_tmp");
- if (!cbw) return NULL;
-
- PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
- if (!chain_length_param) return NULL;
- long chain_length = PyLong_AsLong(chain_length_param);
- Py_DECREF(chain_length_param);
-
- // TODO: Can be optimized by only allocating memory for l >= i
- // TODO: float / int instead of double / long ?
-#define OPT(m, i, l) \
- opt[(m) * (chain_length + 1) * (chain_length + 1) + \
- (i) * (chain_length + 1) + (l)]
- double* opt = (double*)calloc(
- (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
-
-#define WHAT(m, i, l) \
- what[(m) * (chain_length + 1) * (chain_length + 1) + \
- (i) * (chain_length + 1) + (l)]
- long* what = (long*)calloc(
- (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(long));
-
- for (long m = 0; m <= mmax; ++m)
- for (long i = 0; i <= chain_length; ++i)
- // TODO: Can be optimized to remove the IF by reordering loops
- if ((m >= cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
- (m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
- OPT(m, i, i) = fw[i] + bw[i];
- else
- OPT(m, i, i) = INFINITY;
-
- for (long m = 0; m <= mmax; ++m)
- for (long d = 1; d <= chain_length; ++d) {
- for (long i = 0; i <= chain_length - d; ++i) {
- long idx = i + d;
- long mmin = cw[idx + 1] + cw[i + 1] + fwd_tmp[i];
- if (idx > i + 1) {
- long maxCostFWD = 0;
- for (long j = i + 1; j < idx; j++) {
- maxCostFWD = fmaxl(maxCostFWD, cw[j] + cw[j + 1] + fwd_tmp[j]);
- }
- mmin = fmaxl(mmin, cw[idx + 1] + maxCostFWD);
- }
- if ((m >= mmin)) {
- long bestLeaf = -1;
- double sumFw = 0;
- double bestLeafCost = INFINITY;
- /// sumFw + OPT(m-cw[i+1], i+1, l) + OPT(m, i, i); // Value for j =
- /// i+1
- for (long j = i + 1; j <= idx; ++j) {
- sumFw += fw[j - 1];
- if (m >= cw[j]) {
- double cost = sumFw + OPT(m - cw[j], j, idx) + OPT(m, i, j - 1);
- if (cost < bestLeafCost) {
- bestLeafCost = cost;
- bestLeaf = j;
- }
- }
- }
- double chainCost = INFINITY;
- if (m >= cbw[i + 1])
- chainCost = OPT(m, i, i) + OPT(m - cbw[i + 1], i + 1, idx);
- if (bestLeafCost <= chainCost) {
- OPT(m, i, idx) = bestLeafCost;
- WHAT(m, i, idx) = bestLeaf;
- } else {
- OPT(m, i, idx) = chainCost;
- WHAT(m, i, idx) = -1;
- }
- } else
- OPT(m, i, idx) = INFINITY;
- }
- }
-
- free(fw);
- free(bw);
- free(cw);
- free(cbw);
- free(fwd_tmp);
- free(bwd_tmp);
-
- PyObject* res_opt = PyList_New(mmax + 1);
- PyObject* res_what = PyList_New(mmax + 1);
-
- // Convert the result into Python world
- for (long m = 0; m <= mmax; ++m) {
- PyObject* res_opt_m = PyList_New(chain_length + 1);
- PyList_SET_ITEM(res_opt, m, res_opt_m);
- PyObject* res_what_m = PyList_New(chain_length + 1);
- PyList_SET_ITEM(res_what, m, res_what_m);
- for (long i = 0; i <= chain_length; ++i) {
- PyObject* res_opt_m_i = PyDict_New();
- PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
- PyObject* res_what_m_i = PyDict_New();
- PyList_SET_ITEM(res_what_m, i, res_what_m_i);
- for (long l = i; l <= chain_length; ++l) {
- PyObject* res_l = PyLong_FromLong(l);
- PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
- PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
- Py_DECREF(res_opt_m_i_l);
- PyObject* res_what_m_i_l;
- long what_m_i_l = WHAT(m, i, l);
- if (what_m_i_l < 0)
- res_what_m_i_l = Py_BuildValue("(O)", Py_True);
- else
- res_what_m_i_l = Py_BuildValue("(Ol)", Py_False, what_m_i_l);
- PyDict_SetItem(res_what_m_i, res_l, res_what_m_i_l);
- Py_DECREF(res_what_m_i_l);
- Py_DECREF(res_l);
- }
- }
- }
-
- free(opt);
- free(what);
-
- PyObject* result = PyTuple_Pack(2, res_opt, res_what);
- Py_DECREF(res_opt);
- Py_DECREF(res_what);
- return result;
-}
-
-// long i = L - s, j = t - s, k = l - t
-inline long floating_index_in_array(long m_factor, long m, long i, long j,
- long k) {
- return m * m_factor + (i * (i + 1) * (2 * i + 4)) / 12 + (i + 1) * j -
- (j * (j - 1)) / 2 + k;
-}
-
-typedef struct {
- long sp;
- long r;
- long tp;
-} index_t;
-
-static PyObject* floating_compute_table(PyObject* self, PyObject* args) {
- PyObject* chain_param;
- int mmax;
-
- if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
-
- double* fw = getDoubleArray(chain_param, "fweigth");
- if (!fw) return NULL;
-
- double* bw = getDoubleArray(chain_param, "bweigth");
- if (!bw) return NULL;
-
- long* cw = getLongArray(chain_param, "cweigth");
- if (!cw) return NULL;
-
- long* cbw = getLongArray(chain_param, "cbweigth");
- if (!cbw) return NULL;
-
- long* fwd_tmp = getLongArray(chain_param, "fwd_tmp");
- if (!fwd_tmp) return NULL;
-
- long* bwd_tmp = getLongArray(chain_param, "bwd_tmp");
- if (!bwd_tmp) return NULL;
-
- PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
- if (!chain_length_param) return NULL;
- long chain_length = PyLong_AsLong(chain_length_param);
- Py_DECREF(chain_length_param);
-
- const long m_factor =
- (chain_length + 1) * (chain_length + 2) * (2 * chain_length + 6) / 12;
-
- // Defined for 0 <= s <= t <= l <= chain_length, for all m
-#undef OPT
-#define OPT(m, s, t, l) \
- opt[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
- (l) - (t))]
- double* opt = (double*)calloc((mmax + 1) * m_factor, sizeof(double));
-
-#undef WHAT
-#define WHAT(m, s, t, l) \
- what[floating_index_in_array(m_factor, (m), chain_length - (s), (t) - (s), \
- (l) - (t))]
- index_t* what = (index_t*)calloc((mmax + 1) * m_factor, sizeof(index_t));
-
- double* partialSumsFW = (double*)calloc(chain_length + 1, sizeof(double));
- double total = 0;
- for (long i = 0; i < chain_length; ++i) {
- partialSumsFW[i] = total;
- total += fw[i];
- }
- partialSumsFW[chain_length] = total;
-
- for (long m = 0; m <= mmax; ++m)
- for (long i = 0; i <= chain_length; ++i) {
- // TODO: Can be optimized to remove the IF by reordering loops
- if ((m >= cw[i] + cw[i + 1] + cbw[i + 1] + bwd_tmp[i]) &&
- (m >= cw[i + 1] + cbw[i + 1] + fwd_tmp[i]))
- OPT(m, i, i, i) = fw[i] + bw[i];
- else
- OPT(m, i, i, i) = INFINITY;
- }
-
- for (long m = 0; m <= mmax; ++m)
- for (long d = 1; d <= chain_length; ++d) { // d = l - s
- for (long s = 0; s <= chain_length - d; ++s) {
- long l = s + d;
- long memNullFirst = cw[l + 1] + cw[s + 1] + fwd_tmp[s];
- long memNullSecond = 0;
- for (long j = s + 1; j < l; ++j) {
- long val = cw[j] + cw[j + 1] + fwd_tmp[j];
- if (val > memNullSecond) memNullSecond = val;
- }
- for (long t = s; t <= l; ++t) {
- double chainCost = INFINITY;
- if ((s == t) && (m >= cw[l + 1] + cbw[s + 1] + fwd_tmp[s]) &&
- (m >= cw[s] + cw[s + 1] + cbw[s + 1] + bwd_tmp[s])) {
- chainCost = OPT(m, s, s, s) + OPT(m - cbw[s + 1], s + 1, s + 1, l);
- }
- double bestLeafCost = INFINITY;
- index_t bestLeaf = {.sp = -1, .r = -1, .tp = -1};
- if (m >= memNullFirst && m >= cw[l + 1] + memNullSecond) {
- for (long r = s; r <= t; ++r)
- if (cw[s] <= cw[r])
- for (long tp = t + 1; tp <= l; ++tp)
- for (long sp = r + 1; sp <= tp; ++sp) {
- long mp = m - cw[r] + cw[s];
- assert(mp >= 0);
- if (mp >= cw[sp]) {
- double value = partialSumsFW[sp] - partialSumsFW[s] +
- OPT(mp - cw[sp], sp, tp, l) +
- OPT(mp, r, t, tp - 1);
- if (value < bestLeafCost) {
- bestLeafCost = value;
- bestLeaf.sp = sp;
- bestLeaf.r = r;
- bestLeaf.tp = tp;
- }
- }
- }
- }
- if (bestLeaf.sp >= 0 && bestLeafCost <= chainCost) {
- OPT(m, s, t, l) = bestLeafCost;
- WHAT(m, s, t, l).sp = bestLeaf.sp;
- WHAT(m, s, t, l).r = bestLeaf.r;
- WHAT(m, s, t, l).tp = bestLeaf.tp;
- } else {
- OPT(m, s, t, l) = chainCost;
- WHAT(m, s, t, l).sp = -1;
- }
- }
- }
- }
-
- free(fw);
- free(bw);
- free(cw);
- free(cbw);
- free(fwd_tmp);
- free(bwd_tmp);
-
- PyObject* res_opt = PyList_New(mmax + 1);
- PyObject* res_what = PyList_New(mmax + 1);
-
- // Convert the result into Python world
- PyObject* true_tuple = Py_BuildValue("(O)", Py_True);
- for (long m = 0; m <= mmax; ++m) {
- PyObject* res_opt_m = PyDict_New();
- PyList_SET_ITEM(res_opt, m, res_opt_m);
- PyObject* res_what_m = PyDict_New();
- PyList_SET_ITEM(res_what, m, res_what_m);
- for (long s = 0; s <= chain_length; ++s)
- for (long t = s; t <= chain_length; ++t)
- for (long l = t; l <= chain_length; ++l) {
- PyObject* key = Py_BuildValue("(lll)", s, t, l);
- PyObject* value_opt = PyFloat_FromDouble(OPT(m, s, t, l));
- PyDict_SetItem(res_opt_m, key, value_opt);
- PyObject* value_what = true_tuple;
- index_t* idx_what = &WHAT(m, s, t, l);
- if (idx_what->sp >= 0)
- value_what = Py_BuildValue("(O(lll))", Py_False, idx_what->sp,
- idx_what->r, idx_what->tp);
- PyDict_SetItem(res_what_m, key, value_what);
- if (value_what != true_tuple) Py_DECREF(value_what);
- Py_DECREF(key);
- Py_DECREF(value_opt);
- }
- }
-
- Py_DECREF(true_tuple);
-
- free(opt);
- free(what);
-
- PyObject* result = PyTuple_Pack(2, res_opt, res_what);
- Py_DECREF(res_opt);
- Py_DECREF(res_what);
- return result;
-}
-
-static PyObject* griewank_heterogeneous_compute_table(PyObject* self,
- PyObject* args) {
- PyObject* chain_param;
- int mmax;
-
- if (!PyArg_ParseTuple(args, "Oi", &chain_param, &mmax)) return NULL;
-
- double* fw = getDoubleArray(chain_param, "fweigth");
- if (!fw) return NULL;
-
- double* bw = getDoubleArray(chain_param, "bweigth");
- if (!bw) return NULL;
-
- long* cw = getLongArray(chain_param, "cweigth");
- if (!cw) return NULL;
-
- long* cbw = getLongArray(chain_param, "cbweigth");
- if (!cbw) return NULL;
-
- PyObject* chain_length_param = PyObject_GetAttrString(chain_param, "length");
- if (!chain_length_param) return NULL;
- long chain_length = PyLong_AsLong(chain_length_param);
- Py_DECREF(chain_length_param);
-
- // TODO: Can be optimized by only allocating memory for l >= i
- // TODO: float / int instead of double / long ?
-#undef OPT
-#define OPT(m, i, l) \
- opt[(m) * (chain_length + 1) * (chain_length + 1) + \
- (i) * (chain_length + 1) + (l)]
- double* opt = (double*)calloc(
- (mmax + 1) * (chain_length + 1) * (chain_length + 1), sizeof(double));
-
- // Compute partial sums
- double* sumfw = (double*)calloc(chain_length, sizeof(double));
- double* sumbw = (double*)calloc(chain_length + 1, sizeof(double));
- double* sumsumfw = (double*)calloc(chain_length, sizeof(double));
-
- double total = 0;
- for (long i = 0; i < chain_length; ++i) {
- total += fw[i];
- sumfw[i] = total;
- }
-
- total = 0;
- for (long i = 0; i < chain_length + 1; ++i) {
- total += bw[i];
- sumbw[i] = total;
- }
-
- total = 0;
- for (long i = 0; i < chain_length; ++i) {
- total += sumfw[i];
- sumsumfw[i] = total;
- }
-
- for (long m = 0; m <= mmax; ++m)
- for (long i = 0; i <= chain_length; ++i) {
- // TODO: Can be optimized to remove the IF by reordering loops
- if ((m >= cbw[i]) && (m >= cw[i] + cbw[i + 1]))
- OPT(m, i, i) = bw[i];
- else
- OPT(m, i, i) = INFINITY;
-
- if (i < chain_length) {
- long maxC = fmaxl(cw[i], cw[i + 1]);
- long maxCB = fmaxl(cbw[i + 1], cbw[i + 2] + maxC);
- if ((m >= cbw[i]) && (m >= cw[i] + maxCB))
- OPT(m, i, i + 1) = fw[i] + bw[i] + bw[i + 1];
- else
- OPT(m, i, i + 1) = INFINITY;
- }
- }
-
- for (long m = 0; m <= mmax; ++m)
- for (long i = 0; i + 2 <= chain_length; ++i) {
- long mminCst = fmaxl(cbw[i], cbw[i + 1] + cw[i]);
- long maxCW_il = fmax(fmax(cw[i], cw[i + 1]), cw[i + 2]);
- long maxCostFWD = cw[i] + cbw[i + 2] + maxCW_il;
- for (long l = i + 2; l <= chain_length; ++l) {
- maxCW_il = fmax(maxCW_il, cw[l + 1]);
- maxCostFWD = fmaxl(maxCostFWD, cw[i] + cw[l + 1] + maxCW_il);
- long mmin = fmaxl(mminCst, maxCostFWD);
- if ((m >= mmin)) {
- double noCheckpointCost = sumbw[l] - (i > 0 ? sumbw[i - 1] : 0);
- noCheckpointCost +=
- sumsumfw[l - 1] -
- (i > 0 ? sumsumfw[i - 1] + (l - i) * sumfw[i - 1] : 0);
-
- double valueCost = INFINITY;
- if (m >= cw[i]) {
- double sumFwds = 0;
- for (long j = i + 1; j < l; ++j) {
- sumFwds += fw[j - 1];
- valueCost = fmin(
- valueCost, sumFwds + OPT(m - cw[i], j, l) + OPT(m, i, j - 1));
- }
- }
- OPT(m, i, l) = fmin(noCheckpointCost, valueCost);
- } else
- OPT(m, i, l) = INFINITY;
- }
- }
-
- free(sumfw);
- free(sumbw);
- free(sumsumfw);
- free(fw);
- free(bw);
- free(cw);
- free(cbw);
-
- PyObject* res_opt = PyList_New(mmax + 1);
-
- // Convert the result into Python world
- for (long m = 0; m <= mmax; ++m) {
- PyObject* res_opt_m = PyList_New(chain_length + 1);
- PyList_SET_ITEM(res_opt, m, res_opt_m);
- for (long i = 0; i <= chain_length; ++i) {
- PyObject* res_opt_m_i = PyDict_New();
- PyList_SET_ITEM(res_opt_m, i, res_opt_m_i);
- for (long l = i; l <= chain_length; ++l) {
- PyObject* res_l = PyLong_FromLong(l - i);
- PyObject* res_opt_m_i_l = PyFloat_FromDouble(OPT(m, i, l));
- PyDict_SetItem(res_opt_m_i, res_l, res_opt_m_i_l);
- Py_DECREF(res_opt_m_i_l);
- Py_DECREF(res_l);
- }
- }
- }
-
- free(opt);
-
- return res_opt;
-}
-
-static PyMethodDef dynamic_programs_methods[] = {
- {"persistent_compute_table", persistent_compute_table, METH_VARARGS,
- "Compute the optimal table with the persistent algorithm."},
- {"floating_compute_table", floating_compute_table, METH_VARARGS,
- "Compute the optimal table with the floating algorithm."},
- {"griewank_heterogeneous_compute_table",
- griewank_heterogeneous_compute_table, METH_VARARGS,
- "Compute the optimal table for the Griewank Heterogeneous Model."},
- {NULL, NULL, 0, NULL} /* Sentinel */
-};
-
-static struct PyModuleDef dynamic_programs_module = {
- PyModuleDef_HEAD_INIT, "dynamic_programs_C_version", /* name of module */
- NULL, /* module documentation, may be NULL */
- -1, /* size of per-interpreter state of the module,
- or -1 if the module keeps state in global variables. */
- dynamic_programs_methods};
-
-PyMODINIT_FUNC PyInit_dynamic_programs_C_version(void) {
- return PyModule_Create(&dynamic_programs_module);
-}
diff --git a/colossalai/fx/passes/algorithms/linearize.py b/colossalai/fx/passes/algorithms/linearize.py
deleted file mode 100644
index 1a49364f5a7c..000000000000
--- a/colossalai/fx/passes/algorithms/linearize.py
+++ /dev/null
@@ -1,94 +0,0 @@
-from typing import List, Any
-from torch.fx import GraphModule, Node
-from colossalai.fx.profiler import is_inplace
-
-# Common nodes are type of nodes that could be seen as attributes and remain
-# unchanged throughout the whole model, it will be used several times by
-# different blocks of model, so that it is hard for us to linearize the graph
-# when we encounter those kinds of nodes. We let users to annotate some of the
-# input as common node, such as attention mask, and the followings are some of
-# the ops that could actually be seen as common nodes. With our common node prop,
-# we could find some of the "real" common nodes (e.g. the real attention mask
-# used in BERT and GPT), the rule is simple, for node who's parents are all common
-# nodes or it's op belongs to the following operations, we view this node as a
-# newly born common node.
-# List of target name that could be seen as common node
-COPS = ["getattr", "getitem", "size"]
-
-
-def _is_cop(target: Any) -> bool:
- """Check if an op could be seen as common node
-
- Args:
- target (Any): node target
-
- Returns:
- bool
- """
-
- if isinstance(target, str):
- return target in COPS
- else:
- return target.__name__ in COPS
-
-
-def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
- """Linearizing the graph
-
- Args:
- gm (GraphModule): GraphModule derived by tracing
- cnode (List[str], optional): common node List, should be the subset of input. Default to None.
-
- Returns:
- List[List[Node]]: List of list, each inside list of Node presents
- the actual 'node' in linearized manner.
-
- Remarks:
- We merge the inplace ops into the previous node.
- """
-
- def _is_sink() -> bool:
- """Check if we can free all dependencies
-
- Returns:
- bool
- """
-
- return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
-
- # make sure that item in cnode is valid
- if cnode:
- for name in cnode:
- try:
- assert next(node for node in gm.graph.nodes if node.name == name).op == "placeholder", \
- f"common node {name} is not an input of the model"
- except StopIteration:
- raise ValueError(f"common node name {name} not in graph")
-
- else:
- cnode = []
-
- deps = {}
- linearized_nodes = []
- region = []
-
- for n in gm.graph.nodes:
- if n.op != "placeholder" and n.op != "output":
- for n_par in n._input_nodes:
- if n_par.op != "placeholder" and n_par.name not in cnode:
- deps[n_par] -= 1
- region.append(n)
-
- # if the node could free all dependencies in graph
- # we could begin a new node
- if _is_sink():
- linearized_nodes.append(region)
- region = []
-
- # propagate common node attr if possible
- if len(n._input_nodes) == len([node for node in n._input_nodes if node.name in cnode]) or _is_cop(n.target):
- cnode.append(n.name)
- else:
- deps[n] = len([user for user in n.users if user.op != "output"])
-
- return linearized_nodes
diff --git a/colossalai/fx/passes/algorithms/operation.py b/colossalai/fx/passes/algorithms/operation.py
deleted file mode 100644
index 8bfa3452ba64..000000000000
--- a/colossalai/fx/passes/algorithms/operation.py
+++ /dev/null
@@ -1,270 +0,0 @@
-import math
-
-
-def _discretize(mem_unit, values):
- return [math.ceil(value / mem_unit) for value in values]
-
-
-class Chain:
-
- def __init__(self, fw, bw, cw, cbw, ftmp, btmp, check=True):
- self.fweight = fw
- self.bweight = bw
- self.cweight = cw
- self.cbweight = cbw
- self.fwd_mem_tmp = ftmp
- self.bwd_mem_tmp = btmp
- self.length = len(fw)
- if check and not self.check_lengths():
- raise AttributeError("In Chain, input lists do not have consistent lengths")
-
- def check_lengths(self):
- return ((len(self.fweight) == self.length) and (len(self.bweight) == self.length + 1)
- and (len(self.cweight) == self.length + 1) and (len(self.fwd_mem_tmp) == self.length)
- and (len(self.bwd_mem_tmp) == self.length + 1) and (len(self.cbweight) == self.length + 1))
-
- def __repr__(self):
- chain_list = []
- for i in range(self.length):
- chain_list.append((self.fweight[i], self.bweight[i], self.cweight[i], self.cbweight[i], self.fwd_mem_tmp[i],
- self.bwd_mem_tmp[i]))
- i = self.length
- chain_list.append((None, self.bweight[i], self.cweight[i], self.cbweight[i], None, self.bwd_mem_tmp[i]))
- return chain_list.__repr__()
-
- def _discretize(self, mem_unit):
- self.cweight = _discretize(mem_unit, self.cweight)
- self.cbweight = _discretize(mem_unit, self.cbweight)
- self.fwd_mem_tmp = _discretize(mem_unit, self.fwd_mem_tmp)
- self.bwd_mem_tmp = _discretize(mem_unit, self.bwd_mem_tmp)
-
-
-class Operation:
-
- def shift(self, value):
- if type(self.index) is tuple:
- self.index = tuple(x + value for x in self.index)
- else:
- self.index += value
-
-
-class Offload(Operation):
-
- def __init__(self, index, has_bar=False) -> None:
- super().__init__()
- self.index = index
- self.name = "Off"
- self.has_bar = has_bar
- if self.has_bar:
- self.name += "wBar"
-
- def __repr__(self):
- return f"{self.name}_{self.index}"
-
-
-class Prefetch(Operation):
-
- def __init__(self, index, has_bar=False) -> None:
- super().__init__()
- self.index = index
- self.name = "Pre"
- self.has_bar = has_bar
- if self.has_bar:
- self.name += "wBar"
-
- def __repr__(self):
- return f"{self.name}_{self.index}"
-
-
-class Forward(Operation):
-
- def __init__(self, index):
- self.index = index
- self.name = "F"
-
- def __repr__(self):
- return "{n}_{i}".format(n=self.name, i=self.index)
-
- def cost(self, chain: Chain):
- if chain is not None:
- return chain.fweight[self.index]
- else:
- return 1
-
-
-class ForwardEnable(Forward):
-
- def __init__(self, index):
- super().__init__(index)
- self.name = "Fe"
-
-
-class ForwardNograd(Forward):
-
- def __init__(self, index):
- super().__init__(index)
- self.name = "Fn"
-
-
-class ForwardCheck(Forward):
-
- def __init__(self, index):
- super().__init__(index)
- self.name = "CF"
-
-
-class Forwards(Operation):
-
- def __init__(self, start, end):
- self.index = (start, end)
-
- def __repr__(self):
- return "F_{i}->{j}".format(i=self.index[0], j=self.index[1])
-
- def cost(self, chain: Chain):
- if chain is not None:
- return sum(chain.fweight[self.index[0]:self.index[1] + 1])
- else:
- return (self.index[1] - self.index[0] + 1)
-
-
-def isForward(op):
- return type(op) is Forward or type(op) is Forwards
-
-
-class Backward(Operation):
-
- def __init__(self, index):
- self.index = index
-
- def __repr__(self):
- return "B_{i}".format(i=self.index)
-
- def cost(self, chain: Chain):
- if chain is not None:
- return chain.bweight[self.index]
- else:
- return 1
-
-
-class Loss(Operation):
-
- def __init__(self):
- pass
-
- def __repr__(self):
- return "L"
-
- def cost(self, chain):
- return 0
-
-
-class MemoryAccess(Operation):
-
- def __init__(self, index):
- self.index = index
-
- def __repr__(self):
- return "{n}_{i}".format(n=self.name, i=self.index)
-
- def cost(self, chain: Chain):
- return 0
-
-
-class WriteMemory(MemoryAccess):
-
- def __init__(self, index):
- super().__init__(index)
- self.name = "WM"
-
-
-class ReadMemory(MemoryAccess):
-
- def __init__(self, index):
- super().__init__(index)
- self.name = "RM"
-
-
-class DiscardMemory(MemoryAccess):
-
- def __init__(self, index):
- super().__init__(index)
- self.name = "DM"
-
-
-class Function:
-
- def __init__(self, name, *args):
- self.name = name
- self.args = args
- self.str_args = ','.join(str(v) for v in self.args)
-
- def __repr__(self):
- return "{n}({args})".format(n=self.name, args=self.str_args)
-
-
-class Sequence:
-
- def __init__(self, function):
- self.sequence = [] #List of Operation and Sequence
- self.function = function #Description the function (name and parameters)
-
- def __repr__(self):
- return repr(self.list_operations())
-
- def list_operations(self):
- op_list = []
- for x in self.sequence:
- if isinstance(x, Operation):
- op_list.append(x)
- else:
- assert isinstance(x, Sequence)
- op_list += x.list_operations()
- return op_list
-
- def insert(self, operation):
- self.sequence.append(operation)
-
- def remove(self, operation_index):
- del self.sequence[operation_index]
-
- def insert_sequence(self, sequence):
- self.sequence.append(sequence)
-
- def shift(self, value):
- for x in self.sequence:
- x.shift(value)
- return self
-
- def remove_useless_write(self):
- if self.sequence:
- if isinstance(self.sequence[0], WriteMemory):
- self.remove(0)
- return self
-
- def get_makespan(self, chain):
- return sum(op.cost(chain) for op in self.list_operations())
-
- def without_suffix(self):
- ops = self.list_operations()
- end_of_first_phase = [i for i in range(len(ops)) if type(ops[i]) is Loss][0]
- try:
- last_idx = max(i for i in range(end_of_first_phase) if not type(ops[i]) is ForwardEnable)
- except ValueError:
- last_idx = -1
- if last_idx == end_of_first_phase - 1:
- return (self, None)
- chain_length = ops[end_of_first_phase -
- 1].index ## Some assumption here about the sequence (finishes with Forward_L
- start_of_fwd_enable_chain = ops[last_idx + 1].index ## And starts with B_L), but should be fine in practice
- result = Sequence(Function("Strip", self.function.name, *self.function.args, start_of_fwd_enable_chain))
- for i in range(last_idx + 1):
- result.insert(ops[i])
- result.insert(Loss())
- for i in range(chain_length, start_of_fwd_enable_chain - 1, -1):
- position = end_of_first_phase + 1 + (chain_length - i)
- assert type(ops[position]) is Backward
- assert ops[position].index == i
- for i in range(end_of_first_phase + 1 + 1 + chain_length - start_of_fwd_enable_chain, len(ops)):
- result.insert(ops[i])
- return (result, start_of_fwd_enable_chain)
diff --git a/colossalai/fx/passes/concrete_info_prop.py b/colossalai/fx/passes/concrete_info_prop.py
index ab38e8cb14e9..81ac64205528 100644
--- a/colossalai/fx/passes/concrete_info_prop.py
+++ b/colossalai/fx/passes/concrete_info_prop.py
@@ -226,7 +226,7 @@ def propagate(self, *args):
Returns:
Any: The value returned from executing the Module
"""
- return super().run(*args)
+ return self.run(*args)
def summary(self, unit: str = 'MB') -> str:
"""
diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py
index 5137494ada6f..2b4a8749cfd7 100644
--- a/colossalai/fx/passes/meta_info_prop.py
+++ b/colossalai/fx/passes/meta_info_prop.py
@@ -112,7 +112,9 @@ def extract_tensor_meta(obj):
n.meta['tensor_meta'] = tensor_meta
n.meta = {**n.meta, **asdict(meta_info)} # extend MetaInfo to `n.meta`
# TODO: the attribute node_size should be removed in the future
- setattr(n, 'node_size', activation_size(n.meta.get('fwd_in', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
+ setattr(n, 'node_size', activation_size(n.meta.get('fwd_out', 0)) + activation_size(n.meta.get('fwd_tmp', 0)))
+ setattr(n, 'fwd_flop', n.meta.get('fwd_flop', 0))
+ setattr(n, 'bwd_flop', n.meta.get('bwd_flop', 0))
n.meta['type'] = type(result)
# retain the autograd graph
@@ -286,13 +288,16 @@ def mem_repr(mem: int) -> str:
def flops_repr(flop: int) -> str:
return f"{flop:,} FLOPs"
+ accumulate_size = 0
for node in self.module.graph.nodes:
node: Node
+ accumulate_size += calculate_fwd_out(node) + calculate_fwd_tmp(node)
node_summaries.append([
node.op,
str(node),
flops_repr(node.meta['fwd_flop']),
flops_repr(node.meta['bwd_flop']),
+ mem_repr(accumulate_size),
mem_repr(calculate_fwd_in(node)),
mem_repr(calculate_fwd_out(node)),
mem_repr(calculate_fwd_tmp(node)),
@@ -307,6 +312,7 @@ def flops_repr(flop: int) -> str:
'Op',
'Forward FLOPs',
'Backward FLOPs',
+ 'Accumulated Memory',
'FWD_IN',
'FWD_OUT',
'FWD_TMP',
diff --git a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
index 2cf50133d3bd..8d1c8a8c6877 100644
--- a/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
+++ b/colossalai/fx/profiler/experimental/profiler_function/arithmetic.py
@@ -1,7 +1,12 @@
+# Copyright (c) Microsoft Corporation.
+
+# Licensed under the MIT License.
import operator
from functools import reduce
from typing import Any, Optional, Tuple, Union
+
import torch
+
from ..registry import meta_profiler_function
diff --git a/colossalai/fx/profiler/experimental/profiler_module/convolution.py b/colossalai/fx/profiler/experimental/profiler_module/convolution.py
index 3193489fee5e..a4c15b91e611 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/convolution.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/convolution.py
@@ -1,8 +1,13 @@
+# Copyright (c) Microsoft Corporation.
+
+# Licensed under the MIT License.
+import math
import operator
from functools import reduce
-import math
from typing import Tuple
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/experimental/profiler_module/normalization.py b/colossalai/fx/profiler/experimental/profiler_module/normalization.py
index e9939da7b1c4..49e5e6fa5384 100644
--- a/colossalai/fx/profiler/experimental/profiler_module/normalization.py
+++ b/colossalai/fx/profiler/experimental/profiler_module/normalization.py
@@ -1,5 +1,10 @@
+# Copyright (c) Microsoft Corporation.
+
+# Licensed under the MIT License.
from typing import Tuple, Union
+
import torch
+
from ..registry import meta_profiler_module
diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py
index 1c39dc247750..407a6bed5200 100644
--- a/colossalai/fx/profiler/opcount.py
+++ b/colossalai/fx/profiler/opcount.py
@@ -20,7 +20,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
# Inputs contains the shapes of two matrices.
input_shapes = [v.shape for v in inputs]
assert len(input_shapes) == 2, input_shapes
- assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
+
+ # There are three cases: 1) gemm, 2) gemv, 3) dot
+ if all(len(shape) == 2 for shape in input_shapes):
+ # gemm
+ assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
+ elif all(len(shape) == 1 for shape in input_shapes):
+ # dot
+ assert input_shapes[0][0] == input_shapes[1][0], input_shapes
+
+ # expand shape
+ input_shapes[0] = torch.Size([1, input_shapes[0][0]])
+ input_shapes[1] = torch.Size([input_shapes[1][0], 1])
+ else:
+ # gemv
+ if len(input_shapes[0]) == 1:
+ assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
+ input_shapes.reverse()
+ else:
+ assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
+
+ # expand the shape of the vector to [batch size, 1]
+ input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
return flops
@@ -70,6 +91,19 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
return flops
+def baddbmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
+ """
+ Count flops for the baddbmm(batch add and batch matmul) operation.
+ """
+ # Inputs = [input, batch1, batch2]
+ # out = input + batch1 x batch2
+ assert len(inputs) == 3, len(inputs)
+ n, c, t = inputs[1].shape
+ d = inputs[2].shape[-1]
+ flops = n * c * t * d
+ return flops
+
+
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
@@ -191,11 +225,14 @@ def zero_flop_jit(*args):
if version.parse(torch.__version__) >= version.parse('1.12.0'):
flop_mapping = {
- # gemm
+ # gemm, gemv and dot
aten.mm.default: matmul_flop_jit,
+ aten.mv.default: matmul_flop_jit,
+ aten.dot.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
+ aten.baddbmm.default: baddbmm_flop_jit,
# convolution
aten.convolution.default: conv_flop_jit,
@@ -209,6 +246,8 @@ def zero_flop_jit(*args):
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
+ aten.native_group_norm.default: norm_flop_counter(2, 0),
+ aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
# pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
@@ -230,6 +269,8 @@ def zero_flop_jit(*args):
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
aten.embedding.default: elementwise_flop_counter(1, 0),
+ aten.upsample_nearest2d.vec: elementwise_flop_counter(0, 1),
+ aten.upsample_nearest2d_backward.vec: elementwise_flop_counter(0, 1),
}
elementwise_flop_aten = [
@@ -249,6 +290,11 @@ def zero_flop_jit(*args):
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
+ aten.sub.Tensor,
+ aten.sub_.Tensor,
+ aten.exp.default,
+ aten.sin.default,
+ aten.cos.default,
# activation op
aten.hardswish.default,
@@ -301,6 +347,7 @@ def zero_flop_jit(*args):
aten.squeeze.dim,
aten.slice.Tensor,
aten.slice_backward.default,
+ aten.stack.default,
aten.split.Tensor,
aten.permute.default,
aten.t.default,
@@ -313,7 +360,9 @@ def zero_flop_jit(*args):
aten.where.self,
aten.zero_.default,
aten.zeros_like.default,
- ]
+ aten.fill_.Scalar,
+ aten.stack.default
+ ] # yapf: disable
for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit
diff --git a/colossalai/fx/profiler/tensor.py b/colossalai/fx/profiler/tensor.py
index 43165305f010..2ee5e5c47750 100644
--- a/colossalai/fx/profiler/tensor.py
+++ b/colossalai/fx/profiler/tensor.py
@@ -1,6 +1,4 @@
import uuid
-from copy import deepcopy
-from typing import Optional
import torch
from torch.types import _bool, _device, _dtype
@@ -28,8 +26,6 @@ class MetaTensor(torch.Tensor):
_tensor: torch.Tensor
- __slots__ = ['_tensor']
-
@staticmethod
def __new__(cls, elem, fake_device=None):
# Avoid multiple wrapping
@@ -47,7 +43,7 @@ def __new__(cls, elem, fake_device=None):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
- device=fake_device if fake_device is not None else elem.device,
+ device=fake_device or (elem.device if elem.device.type != 'meta' else torch.device('cpu')),
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
@@ -59,8 +55,8 @@ def __new__(cls, elem, fake_device=None):
def __repr__(self):
if self.grad_fn:
- return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
- return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
+ return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
+ return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@@ -76,13 +72,13 @@ def unwrap(x):
x = x.to(torch.device('meta'))
return x
+ args = tree_map(unwrap, args)
+ kwargs = tree_map(unwrap, kwargs)
+
if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')
- args = tree_map(unwrap, args)
- kwargs = tree_map(unwrap, kwargs)
-
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
@@ -118,23 +114,24 @@ def to(self, *args, **kwargs) -> torch.Tensor:
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
- device = None
- for arg in args:
- if isinstance(arg, str) or isinstance(arg, _device):
- device = arg
- if 'device' in kwargs:
- device = kwargs['device']
- result = super().to(*args, **kwargs)
- if device is not None:
- result = MetaTensor(result, fake_device=device)
- return result
+ fake_device = None
+
+ def replace(x):
+ nonlocal fake_device
+ if isinstance(x, str) or isinstance(x, _device):
+ fake_device = x
+ return 'meta'
+ return x
+
+ elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
+ return MetaTensor(elem, fake_device=fake_device)
def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)
- def cuda(self, *args, **kwargs):
- if self.device.type == 'cuda':
- return self.to(*args, **kwargs)
- return self.to(*args, device='cuda', **kwargs)
+ def cuda(self, device=None, non_blocking=False):
+ if device is not None:
+ return self.to(device=device, non_blocking=non_blocking)
+ return self.to(device='cuda:0', non_blocking=non_blocking)
diff --git a/colossalai/fx/tracer/_symbolic_trace.py b/colossalai/fx/tracer/_symbolic_trace.py
index bff2f6a10fa6..5c04eeace0ad 100644
--- a/colossalai/fx/tracer/_symbolic_trace.py
+++ b/colossalai/fx/tracer/_symbolic_trace.py
@@ -13,6 +13,7 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
+ trace_act_ckpt=False,
) -> ColoGraphModule:
"""
Symbolic tracing API
@@ -49,6 +50,6 @@ def symbolic_trace(
This API is still under development and can incur some bugs. Feel free to report any bugs to the Colossal-AI team.
"""
- graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
+ graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root, concrete_args=concrete_args, meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
diff --git a/colossalai/fx/tracer/experimental.py b/colossalai/fx/tracer/experimental.py
index 6fee5f5d061d..88b65b6188fa 100644
--- a/colossalai/fx/tracer/experimental.py
+++ b/colossalai/fx/tracer/experimental.py
@@ -1,7 +1,7 @@
import enum
import functools
-import operator
import inspect
+import operator
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -286,7 +286,6 @@ def _check_arg_name_valid(names):
self.graph.lint()
return self.graph
-
@contextmanager
def trace_activation_checkpoint(self, enabled: bool):
if enabled:
@@ -316,7 +315,6 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any:
# recover the checkpoint function upon exit
torch.utils.checkpoint.CheckpointFunction = orig_ckpt_func
-
def _post_check(self, non_concrete_arg_names: Set[str]):
# This is necessary because concrete args are added as input to the traced module since
# https://github.com/pytorch/pytorch/pull/55888.
@@ -385,18 +383,23 @@ def symbolic_trace(
root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None,
meta_args: Optional[Dict[str, Any]] = None,
+ trace_act_ckpt=False,
) -> ColoGraphModule:
if is_compatible_with_meta():
if meta_args is not None:
root.to(default_device())
wrap_fn = lambda x: MetaTensor(x, fake_device=default_device()) if isinstance(x, torch.Tensor) else x
- graph = ColoTracer().trace(root, concrete_args=concrete_args, meta_args=tree_map(wrap_fn, meta_args))
+ graph = ColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
+ concrete_args=concrete_args,
+ meta_args=tree_map(wrap_fn, meta_args))
root.cpu()
else:
graph = Tracer().trace(root, concrete_args=concrete_args)
else:
from .tracer import ColoTracer as OrigColoTracer
- graph = OrigColoTracer().trace(root, concrete_args=concrete_args, meta_args=meta_args)
+ graph = OrigColoTracer(trace_act_ckpt=trace_act_ckpt).trace(root,
+ concrete_args=concrete_args,
+ meta_args=meta_args)
name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
return ColoGraphModule(root, graph, name)
@@ -471,11 +474,11 @@ def meta_prop_pass(gm: ColoGraphModule,
node._meta_data = _meta_data_computing(meta_args, concrete_args, root, node.op, node.target, node.args,
node.kwargs)
+
def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwargs):
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
if kind == 'placeholder':
- meta_out = meta_args[target] if target in meta_args else concrete_args.get(
- _truncate_suffix(target), None)
+ meta_out = meta_args[target] if target in meta_args else concrete_args.get(_truncate_suffix(target), None)
elif kind == 'get_attr':
attr_itr = root
atoms = target.split(".")
@@ -490,7 +493,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
else:
if target not in _TensorPropertyMethod:
meta_out = getattr(unwrap_fn(args[0]), target)(*tree_map(unwrap_fn, args[1:]),
- **tree_map(unwrap_fn, kwargs))
+ **tree_map(unwrap_fn, kwargs))
elif kind == 'call_module':
mod = root.get_submodule(target)
meta_out = mod.forward(*tree_map(unwrap_fn, args), **tree_map(unwrap_fn, kwargs))
@@ -498,6 +501,7 @@ def _meta_data_computing(meta_args, concrete_args, root, kind, target, args, kwa
meta_out = None
return meta_out
+
def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
if kind == "placeholder" and target in meta_args and meta_args[target].is_meta:
meta_out = meta_args[target]
@@ -568,7 +572,7 @@ def _meta_data_computing_v0(meta_args, root, kind, target, args, kwargs):
return meta_out
-def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]]=None):
+def bias_addition_pass(gm: ColoGraphModule, root_model: torch.nn.Module, meta_args: Optional[Dict[str, Any]] = None):
result_graph = Graph()
value_remap = {}
unwrap_fn = lambda n: n._meta_data if isinstance(n, Node) else n
@@ -601,20 +605,24 @@ def wrap_fn(n):
if target == torch.nn.functional.linear:
if 'bias' in kwargs and kwargs['bias'] is not None:
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
+ handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
+ function_to_substitute)
else:
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
+ handle = bias_addition_function.get(target)(tracer, target, args_proxy, kwargs_proxy,
+ function_to_substitute)
elif bias_addition_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
function_to_substitute = func_to_func_dict[target]
- handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
+ handle = bias_addition_function.get(target.__name__)(tracer, target, args_proxy, kwargs_proxy,
+ function_to_substitute)
elif kind == "call_method":
method = getattr(args_metas[0].__class__, target)
if bias_addition_method.has(method):
function_to_substitute = method_to_func_dict[method]
- handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
+ handle = bias_addition_method.get(method)(tracer, target, args_proxy, kwargs_proxy,
+ function_to_substitute)
elif kind == "call_module":
# if not hasattr(self, "orig_forward"):
@@ -623,20 +631,20 @@ def wrap_fn(n):
mod_type = type(mod)
if bias_addition_module.has(mod_type) and mod.bias is not None:
function_to_substitute = module_to_func_dict[mod_type]
- handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy, function_to_substitute)
+ handle = bias_addition_module.get(mod_type)(tracer, target, args_proxy, kwargs_proxy,
+ function_to_substitute)
if handle is not None:
handle.generate()
for node_inserted in tracer.graph.nodes:
- value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n : value_remap[n])
+ value_remap[node_inserted] = result_graph.node_copy(node_inserted, lambda n: value_remap[n])
last_node = value_remap[node_inserted]
value_remap[orig_node] = last_node
else:
- value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n : value_remap[n])
+ value_remap[orig_node] = result_graph.node_copy(orig_node, lambda n: value_remap[n])
del tracer
gm.graph = result_graph
gm.recompile()
meta_prop_pass(gm, root_model, meta_args)
-
diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py
index 07fb6c48b2d7..30ac4d354647 100644
--- a/colossalai/gemini/chunk/manager.py
+++ b/colossalai/gemini/chunk/manager.py
@@ -72,6 +72,9 @@ def register_tensor(self,
if tensor.numel() > chunk_size:
chunk_size = tensor.numel()
+ dp_size = tensor.process_group.dp_world_size()
+ chunk_size = chunk_size + (-chunk_size % dp_size)
+
chunk = Chunk(
chunk_size=chunk_size,
process_group=tensor.process_group,
@@ -140,6 +143,14 @@ def reduce_chunk(self, chunk: Chunk) -> bool:
self.__add_memory_usage(chunk.memory_usage)
return True
+ def fake_release_chunk(self, chunk: Chunk) -> None:
+ """Release gathered chunk in a fake mode.
+ This function is used for keep-gathered chunk in the inference mode.
+ """
+ assert chunk.keep_gathered
+ assert chunk.tensor_state_cnter[TensorState.HOLD] == chunk.num_tensors
+ self.__sub_accessed_chunk(chunk)
+
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data: torch.Tensor) -> None:
"""
Copy data to the chunk.
diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py
index 312d77f1826c..fe9650721d74 100644
--- a/colossalai/gemini/chunk/search_utils.py
+++ b/colossalai/gemini/chunk/search_utils.py
@@ -2,22 +2,26 @@
from typing import Dict, List, Optional, Tuple
import numpy as np
+import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter
-
-
-def in_ddp(param: nn.Parameter) -> bool:
- return not getattr(param, '_ddp_to_ignore', False)
+from colossalai.utils import is_ddp_ignored
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
"""
- params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
- params_size_arr = np.array(params_size)
+ agg_size_list = []
+ for key in size_dict:
+ agg_size_list.extend(size_dict[key])
+
+ if len(agg_size_list) == 0:
+ return
+
+ params_size_arr = np.array(agg_size_list)
std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
@@ -41,7 +45,15 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc
-def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int, List[ColoParameter]]:
+def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
+ if strict_ddp_flag:
+ return local_param.numel_global()
+ else:
+ return local_param.numel()
+
+
+def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
+ strict_ddp_flag: bool = False) -> Dict[int, List[ColoParameter]]:
"""classify_params_by_dp_degree
Classify the parameters by their dp degree
@@ -56,10 +68,13 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator) -> Dict[int
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in param_order.generate():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
- if not in_ddp(param):
+ if is_ddp_ignored(param):
continue
- param_key = param.process_group.dp_world_size()
+ if strict_ddp_flag:
+ param_key = dist.get_world_size()
+ else:
+ param_key = param.process_group.dp_world_size()
if param_key not in params_dict:
params_dict[param_key] = []
@@ -74,14 +89,18 @@ def search_chunk_configuration(
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True,
- memstas: Optional[MemStats] = None) -> Tuple[Dict, int]:
+ strict_ddp_flag: bool = False,
+ memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.
search_interval_byte (int): searching interval in byte.
+ min_chunk_size_mb (float, optional): the minimum size of a distributed chunk.
filter_exlarge_params (bool, optional): filter extreme large parameters. Defaults to True.
+ strict_ddp_flag (bool, optional): whether to enable the strict ddp mode.
+ all parameters keep replicated in this mode.
Returns:
Tuple[Dict, int]: chunk config (a dict of dp_degree -> chunk init args) and its memory chunk waste in byte.
@@ -99,17 +118,21 @@ def search_chunk_configuration(
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0
- params_dict = classify_params_by_dp_degree(param_order)
+ params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag)
+ size_lcm = np.lcm.reduce(list(params_dict.keys()))
config_dict: Dict[int, Dict] = dict()
+ total_param_size = 0
size_dict: Dict[int, List[int]] = dict()
for dp_degree in params_dict:
params_list = params_dict[dp_degree]
- size_list = [p.numel() for p in params_list]
+ size_list = [_tensor_numel(p, strict_ddp_flag) for p in params_list]
+ group_acc_size = sum(size_list)
+ total_param_size += group_acc_size
+
# let small parameters keep gathered in CUDA all the time
- total_size = sum(size_list)
- if total_size < min_chunk_size_byte:
- config_dict[dp_degree] = dict(chunk_size=total_size, keep_gathered=True)
+ if group_acc_size < min_chunk_size_byte:
+ config_dict[dp_degree] = dict(chunk_size=group_acc_size, keep_gathered=True)
else:
size_dict[dp_degree] = size_list
@@ -132,9 +155,11 @@ def search_chunk_configuration(
min_chunk_waste = temp_waste
best_chunk_size = chunk_size
+ # the chunk size needs to be divided by each groups sizes
+ best_chunk_size = best_chunk_size + (-best_chunk_size % size_lcm)
for dp_degree in params_dict:
if dp_degree in config_dict:
continue
config_dict[dp_degree] = dict(chunk_size=best_chunk_size, keep_gathered=False)
- return config_dict, min_chunk_waste
+ return config_dict, total_param_size, min_chunk_waste
diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/gemini/chunk/utils.py
index e9a9f84e7a93..83512b8e0ee5 100644
--- a/colossalai/gemini/chunk/utils.py
+++ b/colossalai/gemini/chunk/utils.py
@@ -6,51 +6,42 @@
import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
-from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration
-from colossalai.gemini.memory_tracer import MemStats
+from colossalai.gemini.chunk.search_utils import search_chunk_configuration
+from colossalai.utils import is_ddp_ignored
+
+
+def safe_div(a, b):
+ if a == 0:
+ return 0
+ return a / b
def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
- search_range_mb: Optional[float] = None,
- min_chunk_size_mb: Optional[float] = None,
- filter_exlarge_params: Optional[bool] = None) -> ChunkManager:
-
- kwargs_dict = dict()
-
+ **kwargs) -> ChunkManager:
if hidden_dim:
search_interval_byte = hidden_dim
else:
- search_interval_byte = 1024 # 1kb
- kwargs_dict["search_interval_byte"] = search_interval_byte
-
- if search_range_mb:
- kwargs_dict["search_range_mb"] = search_range_mb
-
- if min_chunk_size_mb:
- kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb
-
- if filter_exlarge_params:
- kwargs_dict["filter_exlarge_params"] = filter_exlarge_params
-
- params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
- total_size = sum(params_sizes) / 1024**2
+ search_interval_byte = 1024 # defaults to 1kb
+ kwargs["search_interval_byte"] = search_interval_byte
dist.barrier()
begin = time()
- config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)
+ config_dict, total_size, wasted_size = search_chunk_configuration(model, **kwargs)
dist.barrier()
end = time()
span_s = end - begin
- wasted_size /= 1024**2
+ mb_size = 1024**2
+ total_size /= mb_size
+ wasted_size /= mb_size
if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
- "total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
+ "total wasted percentage is {:.2f}%".format(100 * safe_div(wasted_size, total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()
diff --git a/colossalai/gemini/gemini_context.py b/colossalai/gemini/gemini_context.py
index 98c8a914e5ca..9a7da6b80fba 100644
--- a/colossalai/gemini/gemini_context.py
+++ b/colossalai/gemini/gemini_context.py
@@ -1,48 +1,48 @@
-from enum import EnumMeta
-
-
-class GeminiMemoryManager(object):
-
- def __init__(self, states_cls: EnumMeta):
- super().__init__()
- self.states_cls = states_cls
- self._cnter = 0 # the counter of instances
-
- self.total_mem = dict()
- self.state_mem = dict()
- self.state_mem['cpu'] = dict()
- self.state_mem['cuda'] = dict()
-
- self.reset()
-
- @property
- def total_number(self):
- return self._cnter
-
- def reset(self):
- self._cnter = 0 # the counter of instances
-
- self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
- self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
-
- # memory conditions for all states
- for state in self.states_cls:
- self.state_mem['cpu'][state] = 0
- self.state_mem['cuda'][state] = 0
-
- def register_new_instance(self):
- self._cnter += 1
-
- def delete_instance(self):
- self._cnter -= 1
-
- def print_info(self):
- print(f"Total number: {self.total_number}",
- f"Total CPU memory occupation: {self.total_mem['cpu']}",
- f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
- sep='\n')
-
- for state in self.states_cls:
- print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
- f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
- sep='\n')
+from enum import EnumMeta
+
+
+class GeminiMemoryManager(object):
+
+ def __init__(self, states_cls: EnumMeta):
+ super().__init__()
+ self.states_cls = states_cls
+ self._cnter = 0 # the counter of instances
+
+ self.total_mem = dict()
+ self.state_mem = dict()
+ self.state_mem['cpu'] = dict()
+ self.state_mem['cuda'] = dict()
+
+ self.reset()
+
+ @property
+ def total_number(self):
+ return self._cnter
+
+ def reset(self):
+ self._cnter = 0 # the counter of instances
+
+ self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
+ self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
+
+ # memory conditions for all states
+ for state in self.states_cls:
+ self.state_mem['cpu'][state] = 0
+ self.state_mem['cuda'][state] = 0
+
+ def register_new_instance(self):
+ self._cnter += 1
+
+ def delete_instance(self):
+ self._cnter -= 1
+
+ def print_info(self):
+ print(f"Total number: {self.total_number}",
+ f"Total CPU memory occupation: {self.total_mem['cpu']}",
+ f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
+ sep='\n')
+
+ for state in self.states_cls:
+ print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
+ f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
+ sep='\n')
diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py
index 08961b95832a..72a5e4a7f19b 100644
--- a/colossalai/gemini/gemini_mgr.py
+++ b/colossalai/gemini/gemini_mgr.py
@@ -50,6 +50,21 @@ def __init__(self, placement_policy: str, chunk_manager: ChunkManager, memstats:
self._warmup = True
self._comp_cuda_demand_time = 0
+ def reset_attributes(self):
+ self._compute_idx = -1
+ self._h2d_volume = 0
+ self._d2h_volume = 0
+ self._layout_time = 0
+ self._evict_time = 0
+ self._comp_cuda_demand_time = 0
+
+ @property
+ def need_warmup(self) -> bool:
+ return self.policy_name in ('auto', 'const')
+
+ def is_warmup(self):
+ return self._warmup
+
def memstats(self):
"""memstats
@@ -73,12 +88,7 @@ def post_iter(self):
if self._mem_stats_collector and self._warmup:
self._mem_stats_collector.finish_collection()
self._warmup = False
- self._compute_idx = -1
- self._h2d_volume = 0
- self._d2h_volume = 0
- self._layout_time = 0
- self._evict_time = 0
- self._comp_cuda_demand_time = 0
+ self.reset_attributes()
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
""" Adjust the layout of stateful tensors according to the information provided
diff --git a/colossalai/gemini/ophooks/utils.py b/colossalai/gemini/ophooks/utils.py
index fe08405c82bf..84e8298c1d51 100644
--- a/colossalai/gemini/ophooks/utils.py
+++ b/colossalai/gemini/ophooks/utils.py
@@ -1,7 +1,7 @@
-import torch
-from typing import List, Callable, Optional
-
+# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from abc import ABC, abstractmethod
+from typing import Callable, List, Optional
+
import torch
diff --git a/colossalai/initialize.py b/colossalai/initialize.py
index e907efddee69..f3719dcb47b3 100644
--- a/colossalai/initialize.py
+++ b/colossalai/initialize.py
@@ -15,26 +15,25 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
-from colossalai.core import global_context as gpc
-from colossalai.context.moe_context import MOE_CONTEXT
-
-from colossalai.logging import get_dist_logger
-
-from colossalai.engine.schedule import NonPipelineSchedule, PipelineSchedule, InterleavedPipelineSchedule, get_tensor_shape
-from colossalai.engine import Engine
-from colossalai.gemini.ophooks import BaseOpHook
-
-from colossalai.utils import (get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param)
-from colossalai.utils.moe import sync_moe_model_param
-
from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
+from colossalai.context.moe_context import MOE_CONTEXT
+from colossalai.core import global_context as gpc
+from colossalai.engine import Engine
from colossalai.engine.gradient_accumulation import accumulate_gradient
-
+from colossalai.engine.schedule import (
+ InterleavedPipelineSchedule,
+ NonPipelineSchedule,
+ PipelineSchedule,
+ get_tensor_shape,
+)
+from colossalai.gemini.ophooks import BaseOpHook
+from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
-
+from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
+from colossalai.utils.moe import sync_moe_model_param
from colossalai.zero import convert_to_zero_v2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2
@@ -301,9 +300,9 @@ def initialize(model: nn.Module,
model = model().to(get_current_device())
# optimizer maybe a optimizer_cls
- logger.warning("Initializing an non ZeRO model with optimizer class")
if isinstance(optimizer, Callable):
optimizer = optimizer(model.parameters())
+ logger.warning("Initializing an non ZeRO model with optimizer class")
if not use_zero:
if is_using_sequence():
diff --git a/colossalai/kernel/cuda_native/__init__.py b/colossalai/kernel/cuda_native/__init__.py
index 8f857ff5d9f1..1d5a6ce495bd 100644
--- a/colossalai/kernel/cuda_native/__init__.py
+++ b/colossalai/kernel/cuda_native/__init__.py
@@ -1,3 +1,5 @@
from .layer_norm import MixedFusedLayerNorm as LayerNorm
from .multihead_attention import MultiHeadAttention
-from .scaled_softmax import FusedScaleMaskSoftmax
+from .scaled_softmax import FusedScaleMaskSoftmax, ScaledUpperTriangMaskedSoftmax
+
+__all__ = ['LayerNorm', 'MultiHeadAttention', 'FusedScaleMaskSoftmax', 'ScaledUpperTriangMaskedSoftmax']
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu
index 68be1f6d7a22..09f34763f9b2 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu
+++ b/colossalai/kernel/cuda_native/csrc/kernels/cublas_wrappers.cu
@@ -1,6 +1,7 @@
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
+ Licensed under the MIT License.
*/
#include "cublas_wrappers.h"
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h
index 7ebb9ce48ed3..90255152b2c8 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h
+++ b/colossalai/kernel/cuda_native/csrc/kernels/include/cublas_wrappers.h
@@ -1,6 +1,7 @@
/* Copyright 2021 The LightSeq Team
Copyright Microsoft DeepSpeed
This file is adapted from Microsoft DeepSpeed
+ Licensed under the MIT License.
*/
#pragma once
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h
index ec963259f738..8186da1eed5f 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h
+++ b/colossalai/kernel/cuda_native/csrc/kernels/include/feed_forward.h
@@ -1,68 +1,69 @@
-#pragma once
-
-/* Copyright 2021 The LightSeq Team
- Copyright Microsoft DeepSpeed
- This file is adapted from Microsoft DeepSpeed
-*/
-#include
-#include
-#include
-
-#include
-
-#include "cublas_wrappers.h"
-#include "kernels.h"
-
-template
-class FeedForward {
- public:
- struct Config {
- int outputSize;
- int inputSize;
- std::array gemm_algos;
- Config(int outputs, int inputs)
- : outputSize(outputs),
- inputSize(inputs),
- gemm_algos(std::array({99, 99, 99})) {}
- };
-
- FeedForward(Config config) : config_(config) {}
-
- ~FeedForward() {}
-
- void Forward(int bsz, const T *input_ptr, const T *weights, T *out,
- cublasHandle_t &_cublasHandle) {
- float alpha = T(1.);
- float beta = T(0.);
-
- cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize,
- bsz, config_.inputSize, &alpha, &beta, weights, input_ptr,
- out, cublasGemmAlgo_t(config_.gemm_algos[0]));
- }
- void Backward(int bsz, const T *out_grad, const T *input_ptr,
- const T *weights, T *weights_grad, T *bias_grad,
- cublasHandle_t &_cublasHandle, cudaStream_t &stream,
- T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr,
- bool compute_bias = true) {
- float alpha = (T)1.0, beta = (T)0.0;
- cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize,
- config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad,
- weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1]));
-
- cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize,
- bsz, config_.outputSize, &alpha, &beta, weights, out_grad,
- inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2]));
- if (compute_bias) {
- launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz,
- config_.outputSize, stream);
- }
- }
-
- void reset_size(int outputSize, int inputSize) {
- config_.outputSize = outputSize;
- config_.inputSize = inputSize;
- }
-
- private:
- Config config_;
-};
+#pragma once
+
+/* Copyright 2021 The LightSeq Team
+ Copyright Microsoft DeepSpeed
+ This file is adapted from Microsoft DeepSpeed
+ Licensed under the MIT License.
+*/
+#include
+#include
+#include
+
+#include
+
+#include "cublas_wrappers.h"
+#include "kernels.h"
+
+template
+class FeedForward {
+ public:
+ struct Config {
+ int outputSize;
+ int inputSize;
+ std::array gemm_algos;
+ Config(int outputs, int inputs)
+ : outputSize(outputs),
+ inputSize(inputs),
+ gemm_algos(std::array({99, 99, 99})) {}
+ };
+
+ FeedForward(Config config) : config_(config) {}
+
+ ~FeedForward() {}
+
+ void Forward(int bsz, const T *input_ptr, const T *weights, T *out,
+ cublasHandle_t &_cublasHandle) {
+ float alpha = T(1.);
+ float beta = T(0.);
+
+ cublas_gemm_ex(_cublasHandle, CUBLAS_OP_T, CUBLAS_OP_N, config_.outputSize,
+ bsz, config_.inputSize, &alpha, &beta, weights, input_ptr,
+ out, cublasGemmAlgo_t(config_.gemm_algos[0]));
+ }
+ void Backward(int bsz, const T *out_grad, const T *input_ptr,
+ const T *weights, T *weights_grad, T *bias_grad,
+ cublasHandle_t &_cublasHandle, cudaStream_t &stream,
+ T *inp_grad_out = nullptr, T *out_grad_trans_out = nullptr,
+ bool compute_bias = true) {
+ float alpha = (T)1.0, beta = (T)0.0;
+ cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_T, config_.inputSize,
+ config_.outputSize, bsz, &alpha, &beta, input_ptr, out_grad,
+ weights_grad, cublasGemmAlgo_t(config_.gemm_algos[1]));
+
+ cublas_gemm_ex(_cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N, config_.inputSize,
+ bsz, config_.outputSize, &alpha, &beta, weights, out_grad,
+ inp_grad_out, cublasGemmAlgo_t(config_.gemm_algos[2]));
+ if (compute_bias) {
+ launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz,
+ config_.outputSize, stream);
+ }
+ }
+
+ void reset_size(int outputSize, int inputSize) {
+ config_.outputSize = outputSize;
+ config_.inputSize = inputSize;
+ }
+
+ private:
+ Config config_;
+};
diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h
index 3120660b98be..d386650e8235 100644
--- a/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h
+++ b/colossalai/kernel/cuda_native/csrc/kernels/include/strided_batch_gemm.h
@@ -1,99 +1,100 @@
-/* Copyright 2021 The LightSeq Team
- Copyright Microsoft DeepSpeed
- This file is adapted from Microsoft DeepSpeed
-*/
-#pragma once
-
-#include
-#include
-#include
-
-#include
-
-#include "cublas_wrappers.h"
-
-template
-class StridedBatchGemm {
- public:
- struct Config {
- int m;
- int n;
- int k;
- float alpha;
- float beta;
- cublasOperation_t op_A;
- cublasOperation_t op_B;
- std::array gemm_algos;
-
- Config(float param_alpha, float param_beta, cublasOperation_t opA,
- cublasOperation_t opB)
- : alpha(param_alpha),
- beta(param_beta),
- op_A(opA),
- op_B(opB),
- gemm_algos(std::array({99, 99, 99})) {}
- void SetConfig(int mm, int nn, int kk) {
- m = mm;
- n = nn;
- k = kk;
- }
- };
-
- StridedBatchGemm(const Config &config) : _config(config) {}
-
- virtual ~StridedBatchGemm() {}
-
- void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b,
- cublasHandle_t handle) {
- int stride_a = _config.m * _config.k;
- int stride_b = _config.n * _config.k;
- int stride_c = _config.m * _config.n;
-
- cublas_strided_batched_gemm(
- handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta,
- _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a,
- stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0]));
- }
-
- void Backward(int bsz, const T *d_output, const T *_buffer_a,
- const T *_buffer_b, cublasHandle_t handle,
- T *inpGradA = nullptr, T *inpGradB = nullptr) {
- int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
- int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
-
- int stride_a = mb * _config.n;
- int stride_b = _config.n * kb;
- int stride_c = _config.m * _config.k;
-
- // B need to transpose.
- cublasOperation_t op_b =
- (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
-
- // Calculate d_A.
- cublas_strided_batched_gemm(
- handle, mb, kb, _config.n, &_config.alpha, &_config.beta,
- (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
- (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA,
- CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz,
- cublasGemmAlgo_t(_config.gemm_algos[1]));
-
- // A need to transpose.
- cublasOperation_t op_a =
- (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
-
- stride_a = _config.m * _config.k;
- stride_b = _config.m * _config.n;
- stride_c = _config.n * _config.k;
-
- // Calculate d_B.
- cublas_strided_batched_gemm(
- handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta,
- _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b,
- stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2]));
- }
-
- inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
-
- private:
- Config _config;
-};
+/* Copyright 2021 The LightSeq Team
+ Copyright Microsoft DeepSpeed
+ This file is adapted from Microsoft DeepSpeed
+ Licensed under the MIT License.
+*/
+#pragma once
+
+#include
+#include
+#include
+
+#include
+
+#include "cublas_wrappers.h"
+
+template
+class StridedBatchGemm {
+ public:
+ struct Config {
+ int m;
+ int n;
+ int k;
+ float alpha;
+ float beta;
+ cublasOperation_t op_A;
+ cublasOperation_t op_B;
+ std::array gemm_algos;
+
+ Config(float param_alpha, float param_beta, cublasOperation_t opA,
+ cublasOperation_t opB)
+ : alpha(param_alpha),
+ beta(param_beta),
+ op_A(opA),
+ op_B(opB),
+ gemm_algos(std::array({99, 99, 99})) {}
+ void SetConfig(int mm, int nn, int kk) {
+ m = mm;
+ n = nn;
+ k = kk;
+ }
+ };
+
+ StridedBatchGemm(const Config &config) : _config(config) {}
+
+ virtual ~StridedBatchGemm() {}
+
+ void Forward(int bsz, T *output, const T *_buffer_a, const T *_buffer_b,
+ cublasHandle_t handle) {
+ int stride_a = _config.m * _config.k;
+ int stride_b = _config.n * _config.k;
+ int stride_c = _config.m * _config.n;
+
+ cublas_strided_batched_gemm(
+ handle, _config.m, _config.n, _config.k, &_config.alpha, &_config.beta,
+ _buffer_a, _buffer_b, output, _config.op_A, _config.op_B, stride_a,
+ stride_b, stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[0]));
+ }
+
+ void Backward(int bsz, const T *d_output, const T *_buffer_a,
+ const T *_buffer_b, cublasHandle_t handle,
+ T *inpGradA = nullptr, T *inpGradB = nullptr) {
+ int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
+ int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
+
+ int stride_a = mb * _config.n;
+ int stride_b = _config.n * kb;
+ int stride_c = _config.m * _config.k;
+
+ // B need to transpose.
+ cublasOperation_t op_b =
+ (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
+
+ // Calculate d_A.
+ cublas_strided_batched_gemm(
+ handle, mb, kb, _config.n, &_config.alpha, &_config.beta,
+ (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
+ (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), inpGradA,
+ CUBLAS_OP_N, op_b, stride_a, stride_b, stride_c, bsz,
+ cublasGemmAlgo_t(_config.gemm_algos[1]));
+
+ // A need to transpose.
+ cublasOperation_t op_a =
+ (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
+
+ stride_a = _config.m * _config.k;
+ stride_b = _config.m * _config.n;
+ stride_c = _config.n * _config.k;
+
+ // Calculate d_B.
+ cublas_strided_batched_gemm(
+ handle, _config.k, _config.n, _config.m, &_config.alpha, &_config.beta,
+ _buffer_a, d_output, inpGradB, op_a, CUBLAS_OP_N, stride_a, stride_b,
+ stride_c, bsz, cublasGemmAlgo_t(_config.gemm_algos[2]));
+ }
+
+ inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
+
+ private:
+ Config _config;
+};
diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
index afd34bb96352..9cc3ae1eac10 100644
--- a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
+++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
@@ -1,5 +1,10 @@
// modified from
// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_adam.cu
+/* Copyright 2020 The Microsoft DeepSpeed Team
+ Copyright NVIDIA/apex
+ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
+ Licensed under the MIT License.
+*/
#include
#include
#include
diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh b/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh
index 9ce41191133e..ec55dd320b40 100644
--- a/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh
+++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_apply.cuh
@@ -1,12 +1,18 @@
-// modified from https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
+// modified from
+// https://github.com/NVIDIA/apex/blob/master/csrc/multi_tensor_apply.cuh
+/* Copyright 2020 The Microsoft DeepSpeed Team
+ Copyright NVIDIA/apex
+ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
+ Licensed under the MIT License.
+*/
#include
#include
#include
#include
+#include
#include
-#include "compat.h"
-#include
+#include "compat.h"
// #include
@@ -17,117 +23,108 @@ constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template
-struct TensorListMetadata
-{
- void *addresses[n][depth_to_max_tensors[n - 1]];
- int sizes[depth_to_max_tensors[n - 1]];
- unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
- int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
- int start_tensor_this_launch;
+struct TensorListMetadata {
+ void *addresses[n][depth_to_max_tensors[n - 1]];
+ int sizes[depth_to_max_tensors[n - 1]];
+ unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
+ int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a
+ // full int.
+ int start_tensor_this_launch;
};
template
-__global__ void multi_tensor_apply_kernel(
- int chunk_size,
- volatile int *noop_flag,
- T tl,
- U callable,
- ArgTypes... args)
-{
- // Hand the chunk information to the user-supplied functor to process however it likes.
- callable(chunk_size, noop_flag, tl, args...);
+__global__ void multi_tensor_apply_kernel(int chunk_size,
+ volatile int *noop_flag, T tl,
+ U callable, ArgTypes... args) {
+ // Hand the chunk information to the user-supplied functor to process however
+ // it likes.
+ callable(chunk_size, noop_flag, tl, args...);
}
template
void multi_tensor_apply(
- int block_size,
- int chunk_size,
- const at::Tensor &noop_flag,
- const std::vector> &tensor_lists,
- T callable,
- ArgTypes... args)
-{
- TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
- int len0 = tensor_lists[0].size();
- TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
- auto ref_device = tensor_lists[0][0].device();
- TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
- for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
- {
- TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
- for (int t = 0; t < tensor_lists[l].size(); t++)
- {
- // TODO: Print which tensor fails.
- bool contiguous_memory = tensor_lists[l][t].is_contiguous();
+ int block_size, int chunk_size, const at::Tensor &noop_flag,
+ const std::vector> &tensor_lists, T callable,
+ ArgTypes... args) {
+ TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
+ int len0 = tensor_lists[0].size();
+ TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
+ auto ref_device = tensor_lists[0][0].device();
+ TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
+ for (int l = 0; l < tensor_lists.size();
+ l++) // No range-based for because I need indices
+ {
+ TORCH_CHECK(tensor_lists[l].size() == len0,
+ "Size mismatch among tensor lists");
+ for (int t = 0; t < tensor_lists[l].size(); t++) {
+ // TODO: Print which tensor fails.
+ bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
- contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
+ contiguous_memory =
+ (contiguous_memory ||
+ tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
- TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
- TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
- TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
- }
+ TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
+ TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
+ "A tensor was not on the same device as the first tensor");
+ TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(),
+ "Size mismatch");
}
-
- int ntensors = tensor_lists[0].size();
-
- TensorListMetadata tl;
-
- const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
- auto stream = at::cuda::getCurrentCUDAStream();
-
- tl.start_tensor_this_launch = 0;
- int loc_block_info = 0;
- int loc_tensor_info = 0;
- for (int t = 0; t < ntensors; t++)
- {
- tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
- for (int d = 0; d < depth; d++)
- tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
- loc_tensor_info++;
-
- int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
-
- for (int chunk = 0; chunk < chunks_this_tensor; chunk++)
- {
- // std::cout << chunks_this_tensor << std::endl;
- tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
- tl.block_to_chunk[loc_block_info] = chunk;
- loc_block_info++;
-
- bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
- chunk == chunks_this_tensor - 1);
- bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
- bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
- if (tensors_full || blocks_full || last_chunk)
- {
- // using accscalar_t = acc_type;
- multi_tensor_apply_kernel<<>>(
- chunk_size,
- noop_flag.DATA_PTR(),
- tl,
- callable,
- args...);
-
- AT_CUDA_CHECK(cudaGetLastError());
-
- // Reset. The control flow possibilities here make my brain hurt.
- loc_block_info = 0;
- if (chunk == chunks_this_tensor - 1)
- {
- // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
- loc_tensor_info = 0;
- tl.start_tensor_this_launch = t + 1;
- }
- else
- {
- // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
- tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
- for (int d = 0; d < depth; d++)
- tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
- loc_tensor_info = 1;
- tl.start_tensor_this_launch = t;
- }
- }
+ }
+
+ int ntensors = tensor_lists[0].size();
+
+ TensorListMetadata tl;
+
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
+ auto stream = at::cuda::getCurrentCUDAStream();
+
+ tl.start_tensor_this_launch = 0;
+ int loc_block_info = 0;
+ int loc_tensor_info = 0;
+ for (int t = 0; t < ntensors; t++) {
+ tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
+ for (int d = 0; d < depth; d++)
+ tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
+ loc_tensor_info++;
+
+ int chunks_this_tensor =
+ (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
+
+ for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
+ // std::cout << chunks_this_tensor << std::endl;
+ tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
+ tl.block_to_chunk[loc_block_info] = chunk;
+ loc_block_info++;
+
+ bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
+ chunk == chunks_this_tensor - 1);
+ bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
+ bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
+ if (tensors_full || blocks_full || last_chunk) {
+ // using accscalar_t = acc_type;
+ multi_tensor_apply_kernel<<>>(
+ chunk_size, noop_flag.DATA_PTR(), tl, callable, args...);
+
+ AT_CUDA_CHECK(cudaGetLastError());
+
+ // Reset. The control flow possibilities here make my brain hurt.
+ loc_block_info = 0;
+ if (chunk == chunks_this_tensor - 1) {
+ // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3
+ // << std::endl;
+ loc_tensor_info = 0;
+ tl.start_tensor_this_launch = t + 1;
+ } else {
+ // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3
+ // << std::endl;
+ tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
+ for (int d = 0; d < depth; d++)
+ tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
+ loc_tensor_info = 1;
+ tl.start_tensor_this_launch = t;
}
+ }
}
-}
\ No newline at end of file
+ }
+}
diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h
index cf83414af37f..2f180a7783ec 100644
--- a/colossalai/kernel/cuda_native/csrc/type_shim.h
+++ b/colossalai/kernel/cuda_native/csrc/type_shim.h
@@ -1,76 +1,69 @@
+/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
+/* Copyright 2020 The Microsoft DeepSpeed Team
+ Copyright NVIDIA/apex
+ This file is adapted from fused adam in NVIDIA/apex, commit a109f85
+ Licensed under the MIT License.
+*/
#include
-#include "compat.h"
-
-
-#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
- switch(TYPE) \
- { \
- case at::ScalarType::Half: \
- { \
- using scalar_t = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::BFloat16: \
- { \
- using scalar_t = at::BFloat16; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- }
+#include "compat.h"
+#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
+ switch (TYPE) { \
+ case at::ScalarType::Half: { \
+ using scalar_t = at::Half; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::BFloat16: { \
+ using scalar_t = at::BFloat16; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
+ }
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
- switch(TYPEIN) \
- { \
- case at::ScalarType::Float: \
- { \
- using scalar_t_in = float; \
- switch(TYPEOUT) \
- { \
- case at::ScalarType::Float: \
- { \
- using scalar_t_out = float; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Half: \
- { \
- using scalar_t_out = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::BFloat16: \
- { \
- using scalar_t_out = at::BFloat16; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
- } \
- break; \
- } \
- case at::ScalarType::Half: \
- { \
- using scalar_t_in = at::Half; \
- using scalar_t_out = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::BFloat16: \
- { \
- using scalar_t_in = at::BFloat16; \
- using scalar_t_out = at::BFloat16; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
- }
+ switch (TYPEIN) { \
+ case at::ScalarType::Float: { \
+ using scalar_t_in = float; \
+ switch (TYPEOUT) { \
+ case at::ScalarType::Float: { \
+ using scalar_t_out = float; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::Half: { \
+ using scalar_t_out = at::Half; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::BFloat16: { \
+ using scalar_t_out = at::BFloat16; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
+ } \
+ break; \
+ } \
+ case at::ScalarType::Half: { \
+ using scalar_t_in = at::Half; \
+ using scalar_t_out = at::Half; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::BFloat16: { \
+ using scalar_t_in = at::BFloat16; \
+ using scalar_t_out = at::BFloat16; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
+ }
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
@@ -81,222 +74,191 @@
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
-// // Enable dispatch switch statements to take *this directly for post-3aeb78
+// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
-#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
- switch (TYPE) \
- { \
- case at::ScalarType::Float: \
- { \
- using scalar_t_##LEVEL = float; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Half: \
- { \
- using scalar_t_##LEVEL = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- }
+#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
+ switch (TYPE) { \
+ case at::ScalarType::Float: { \
+ using scalar_t_##LEVEL = float; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::Half: { \
+ using scalar_t_##LEVEL = at::Half; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
+ }
-#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
- switch (TYPE) \
- { \
- case at::ScalarType::Float: \
- { \
- using scalar_t_##LEVEL = float; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Half: \
- { \
- using scalar_t_##LEVEL = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Byte: \
- { \
- using scalar_t_##LEVEL = uint8_t; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- }
+#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
+ switch (TYPE) { \
+ case at::ScalarType::Float: { \
+ using scalar_t_##LEVEL = float; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::Half: { \
+ using scalar_t_##LEVEL = at::Half; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::Byte: { \
+ using scalar_t_##LEVEL = uint8_t; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
+ }
-#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
- switch (TYPE) \
- { \
- case at::ScalarType::Double: \
- { \
- using scalar_t_##LEVEL = double; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Float: \
- { \
- using scalar_t_##LEVEL = float; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Half: \
- { \
- using scalar_t_##LEVEL = at::Half; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- }
+#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
+ switch (TYPE) { \
+ case at::ScalarType::Double: { \
+ using scalar_t_##LEVEL = double; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::Float: { \
+ using scalar_t_##LEVEL = float; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::Half: { \
+ using scalar_t_##LEVEL = at::Half; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
+ }
-#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
- switch (TYPE) \
- { \
- case at::ScalarType::Double: \
- { \
- using scalar_t_##LEVEL = double; \
- __VA_ARGS__; \
- break; \
- } \
- case at::ScalarType::Float: \
- { \
- using scalar_t_##LEVEL = float; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
- }
+#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
+ switch (TYPE) { \
+ case at::ScalarType::Double: { \
+ using scalar_t_##LEVEL = double; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case at::ScalarType::Float: { \
+ using scalar_t_##LEVEL = float; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
+ }
-#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
- if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
- { \
- using g_scalar_t_##LEVEL = float; \
- using p_scalar_t_##LEVEL = float; \
- __VA_ARGS__; \
- } \
- else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
- { \
- using g_scalar_t_##LEVEL = float; \
- using p_scalar_t_##LEVEL = at::Half; \
- __VA_ARGS__; \
- } \
- else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
- { \
- using g_scalar_t_##LEVEL = at::Half; \
- using p_scalar_t_##LEVEL = float; \
- __VA_ARGS__; \
- } \
- else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
- { \
- using g_scalar_t_##LEVEL = at::Half; \
- using p_scalar_t_##LEVEL = at::Half; \
- __VA_ARGS__; \
- } \
- else \
- { \
- AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
- } \
+#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
+ if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
+ using g_scalar_t_##LEVEL = float; \
+ using p_scalar_t_##LEVEL = float; \
+ __VA_ARGS__; \
+ } else if (GTYPE == at::ScalarType::Float && \
+ PTYPE == at::ScalarType::Half) { \
+ using g_scalar_t_##LEVEL = float; \
+ using p_scalar_t_##LEVEL = at::Half; \
+ __VA_ARGS__; \
+ } else if (GTYPE == at::ScalarType::Half && \
+ PTYPE == at::ScalarType::Float) { \
+ using g_scalar_t_##LEVEL = at::Half; \
+ using p_scalar_t_##LEVEL = float; \
+ __VA_ARGS__; \
+ } else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
+ using g_scalar_t_##LEVEL = at::Half; \
+ using p_scalar_t_##LEVEL = at::Half; \
+ __VA_ARGS__; \
+ } else { \
+ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
+ "'"); \
+ }
template
-__device__ __forceinline__ T reduce_block_into_lanes(T *x,
- T val,
- int lanes = 1,
- bool share_result = false) // lanes is intended to be <= 32.
+__device__ __forceinline__ T reduce_block_into_lanes(
+ T *x, T val, int lanes = 1,
+ bool share_result = false) // lanes is intended to be <= 32.
{
- int tid = threadIdx.x + threadIdx.y * blockDim.x;
- int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
+ int tid = threadIdx.x + threadIdx.y * blockDim.x;
+ int blockSize =
+ blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
- if (blockSize >= 64)
- {
- x[tid] = val;
- __syncthreads();
- }
+ if (blockSize >= 64) {
+ x[tid] = val;
+ __syncthreads();
+ }
#pragma unroll
- for (int i = (blockSize >> 1); i >= 64; i >>= 1)
- {
- if (tid < i)
- x[tid] = x[tid] + x[tid + i];
- __syncthreads();
- }
+ for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
+ if (tid < i) x[tid] = x[tid] + x[tid + i];
+ __syncthreads();
+ }
- T final;
+ T final;
- if (tid < 32)
- {
- if (blockSize >= 64)
- final = x[tid] + x[tid + 32];
- else
- final = val;
- // __SYNCWARP();
+ if (tid < 32) {
+ if (blockSize >= 64)
+ final = x[tid] + x[tid + 32];
+ else
+ final = val;
+ // __SYNCWARP();
#pragma unroll
- for (int i = 16; i >= lanes; i >>= 1)
- final = final + __shfl_down_sync(0xffffffff, final, i);
- }
+ for (int i = 16; i >= lanes; i >>= 1)
+ final = final + __shfl_down_sync(0xffffffff, final, i);
+ }
- if (share_result)
- {
- if (tid < lanes)
- x[tid] = final; // EpilogueOp
- // Make sure the smem result is visible to all warps.
- __syncthreads();
- }
+ if (share_result) {
+ if (tid < lanes) x[tid] = final; // EpilogueOp
+ // Make sure the smem result is visible to all warps.
+ __syncthreads();
+ }
- return final;
+ return final;
}
template
-__device__ __forceinline__ T reduce_block_into_lanes_max_op(T *x,
- T val,
- int lanes = 1,
- bool share_result = false) // lanes is intended to be <= 32.
+__device__ __forceinline__ T reduce_block_into_lanes_max_op(
+ T *x, T val, int lanes = 1,
+ bool share_result = false) // lanes is intended to be <= 32.
{
- int tid = threadIdx.x + threadIdx.y * blockDim.x;
- int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
+ int tid = threadIdx.x + threadIdx.y * blockDim.x;
+ int blockSize =
+ blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
- if (blockSize >= 64)
- {
- x[tid] = val;
- __syncthreads();
- }
+ if (blockSize >= 64) {
+ x[tid] = val;
+ __syncthreads();
+ }
#pragma unroll
- for (int i = (blockSize >> 1); i >= 64; i >>= 1)
- {
- if (tid < i)
- x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
- __syncthreads();
- }
+ for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
+ if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
+ __syncthreads();
+ }
- T final;
+ T final;
- if (tid < 32)
- {
- if (blockSize >= 64)
- final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
- else
- final = val;
- // __SYNCWARP();
+ if (tid < 32) {
+ if (blockSize >= 64)
+ final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
+ else
+ final = val;
+ // __SYNCWARP();
#pragma unroll
- for (int i = 16; i >= lanes; i >>= 1)
- final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
- }
+ for (int i = 16; i >= lanes; i >>= 1)
+ final =
+ fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
+ }
- if (share_result)
- {
- if (tid < lanes)
- x[tid] = final; // EpilogueOp
- // Make sure the smem result is visible to all warps.
- __syncthreads();
- }
+ if (share_result) {
+ if (tid < lanes) x[tid] = final; // EpilogueOp
+ // Make sure the smem result is visible to all warps.
+ __syncthreads();
+ }
- return final;
-}
\ No newline at end of file
+ return final;
+}
diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py
index 7bd646d3935f..d793815ed681 100644
--- a/colossalai/kernel/cuda_native/flash_attention.py
+++ b/colossalai/kernel/cuda_native/flash_attention.py
@@ -1,8 +1,6 @@
"""
-Fused Attention
-===============
-This is a Triton implementation of the Flash Attention algorithm
-(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
+A general attention module using the flash attention kernels from xformers:
+https://github.com/facebookresearch/xformers/tree/main/xformers/ops/fmha
"""
import math
@@ -11,6 +9,159 @@
import torch
+try:
+ from xformers.ops.fmha import memory_efficient_attention
+ HAS_MEM_EFF_ATTN = True
+except ImportError:
+ HAS_MEM_EFF_ATTN = False
+ print('please install xformers from https://github.com/facebookresearch/xformers')
+
+if HAS_MEM_EFF_ATTN:
+
+ from typing import Optional
+
+ from einops import rearrange
+ from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
+ from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias
+
+ from .scaled_softmax import AttnMaskType
+
+ allow_alibi = True
+ for op in MemoryEfficientAttentionCutlassOp:
+ allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
+
+ class Unpad(torch.autograd.Function):
+ """
+ Adapted from
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
+ """
+
+ @staticmethod
+ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
+ ctx.save_for_backward(indices)
+ # [b, s, ...]
+ assert tensor.ndim >= 3
+ ctx.bsz = tensor.shape[0]
+ out = rearrange(tensor, 'b s ... -> (b s) ...')
+ ctx.shape = out.shape
+ # [1, ntokens, ...]
+ return out[indices].unsqueeze(0)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ # [b*s, ...]
+ grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
+ grad[indices] = grad_output.squeeze(0)
+ grad = rearrange(grad, '(b s) ... -> b s ...', b=ctx.bsz)
+ # [b, s, ...]
+ return grad, None
+
+ class Repad(torch.autograd.Function):
+ """
+ Adapted from
+ https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
+ """
+
+ @staticmethod
+ def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
+ ctx.save_for_backward(indices)
+ # [ntokens, ...]
+ tensor = tensor.squeeze(0)
+ out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
+ # [b*s, ...]
+ out[indices] = tensor
+ # [b, s, ...]
+ out = rearrange(out, '(b s) ... -> b s ...', b=batch_size)
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ indices, = ctx.saved_tensors
+ # [b*s, ...]
+ grad_output = rearrange(grad_output, 'b s ... -> (b s) ...')
+ grad = grad_output[indices]
+ # [1, ntokens, ...]
+ return grad.unsqueeze(0), None, None, None
+
+ class ColoAttention(torch.nn.Module):
+
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
+ super().__init__()
+ assert embed_dim % num_heads == 0, \
+ f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
+ self.scale = 1 / math.sqrt(embed_dim // num_heads)
+ self.dropout = dropout
+
+ @staticmethod
+ def get_seq_info_from_mask(attn_mask: torch.Tensor):
+ indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten()
+ seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten().tolist()
+ return indices, seqlens
+
+ @staticmethod
+ def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
+ return Unpad.apply(tensor, indices)
+
+ @staticmethod
+ def repad(tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int) -> torch.Tensor:
+ return Repad.apply(tensor, indices, batch_size, seq_len)
+
+ def forward(self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ attn_mask_type: Optional[AttnMaskType] = None,
+ bias: Optional[torch.Tensor] = None):
+ batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
+ attn_bias = None
+ if attn_mask_type == AttnMaskType.padding: # bert style
+ assert attn_mask is not None, \
+ f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
+ assert attn_mask.dim() == 2, \
+ "attention mask is supposed to have shape (batch_size, seq_len), " + \
+ f"but got {attn_mask.dim()} dimensions."
+ if tgt_len == src_len:
+ q_indices, q_seqlen = self.get_seq_info_from_mask(attn_mask)
+ kv_seqlen = None
+ if batch_size > 1:
+ query, key, value = self.unpad(torch.stack([query, key, value], dim=2), q_indices).unbind(dim=2)
+ else:
+ q_indices = torch.arange(batch_size * tgt_len, dtype=torch.int32, device=query.device)
+ q_seqlen = torch.LongTensor([tgt_len] * batch_size, device=query.device)
+ kv_indices, kv_seqlen = self.get_seq_info_from_mask(attn_mask)
+ if batch_size > 1:
+ query = rearrange(query, "b s ... -> c (b s) ...", c=1)
+ key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
+ attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
+ elif attn_mask_type == AttnMaskType.causal: # gpt style
+ attn_bias = LowerTriangularMask()
+
+ if bias is not None: # alibi / relative position emebedding
+ assert allow_alibi, "flash attention with bias is not supported in this system."
+ assert attn_mask_type == AttnMaskType.causal, \
+ "attention with bias is only supported for causal attention so far."
+ attn_bias = attn_bias.add_bias(bias)
+
+ out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)
+
+ if attn_mask_type == AttnMaskType.padding and batch_size > 1:
+ out = self.repad(out, q_indices, batch_size, tgt_len)
+
+ out = rearrange(out, 'b s h d -> b s (h d)')
+ return out
+
+
+##########################################################################
+# the flash attention functions below that are copied
+# from the OpenAI/triton repository will be deprecated
+# You can find the repository in Triton https://github.com/openai/triton
+# You can find the source file in https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
+# Reference:
+# 1. Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf
+# 2. Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf
+
def triton_cuda_check():
cuda_home = os.getenv("CUDA_HOME", default="/usr/local/cuda")
@@ -48,15 +199,9 @@ def triton_cuda_check():
HAS_FLASH_ATTN = False
print('please install flash_attn from https://github.com/HazyResearch/flash-attention')
-try:
- from xformers.ops.fmha import memory_efficient_attention
- HAS_MEM_EFF_ATTN = True
-except ImportError:
- HAS_MEM_EFF_ATTN = False
- print('please install xformers from https://github.com/facebookresearch/xformers')
-
if HAS_TRITON:
-
+ # the following functions are adapted from the OpenAI Triton tutorial
+ # https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
@triton.jit
def _fwd_kernel(
Q,
@@ -417,25 +562,6 @@ def triton_flash_attention(q, k, v, sm_scale):
if HAS_FLASH_ATTN:
- from einops import rearrange
-
- class MaskedFlashAttention(torch.nn.Module):
-
- def __init__(self, num_attention_heads: int, attention_head_size: int, attention_dropout: float) -> None:
- super().__init__()
- self.num_attention_heads = num_attention_heads
- self.attention_head_size = attention_head_size
- self.attention_func = FlashAttention(softmax_scale=math.sqrt(attention_head_size),
- attention_dropout=attention_dropout)
-
- def forward(self, query_key_value: torch.Tensor, attention_mask: torch.Tensor, causal=False):
- if attention_mask.dtype is not torch.bool:
- attention_mask = attention_mask.bool()
- qkv = rearrange(query_key_value, 'b s (three h d) -> b s three h d', three=3, h=self.num_attention_heads)
- context, _ = self.attention_func(qkv, key_padding_mask=attention_mask, causal=causal)
- context = rearrange(context, 'b s h d -> b s (h d)')
- return context
-
def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len, dropout_p=0., causal=False):
"""
Arguments:
@@ -506,20 +632,4 @@ def flash_attention_q_k_v(q, k, v, sm_scale, batch_size, q_seqlen, kv_seqlen, dr
causal)
-if HAS_MEM_EFF_ATTN:
-
- from einops import rearrange
- from xformers.ops.fmha import LowerTriangularMask
-
- class MemoryEfficientAttention(torch.nn.Module):
-
- def __init__(self, hidden_size: int, num_attention_heads: int, attention_dropout: float = 0.0):
- super().__init__()
- attention_head_size = hidden_size // num_attention_heads
- self.scale = 1 / attention_head_size**0.5
- self.dropout = attention_dropout
-
- def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor):
- context = memory_efficient_attention(query, key, value, attention_mask, self.dropout, self.scale)
- context = rearrange(context, 'b s h d -> b s (h d)')
- return context
+##########################################################################
diff --git a/colossalai/kernel/cuda_native/layer_norm.py b/colossalai/kernel/cuda_native/layer_norm.py
index f1b5efa4ec8c..40355a41ed0d 100644
--- a/colossalai/kernel/cuda_native/layer_norm.py
+++ b/colossalai/kernel/cuda_native/layer_norm.py
@@ -9,24 +9,31 @@
from torch.nn import init
from torch.nn.parameter import Parameter
+from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
+
+try:
+ from colossalai._C import layer_norm
+except ImportError:
+ layer_norm = None
+
class FusedLayerNormAffineFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, weight, bias, normalized_shape, eps):
- try:
- import colossalai._C.layer_norm
- except ImportError:
- raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
-
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
- output, mean, invvar = colossalai._C.layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
- ctx.eps)
+
+ global layer_norm
+ if layer_norm is None:
+
+ layer_norm = LayerNormBuilder().load()
+ output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
+ ctx.layernorm_op = layer_norm
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@@ -34,15 +41,10 @@ def forward(ctx, input, weight, bias, normalized_shape, eps):
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
- try:
- import colossalai._C.layer_norm
- except ImportError:
- raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
-
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
- = colossalai._C.layer_norm.backward_affine(
+ = layer_norm.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py
index 9e147b4199ec..24e458bb3ea5 100644
--- a/colossalai/kernel/cuda_native/scaled_softmax.py
+++ b/colossalai/kernel/cuda_native/scaled_softmax.py
@@ -1,11 +1,20 @@
-"""This code from NVIDIA Megatron
- with some changes. """
+# This code from NVIDIA Megatron:
+# with minor changes.
import enum
import torch
import torch.nn as nn
+from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
+from colossalai.kernel.op_builder.scaled_upper_triangle_masked_softmax import ScaledUpperTrainglemaskedSoftmaxBuilder
+
+try:
+ from colossalai._C import scaled_masked_softmax, scaled_upper_triang_masked_softmax
+except ImportError:
+ scaled_masked_softmax = None
+ scaled_upper_triang_masked_softmax = None
+
class AttnMaskType(enum.Enum):
padding = 1
@@ -23,7 +32,9 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, scale):
- from colossalai.kernel import scaled_upper_triang_masked_softmax
+ global scaled_upper_triang_masked_softmax
+ if scaled_upper_triang_masked_softmax:
+ scaled_upper_triang_masked_softmax = ScaledUpperTrainglemaskedSoftmaxBuilder().load()
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
@@ -33,8 +44,6 @@ def forward(ctx, inputs, scale):
@staticmethod
def backward(ctx, output_grads):
- from colossalai.kernel import scaled_upper_triang_masked_softmax
-
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
@@ -52,28 +61,23 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
- try:
- import colossalai._C.scaled_masked_softmax
- except ImportError:
- raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
-
scale_t = torch.tensor([scale])
- softmax_results = colossalai._C.scaled_masked_softmax.forward(inputs, mask, scale_t[0])
+ # build and load kernel if not pre-built
+ global scaled_masked_softmax
+ if scaled_masked_softmax is None:
+ scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
+
+ softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
- try:
- import colossalai._C.scaled_masked_softmax
- except ImportError:
- raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
-
softmax_results, scale_t = ctx.saved_tensors
- input_grads = colossalai._C.scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
- return input_grads, None, None
+ input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
+ return input_grads, None, None, None
class FusedScaleMaskSoftmax(nn.Module):
@@ -111,7 +115,6 @@ def __init__(
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
-
assert (self.scale is None or softmax_in_fp32), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
@@ -176,11 +179,10 @@ def forward_torch_softmax(self, input, mask):
return probs
- @staticmethod
- def get_batch_per_block(sq, sk, b, np):
- try:
- import colossalai._C.scaled_masked_softmax
- except ImportError:
- raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
+ def get_batch_per_block(self, sq, sk, b, np):
+ # build and load kernel if not pre-built
+ global scaled_masked_softmax
+ if scaled_masked_softmax is None:
+ scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
- return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
+ return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
diff --git a/colossalai/kernel/jit/bias_gelu.py b/colossalai/kernel/jit/bias_gelu.py
index e6da70c40b42..33b4ac32b044 100644
--- a/colossalai/kernel/jit/bias_gelu.py
+++ b/colossalai/kernel/jit/bias_gelu.py
@@ -1,3 +1,4 @@
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
diff --git a/colossalai/nn/_ops/view.py b/colossalai/nn/_ops/view.py
index 3197e7568d6f..3c0bc52337ce 100644
--- a/colossalai/nn/_ops/view.py
+++ b/colossalai/nn/_ops/view.py
@@ -1,97 +1,96 @@
-import math
-import torch
-from colossalai.tensor.op_wrapper import colo_op_impl
-from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
-from typing import Optional, Union
-
-
-def _all_int(my_iter):
- return all(isinstance(i, int) for i in my_iter)
-
-
-def _get_valid_shape(shape):
- if isinstance(shape, list):
- if _all_int(shape):
- return tuple(shape)
- else:
- raise RuntimeError("expects type(int) but finds an other type")
- elif isinstance(shape, tuple):
- if _all_int(shape):
- return shape
- else:
- return _get_valid_shape(shape[0])
- else:
- raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
-
-
-def _shape_infer(org_sp, tgt_sp):
- cnt = 0
- pos = 0
- for idx, dim in enumerate(tgt_sp):
- if dim < -1:
- raise RuntimeError("invalid shape dimension {}".format(dim))
- elif dim == -1:
- cnt += 1
- pos = idx
-
- if cnt > 1:
- raise RuntimeError("only one dimension can be inferred")
-
- org_prod = math.prod(org_sp)
- tgt_prod = math.prod(tgt_sp)
-
- if cnt == 0:
- if org_prod != tgt_prod:
- raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
- else:
- return tgt_sp
- elif org_prod % tgt_prod != 0:
- raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
-
- infer_dim = -(org_prod // tgt_prod)
- return tgt_sp[: pos] + (infer_dim,) + tgt_sp[pos + 1:]
-
-
-@colo_op_impl(torch.Tensor.view)
-def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
- """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
- Changes the shape of the current tensor.
- """
- assert isinstance(self, ColoTensor)
- # apply original `view` function for replicated colo tensors
- if self.is_replicate():
- return self.view(*shape)
-
- cur_sp = self.size()
- org_sp = self.size_global()
- # parse the passed arguments
- tgt_sp = _get_valid_shape(shape)
- # get the correct shape from inference
- inf_sp = _shape_infer(org_sp, tgt_sp)
-
- if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
- new_shape = (cur_sp[0],) + tgt_sp[1:]
- res = self.view(*new_shape)
- elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
- new_shape = tgt_sp[:-1] + (cur_sp[-1],)
- res = self.view(*new_shape)
- else:
- replicated_t = self.redistribute(dist_spec=ReplicaSpec())
- return ColoTensor.from_torch_tensor(
- tensor=replicated_t.view(*shape),
- spec=ColoTensorSpec(self.get_process_group()))
-
- return ColoTensor.from_torch_tensor(
- tensor=res,
- spec=ColoTensorSpec(
- pg=self.get_process_group(),
- dist_attr=self.dist_spec))
-
-
-@colo_op_impl(torch.Tensor.size)
-def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
- size = self.size_global()
- if dim is None:
- return size
- else:
- return size[dim]
+import operator
+from functools import reduce
+from typing import Optional, Union
+
+import torch
+
+from colossalai.tensor import ColoTensor, ColoTensorSpec, ReplicaSpec
+from colossalai.tensor.op_wrapper import colo_op_impl
+
+
+def _all_int(my_iter):
+ return all(isinstance(i, int) for i in my_iter)
+
+
+def _get_valid_shape(shape):
+ if isinstance(shape, list):
+ if _all_int(shape):
+ return tuple(shape)
+ else:
+ raise RuntimeError("expects type(int) but finds an other type")
+ elif isinstance(shape, tuple):
+ if _all_int(shape):
+ return shape
+ else:
+ return _get_valid_shape(shape[0])
+ else:
+ raise RuntimeError("expects an iterable array but finds '{}'".format(type(shape)))
+
+
+def _shape_infer(org_sp, tgt_sp):
+ cnt = 0
+ pos = 0
+ for idx, dim in enumerate(tgt_sp):
+ if dim < -1:
+ raise RuntimeError("invalid shape dimension {}".format(dim))
+ elif dim == -1:
+ cnt += 1
+ pos = idx
+
+ if cnt > 1:
+ raise RuntimeError("only one dimension can be inferred")
+
+ org_prod = reduce(operator.mul, org_sp, 1)
+ tgt_prod = reduce(operator.mul, tgt_sp, 1)
+
+ if cnt == 0:
+ if org_prod != tgt_prod:
+ raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
+ else:
+ return tgt_sp
+ elif org_prod % tgt_prod != 0:
+ raise RuntimeError("shape '{}' is invalid for input of size {}".format(tgt_sp, org_prod))
+
+ infer_dim = -(org_prod // tgt_prod)
+ return tgt_sp[:pos] + (infer_dim,) + tgt_sp[pos + 1:]
+
+
+@colo_op_impl(torch.Tensor.view)
+def colo_view(self: ColoTensor, *shape) -> 'ColoTensor':
+ """Handles ``__torch_function__`` dispatch for ``torch.Tensor.view``.
+ Changes the shape of the current tensor.
+ """
+ assert isinstance(self, ColoTensor)
+ # apply original `view` function for replicated colo tensors
+ if self.is_replicate():
+ return self.view(*shape)
+
+ cur_sp = self.size()
+ org_sp = self.size_global()
+ # parse the passed arguments
+ tgt_sp = _get_valid_shape(shape)
+ # get the correct shape from inference
+ inf_sp = _shape_infer(org_sp, tgt_sp)
+
+ if self.is_shard_1drow() and org_sp[0] == inf_sp[0]:
+ new_shape = (cur_sp[0],) + tgt_sp[1:]
+ res = self.view(*new_shape)
+ elif self.is_shard_1dcol() and org_sp[-1] == inf_sp[-1]:
+ new_shape = tgt_sp[:-1] + (cur_sp[-1],)
+ res = self.view(*new_shape)
+ else:
+ replicated_t = self.redistribute(dist_spec=ReplicaSpec())
+ return ColoTensor.from_torch_tensor(tensor=replicated_t.view(*shape),
+ spec=ColoTensorSpec(self.get_process_group()))
+
+ return ColoTensor.from_torch_tensor(tensor=res,
+ spec=ColoTensorSpec(pg=self.get_process_group(), dist_attr=self.dist_spec))
+
+
+@colo_op_impl(torch.Tensor.size)
+def colo_size(self: ColoTensor, dim: Optional[int] = None) -> Union[torch.Size, int]:
+ size = self.size_global()
+ if dim is None:
+ return size
+ else:
+ return size[dim]
diff --git a/colossalai/nn/layer/colossalai_layer/_utils.py b/colossalai/nn/layer/colossalai_layer/_utils.py
index 4283e5fe09b5..677cb0e7ac42 100644
--- a/colossalai/nn/layer/colossalai_layer/_utils.py
+++ b/colossalai/nn/layer/colossalai_layer/_utils.py
@@ -1,38 +1,41 @@
-import torch.nn as nn
-from torch import Tensor
-
-from ..parallel_2d._operation import split_batch_2d
-from ..parallel_2p5d._operation import split_batch_2p5d
-from ..parallel_3d._operation import split_batch_3d
-from ..utils import get_tensor_parallel_mode
-
-_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
-
-
-def partition_batch(input_) -> Tensor:
- tensor_parallel_mode = get_tensor_parallel_mode()
- if tensor_parallel_mode in _parallel_split_batch:
- if isinstance(input_, dict):
- return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
- else:
- return _parallel_split_batch[tensor_parallel_mode](input_)
- else:
- return input_
-
-
-class ColossalaiModule(nn.Module):
-
- def __init__(self, module: nn.Module, **kwargs):
- super().__init__()
- # copy values
- self.__dict__ = module.__dict__.copy()
- # copy methods
- for name, attr in module.__class__.__dict__.items():
- if name not in ['__init__', 'forward'] and callable(attr):
- setattr(self, name, getattr(module, name))
- self._forward_func = module.forward
- for k, v in kwargs.items():
- setattr(self, k, v)
-
- def forward(self, *args):
- return self._forward_func(*args)
+import torch.nn as nn
+from torch import Tensor
+
+from ..parallel_2d._operation import split_batch_2d
+from ..parallel_2p5d._operation import split_batch_2p5d
+from ..parallel_3d._operation import split_batch_3d
+from ..utils import get_tensor_parallel_mode
+
+_parallel_split_batch = {'2d': split_batch_2d, '2.5d': split_batch_2p5d, '3d': split_batch_3d}
+
+
+def partition_batch(input_) -> Tensor:
+ tensor_parallel_mode = get_tensor_parallel_mode()
+ if tensor_parallel_mode in _parallel_split_batch:
+ if isinstance(input_, dict):
+ return {k: _parallel_split_batch[tensor_parallel_mode](v) for k, v in input_.items()}
+ else:
+ return _parallel_split_batch[tensor_parallel_mode](input_)
+ else:
+ return input_
+
+
+class ColossalaiModule(nn.Module):
+
+ def __init__(self, module: nn.Module, **kwargs):
+ super().__init__()
+ self.module = module
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ def __getattr__(self, name: str):
+ if name == 'module':
+ return super().__getattr__(name)
+ elif hasattr(self.module, name):
+ return getattr(self.module, name)
+ elif name in self.__dict__:
+ return self.__dict__[name]
+ raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
+
+ def forward(self, *args):
+ return self.module(*args)
diff --git a/colossalai/nn/layer/colossalai_layer/dropout.py b/colossalai/nn/layer/colossalai_layer/dropout.py
index cc2d9a0a70fd..0c049cb3f408 100644
--- a/colossalai/nn/layer/colossalai_layer/dropout.py
+++ b/colossalai/nn/layer/colossalai_layer/dropout.py
@@ -1,30 +1,31 @@
-import torch.nn as nn
-from colossalai.context import ParallelMode, seed
-
-from ..parallel_1d import *
-from ..utils import get_tensor_parallel_mode
-from ._utils import ColossalaiModule
-
-
-class Dropout(ColossalaiModule):
- """Dropout layer of colossalai.
-
- Args:
- p (float, optional): probability of an element to be zeroed, defaults 0.5.
- inplace (bool, optional): whether to do dropout in-place, default to be False.
- """
-
- def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
- tensor_parallel = get_tensor_parallel_mode()
- if tensor_parallel == "1d":
- drop = Dropout1D(p, inplace)
- else:
- drop = nn.Dropout(p, inplace)
- super().__init__(drop, tensor_parallel=tensor_parallel)
-
- def forward(self, *args):
- if self.tensor_parallel in [None, '1d']:
- return self._forward_func(*args)
- else:
- with seed(ParallelMode.TENSOR):
- return self._forward_func(*args)
+import torch.nn as nn
+
+from colossalai.context import ParallelMode, seed
+
+from ..parallel_1d import *
+from ..utils import get_tensor_parallel_mode
+from ._utils import ColossalaiModule
+
+
+class Dropout(ColossalaiModule):
+ """Dropout layer of colossalai.
+
+ Args:
+ p (float, optional): probability of an element to be zeroed, defaults 0.5.
+ inplace (bool, optional): whether to do dropout in-place, default to be False.
+ """
+
+ def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
+ tensor_parallel = get_tensor_parallel_mode()
+ if tensor_parallel == "1d":
+ drop = Dropout1D(p, inplace)
+ else:
+ drop = nn.Dropout(p, inplace)
+ super().__init__(drop, tensor_parallel=tensor_parallel)
+
+ def forward(self, *args):
+ if self.tensor_parallel in [None, '1d']:
+ return super().forward(*args)
+ else:
+ with seed(ParallelMode.TENSOR):
+ return super().forward(*args)
diff --git a/colossalai/nn/optimizer/README.md b/colossalai/nn/optimizer/README.md
index 268e37d57997..09395d08b93e 100644
--- a/colossalai/nn/optimizer/README.md
+++ b/colossalai/nn/optimizer/README.md
@@ -2,30 +2,30 @@
## Introduction
-Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI),
-which has been accepted as official tutorials by top conference [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), etc.
+Welcome to the large-scale deep learning optimization techniques of [Colossal-AI](https://github.com/hpcaitech/ColossalAI),
+which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates
many advanced technologies such as multi-dimensional tensor parallelism, sequence parallelism, heterogeneous memory management,
-large-scale optimization, adaptive task scheduling, etc. By using Colossal-AI, we could help users to efficiently and
+large-scale optimization, adaptive task scheduling, etc. By using Colossal-AI, we could help users to efficiently and
quickly deploy large AI model training and inference, reducing large AI model training budgets and scaling down the labor cost of learning and deployment.
### 🚀 Quick Links
[**Colossal-AI**](https://github.com/hpcaitech/ColossalAI) |
-[**Paper**](https://arxiv.org/abs/2110.14883) |
-[**Documentation**](https://www.colossalai.org/) |
-[**Forum**](https://github.com/hpcaitech/ColossalAI/discussions) |
+[**Paper**](https://arxiv.org/abs/2110.14883) |
+[**Documentation**](https://www.colossalai.org/) |
+[**Forum**](https://github.com/hpcaitech/ColossalAI/discussions) |
[**Slack**](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
## Table of Content
-Large transformer models display promising performance on a wide spectrum of AI applications.
+Large transformer models display promising performance on a wide spectrum of AI applications.
Both academia and industry are scaling DL training on larger clusters. However, degrading generalization performance, non-negligible communication overhead, and increasing model size prevent DL researchers and engineers from exploring large-scale AI models.
-We aim to provide a clear sketch of the optimizations for large-scale deep learning with regard to model accuracy and model efficiency.
+We aim to provide a clear sketch of the optimizations for large-scale deep learning with regard to model accuracy and model efficiency.
One way to achieve the goal of maintaining or improving the model accuracy in the large-scale setting while maintaining compute efficiency is to design algorithms that
are less communication and memory hungry. Notably, they are not mutually exclusive but can
be optimized jointly to further speed up training.
@@ -51,7 +51,7 @@ be optimized jointly to further speed up training.
- Memory Efficiency
- Mix-Precision Training
- Memory-Efficient Methods, e.g. ZeRO, Gemini, etc.
-
+
Some of the above are still under development. **If you wish to make a contribution to this repository, please read the `Contributing` section below.**
## Discussion
@@ -63,7 +63,7 @@ If you encounter any problem while running these optimizers, you may want to rai
## Contributing
-This project welcomes constructive ideas and implementations from the community.
+This project welcomes constructive ideas and implementations from the community.
### Update an Optimizer
diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py
index a8c3522793d8..54036973e1e3 100644
--- a/colossalai/nn/optimizer/cpu_adam.py
+++ b/colossalai/nn/optimizer/cpu_adam.py
@@ -19,7 +19,7 @@ class CPUAdam(NVMeOptimizer):
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
- Requires ColossalAI to be installed via ``pip install .``.
+ `CPUAdam` requires CUDA extensions which can be built during installation or runtime.
This version of CPU Adam accelates parameters updating on CPU with SIMD.
Support of AVX2 or AVX512 is required.
diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py
index 2f6bde5ca1ab..987af8a968b7 100644
--- a/colossalai/nn/optimizer/fused_adam.py
+++ b/colossalai/nn/optimizer/fused_adam.py
@@ -1,4 +1,11 @@
# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_adam.py
+'''
+Copyright 2020 The Microsoft DeepSpeed Team
+
+Copyright NVIDIA/apex
+This file is adapted from fused adam in NVIDIA/apex, commit a109f85
+Licensed under the MIT License.
+'''
import torch
from colossalai.registry import OPTIMIZERS
@@ -9,8 +16,7 @@
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
- Currently GPU-only. Requires ColossalAI to be installed via
- ``pip install .``.
+ `FusedAdam` requires CUDA extensions which can be built during installation or runtime.
This version of fused Adam implements 2 fusions.
diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py
index 891a76da73dd..72520064e98b 100644
--- a/colossalai/nn/optimizer/fused_lamb.py
+++ b/colossalai/nn/optimizer/fused_lamb.py
@@ -9,8 +9,7 @@
class FusedLAMB(torch.optim.Optimizer):
"""Implements LAMB algorithm.
- Currently GPU-only. Requires ColossalAI to be installed via
- ``pip install .``.
+ `FusedLAMB` requires CUDA extensions which can be built during installation or runtime.
This version of fused LAMB implements 2 fusions.
diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py
index 41e6d524895a..468713b223c1 100644
--- a/colossalai/nn/optimizer/fused_sgd.py
+++ b/colossalai/nn/optimizer/fused_sgd.py
@@ -10,8 +10,7 @@
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
- Currently GPU-only. Requires ColossalAI to be installed via
- ``pip install .``.
+ `FusedSGD` requires CUDA extensions which can be built during installation or runtime.
This version of fused SGD implements 2 fusions.
diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py
index 5196d4338441..1d0fb92de499 100644
--- a/colossalai/nn/optimizer/hybrid_adam.py
+++ b/colossalai/nn/optimizer/hybrid_adam.py
@@ -19,7 +19,7 @@ class HybridAdam(NVMeOptimizer):
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
- Requires ColossalAI to be installed via ``pip install .``
+ `HybriadAdam` requires CUDA extensions which can be built during installation or runtime.
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.
diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/nn/optimizer/zero_optimizer.py
index 2786d4496a8e..422ebb7a3944 100644
--- a/colossalai/nn/optimizer/zero_optimizer.py
+++ b/colossalai/nn/optimizer/zero_optimizer.py
@@ -1,4 +1,6 @@
+# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import math
+import warnings
from enum import Enum
from typing import Any, Dict, Set, Tuple
@@ -12,7 +14,7 @@
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP
-from colossalai.utils import disposable, get_current_device
+from colossalai.utils import disposable, get_current_device, is_ddp_ignored
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
@@ -64,7 +66,8 @@ def __init__(self,
**defaults: Any):
super().__init__(optim)
assert isinstance(module, ZeroDDP)
- assert type(optim) in _AVAIL_OPTIM_LIST, "you should use the optimizer in the available list"
+ assert type(optim) in _AVAIL_OPTIM_LIST, "You should use an optimizer in the available list:\n" \
+ f"{_AVAIL_OPTIM_LIST}"
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
@@ -78,8 +81,16 @@ def __init__(self,
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
- params_list = [p for p in module.parameters() if not getattr(p, '_ddp_to_ignore', False)]
- for p, fp32_p in zip(params_list, module.fp32_params):
+ ddp_param_list = []
+ for name, param in module.named_parameters():
+ if is_ddp_ignored(param):
+ if param.requires_grad:
+ warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
+ "You should handle its optimizer update by yourself!")
+ else:
+ ddp_param_list.append(param)
+
+ for p, fp32_p in zip(ddp_param_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
@@ -126,7 +137,7 @@ def _update_fp16_params(self):
for group in self.param_groups:
for fake_param in group['params']:
assert fake_param.grad is None
- fake_param.data = none_tensor
+ fake_param.data = none_tensor.to(fake_param.device)
for chunk16 in self.chunk16_set:
chunk16.optim_update()
@@ -140,6 +151,10 @@ def _check_overflow(self):
return self._found_overflow.item() > 0
+ def _clear_global_norm(self) -> None:
+ for c16 in self.chunk16_set:
+ c16.l2_norm = None
+
def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
@@ -201,6 +216,7 @@ def step(self, *args, **kwargs):
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
self._logger.info(f'Found overflow. Skip step')
+ self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
self._update_fp16_params()
return
@@ -285,12 +301,15 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter):
fake_params_list = list()
for param in group['params']:
+ if is_ddp_ignored(param):
+ continue
chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]:
continue
- fake_param = torch.nn.Parameter(torch.empty([0]))
+ grad_device = self.module.grads_device[param]
+ fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device))
self.param_to_chunk32[fake_param] = chunk16.paired_chunk
self.param_to_range[fake_param] = range_pair
diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/nn/parallel/__init__.py
index 0c369bfce22f..2afc8f18c36f 100644
--- a/colossalai/nn/parallel/__init__.py
+++ b/colossalai/nn/parallel/__init__.py
@@ -1,4 +1,5 @@
from .data_parallel import ColoDDP, ZeroDDP
from .gemini_parallel import GeminiDDP
+from .zero_wrapper import zero_model_wrapper, zero_optim_wrapper
-__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP']
+__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP', 'zero_model_wrapper', 'zero_optim_wrapper']
diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py
index e3bb83347d21..a9d001bd0a9c 100644
--- a/colossalai/nn/parallel/data_parallel.py
+++ b/colossalai/nn/parallel/data_parallel.py
@@ -5,6 +5,7 @@
import torch
import torch.distributed as dist
+import torch.nn as nn
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
@@ -12,12 +13,14 @@
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
+from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
-from colossalai.utils import get_current_device
+from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from .reducer import Reducer
+from .utils import get_static_torch_model
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
@@ -80,7 +83,7 @@ def __init__(self,
self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket
for p in module.parameters():
- if getattr(p, '_ddp_to_ignore', False):
+ if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
@@ -115,7 +118,7 @@ def backward(self, loss: torch.Tensor):
if self.rebuild_bucket:
self.reducer.free()
for p in self.module.parameters():
- if getattr(p, '_ddp_to_ignore', False):
+ if is_ddp_ignored(p):
continue
if p.grad.device.type != "cpu":
p.grad = p._saved_grad
@@ -199,24 +202,32 @@ class ZeroDDP(ColoDDP):
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
- force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
+ force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
+ Defaults to False.
+ strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
+ Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
"""
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
- force_outputs_fp32: bool = False) -> None:
+ force_outputs_fp32: bool = False,
+ strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager)
- self.fp32_params: List[ColoTensor] = []
+ self.fp32_params: List[ColoTensor] = list()
+ self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
- self.grads_device: Dict[torch.Tensor, torch.device] = {}
+ self.grads_device: Dict[torch.Tensor, torch.device] = dict()
+ self.param2name: Dict[nn.Parameter, str] = dict()
+ self.name2param: Dict[str, nn.Parameter] = dict()
- cpu_offload = self.gemini_manager.policy_name != 'cuda'
+ self._cast_buffers()
+ self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
@@ -228,64 +239,78 @@ def __init__(self,
for p in module.parameters():
param_order.append(p)
- for p in param_order.generate():
- assert isinstance(p, ColoParameter)
+ self._init_chunks(param_order=param_order,
+ strict_ddp_mode=strict_ddp_mode,
+ cpu_offload=self.gemini_manager.policy_name != 'cuda',
+ pin_memory=pin_memory)
- if getattr(p, '_ddp_to_ignore', False):
- p.data = p.data.half()
- continue
+ for name, param in module.named_parameters():
+ self.param2name[param] = name
+ for m_name, m_var in module.named_modules():
+ for p_name, p_var in m_var.named_parameters(recurse=False):
+ param_name = m_name + '.' + p_name if m_name else p_name
+ self.name2param[param_name] = p_var
- fp32_data = p.data.float()
- fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
- p.data = p.data.half()
- dp_world_size = p.process_group.dp_world_size()
- self.chunk_manager.register_tensor(tensor=p,
- group_type='fp16_param',
- config_key=dp_world_size,
- cpu_offload=cpu_offload,
- pin_memory=pin_memory)
- self.chunk_manager.register_tensor(tensor=fp32_p,
- group_type='fp32_param',
- config_key=dp_world_size,
- cpu_offload=cpu_offload,
- pin_memory=pin_memory)
- self.fp32_params.append(fp32_p)
- self.grads_device[p] = self.gemini_manager.default_device
- self.chunk_manager.close_all_groups()
- self._cast_buffers()
-
- params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
- for p, fp32_p in zip(params_list, self.fp32_params):
- chunk_16 = self.chunk_manager.get_chunk(p)
- chunk_32 = self.chunk_manager.get_chunk(fp32_p)
- chunk_32.init_pair(chunk_16)
-
- # keep gathered chunks are in CUDA
- if chunk_16.keep_gathered:
- self.grads_device[p] = get_current_device()
-
- self._logger = get_dist_logger()
+ def _post_forward(self):
+ """This function is only triggered for inference.
+ """
+ access_list = list(self.chunk_manager.accessed_chunks)
+ # we need to scatter all accessed chunks and move them to their original places
+ for chunk in access_list:
+ if chunk.keep_gathered:
+ self.chunk_manager.fake_release_chunk(chunk)
+ else:
+ assert chunk.can_release
+ self.chunk_manager.release_chunk(chunk)
+ first_param = next(iter(chunk.tensors_info))
+ self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
+ assert self.chunk_manager.accessed_mem == 0
+ # reset all recorded attributes
+ self.gemini_manager.reset_attributes()
def forward(self, *args, **kwargs):
+ # check whether we are in a inference mode
+ grad_flag = torch.is_grad_enabled()
+ if not grad_flag:
+ assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
+ ), "You should run a completed iteration as your warmup iter"
+
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
+ # scatter chunks in the inference mode
+ if not grad_flag:
+ self._post_forward()
+
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
def _setup_grads_ptr(self):
for p in self.module.parameters():
- if getattr(p, '_ddp_to_ignore', False):
+ if is_ddp_ignored(p):
continue
p.grad = None
+ def _pre_backward(self):
+ # set a visit label for all parameters
+ # the label is used to check whether the parameter is correctly reduced
+ for param in self.param2name:
+ if not is_ddp_ignored(param):
+ setattr(param, "_gemini_reduced", False)
+
def _post_backward(self):
if self.chunk_manager.accessed_mem != 0:
+ error_params = ["Reduction failed at followed parameters:"]
+ for param in self.param2name:
+ if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
+ error_params.append(self.param2name[param])
+ error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
- "The most possible reason is that the model is not compatible with ZeroDDP.")
+ "The most possible reason is that the model is not compatible with ZeroDDP.\n",
+ f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
@@ -293,6 +318,7 @@ def _post_backward(self):
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
+ self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
@@ -307,7 +333,9 @@ def grad_handle(self, p, grad):
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
- assert chunk.tensors_info[p].state == TensorState.HOLD_AFTER_BWD
+ if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
+ raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
+ "Some unsupported torch function is operated upon this parameter.")
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
@@ -332,21 +360,19 @@ def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
- r"""Returns a dictionary containing a whole state of the module.
+ """Returns a dictionary containing a whole state of the module.
- Both parameters and persistent buffers (e.g. running averages) are
- included. Keys are corresponding parameter and buffer names.
+ Both parameters and persistent buffers (e.g. running averages) are included.
+ Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
+ Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
+ are shared with other parameters which have been included in the dictionary.
+ When you need to load the state dict, you should set the argument `strict` to False.
+
Returns:
dict:
a dictionary containing a whole state of the module
-
- Example:
-
- >>> module.state_dict().keys()
- ['bias', 'weight']
-
"""
if destination is None:
destination = OrderedDict()
@@ -404,13 +430,24 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
+ # get copies of fp32 parameters in CPU
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
- # TODO: (HELSON) deal with ddp ignored parameters
- for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
- if p is not None:
- assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
- record_parameter = param_to_save_data[fp32_p]
- destination[prefix + name] = record_parameter
+ # get the mapping between copies and fp16 parameters
+ p_mapping = dict()
+ for p, fp32_p in zip(self.fp16_params, self.fp32_params):
+ name = self.param2name[p]
+ assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
+ record_parameter = param_to_save_data[fp32_p]
+ p_mapping[p] = record_parameter
+ for name, param in self.name2param.items():
+ if param is not None:
+ if is_ddp_ignored(param):
+ # deal with ddp ignored parameters
+ destination[prefix + name] = param if keep_vars else param.detach()
+ else:
+ destination[prefix + name] = p_mapping[param]
+ del p_mapping
+ del param_to_save_data
# save all buffers
for name, buf in self.named_buffers():
@@ -542,9 +579,15 @@ def load(param_name, dest_tensor, copy_func):
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())
+ for name, param in self.named_parameters():
+ if is_ddp_ignored(param):
+ # deal with ddp ignored parameters
+ load(name, param, param.copy_)
+
fp32_to_name = dict()
- for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
+ for p, fp32_p in zip(self.fp16_params, self.fp32_params):
if p is not None:
+ name = self.param2name[p]
fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
@@ -591,6 +634,60 @@ def load_fp32_parameter(chunk_slice, data):
if input_name not in local_state:
unexpected_keys.append(key)
+ def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
+ ddp_pg = ColoProcessGroup()
+ for p in param_order.generate():
+ assert isinstance(p, ColoParameter)
+
+ # gather sharded parameters in the strict ddp mode
+ if strict_ddp_mode:
+ if not p.is_replicate():
+ p.set_dist_spec(ReplicaSpec())
+ p.set_process_group(pg=ddp_pg)
+
+ # ignore the parameters with no gradient
+ if not p.requires_grad:
+ self.set_params_to_ignore([p])
+
+ # move ignored parameters to CUDA
+ if is_ddp_ignored(p):
+ p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
+ continue
+
+ # create a fp32 parameter
+ fp32_data = p.data.float()
+ fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
+ # create a fp16 parameter
+ p.data = p.data.half()
+
+ # register the fp16 parameter and fp32 parameter in the chunk manager
+ dp_world_size = p.process_group.dp_world_size()
+ self.chunk_manager.register_tensor(tensor=p,
+ group_type='fp16_param',
+ config_key=dp_world_size,
+ cpu_offload=cpu_offload,
+ pin_memory=pin_memory)
+ self.chunk_manager.register_tensor(tensor=fp32_p,
+ group_type='fp32_param',
+ config_key=dp_world_size,
+ cpu_offload=cpu_offload,
+ pin_memory=pin_memory)
+
+ self.fp16_params.append(p)
+ self.fp32_params.append(fp32_p)
+ self.grads_device[p] = self.gemini_manager.default_device
+
+ self.chunk_manager.close_all_groups()
+
+ for p, fp32_p in zip(self.fp16_params, self.fp32_params):
+ chunk_16 = self.chunk_manager.get_chunk(p)
+ chunk_32 = self.chunk_manager.get_chunk(fp32_p)
+ chunk_32.init_pair(chunk_16)
+
+ # keep gathered chunks are in CUDA
+ if chunk_16.keep_gathered:
+ self.grads_device[p] = get_current_device()
+
def _cast_buffers(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()
diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py
index cd5ef424a1d9..2c6e15d91736 100644
--- a/colossalai/nn/parallel/gemini_parallel.py
+++ b/colossalai/nn/parallel/gemini_parallel.py
@@ -17,9 +17,10 @@ def __init__(self,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
+ strict_ddp_mode: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
- min_chunk_size_mb: Optional[float] = None,
+ min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
@@ -48,10 +49,15 @@ def __init__(self,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
+ # some ugly hotfix for the compatibility with Lightning
+ if search_range_mb is None:
+ search_range_mb = 32
+
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
- min_chunk_size_mb=min_chunk_size_mb)
+ min_chunk_size_mb=min_chunk_size_mb,
+ strict_ddp_flag=strict_ddp_mode)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
- super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32)
+ super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)
diff --git a/colossalai/nn/parallel/utils.py b/colossalai/nn/parallel/utils.py
index 1205cbc3a658..08fdb6026e38 100644
--- a/colossalai/nn/parallel/utils.py
+++ b/colossalai/nn/parallel/utils.py
@@ -47,30 +47,29 @@ def _get_shallow_copy_model(model: nn.Module):
"""Get a shallow copy of the given model. Each submodule is different from the original submodule.
But the new submodule and the old submodule share all attributes.
"""
- name_to_module = dict()
+ old_to_new = dict()
for name, module in _get_dfs_module_list(model):
new_module = copy(module)
new_module._modules = OrderedDict()
for subname, submodule in module._modules.items():
if submodule is None:
continue
- full_name = name + ('.' if name else '') + subname
- setattr(new_module, subname, name_to_module[full_name])
- name_to_module[name] = new_module
- return name_to_module['']
+ setattr(new_module, subname, old_to_new[submodule])
+ old_to_new[module] = new_module
+ return old_to_new[model]
-def get_static_torch_model(gemini_ddp_model,
+def get_static_torch_model(zero_ddp_model,
device=torch.device("cpu"),
dtype=torch.float32,
only_rank_0=True) -> torch.nn.Module:
- """Get a static torch.nn.Module model from the given GeminiDDP module.
- You should notice that the original GeminiDDP model is not modified.
+ """Get a static torch.nn.Module model from the given ZeroDDP module.
+ You should notice that the original ZeroDDP model is not modified.
Thus, you can use the original model in further training.
But you should not use the returned torch model to train, this can cause unexpected errors.
Args:
- gemini_ddp_model (GeminiDDP): a gemini ddp model
+ zero_ddp_model (ZeroDDP): a zero ddp model
device (torch.device): the device of the final torch model
dtype (torch.dtype): the dtype of the final torch model
only_rank_0 (bool): if True, only rank0 has the coverted torch model
@@ -78,16 +77,14 @@ def get_static_torch_model(gemini_ddp_model,
Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
"""
- from colossalai.nn.parallel import GeminiDDP
- assert isinstance(gemini_ddp_model, GeminiDDP)
+ from colossalai.nn.parallel import ZeroDDP
+ assert isinstance(zero_ddp_model, ZeroDDP)
- state_dict = gemini_ddp_model.state_dict(only_rank_0=only_rank_0)
- colo_model = gemini_ddp_model.module
+ state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)
+ colo_model = zero_ddp_model.module
torch_model = _get_shallow_copy_model(colo_model)
if not only_rank_0 or dist.get_rank() == 0:
- # record the mapping relationship between colo parameters and torch parameters
- colo_to_torch = dict()
for (name, colo_module), (_, torch_module) in \
zip(_get_dfs_module_list(colo_model), _get_dfs_module_list(torch_model)):
# clean the parameter list of the new torch module
@@ -95,17 +92,10 @@ def get_static_torch_model(gemini_ddp_model,
for sufix_param_name, param in colo_module.named_parameters(recurse=False):
# get the full name of the parameter
full_param_name = name + ('.' if name else '') + sufix_param_name
-
- if full_param_name not in state_dict:
- # this means the parameter is shared by multiple modules
- # we should use colo_to_torch to get the torch parameter created before
- assert param in colo_to_torch, f"can not find parameter `{full_param_name}` in the GeminiDDP module"
- torch_param = colo_to_torch[param]
- else:
- # we meet the parameter the first time, just use the state dict to get the data
- state_param = state_dict[full_param_name]
- torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
- colo_to_torch[param] = torch_param
+ assert full_param_name in state_dict, \
+ f"Can not find parameter `{full_param_name}` in the GeminiDDP module"
+ state_param = state_dict[full_param_name]
+ torch_param = torch.nn.Parameter(state_param.data.to(device=device, dtype=dtype))
setattr(torch_module, sufix_param_name, torch_param)
dist.barrier()
diff --git a/colossalai/nn/parallel/zero_wrapper.py b/colossalai/nn/parallel/zero_wrapper.py
new file mode 100644
index 000000000000..be8d1da7c24e
--- /dev/null
+++ b/colossalai/nn/parallel/zero_wrapper.py
@@ -0,0 +1,109 @@
+from copy import copy
+from typing import Dict, Optional
+
+import torch
+import torch.nn as nn
+
+from .gemini_parallel import GeminiDDP
+
+
+def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None):
+ """This wrapper function is used to wrap your training model for ZeRO DDP.
+
+ Example:
+
+ >>> with ColoInitContext():
+ >>> my_model = Bert()
+ >>> my_optim = SGD(my_model.parameters(), lr = 1e-3)
+ >>> zero_model = zero_model_wrapper(my_model, zero_stage=1)
+ >>> zero_optim = zero_optim_wrapper(zero_model, my_optim)
+
+ Args:
+ model (nn.Module): The model used in ZeRO DDP.
+ zero_stage (int, optional): The stage of ZeRO DDP. You can find more information in ZeRO's paper.
+ https://arxiv.org/abs/1910.02054
+ gemini_config (dict, optional): The configuration dictionary of `GeminiDDP`. `GeminiDDP` is enabled
+ when the stage is set to 3. You can set the arguemnts of `GeminiDDP` in the gemini_config.
+ Here is an example where we set the device of the model, the placement policy of Gemini, and the
+ size of hidden dimension to help Gemini find out a unified chunk size.
+
+ Example:
+
+ >>> config_dict = dict(device=torch.cuda.current_device(), hidden_dim=1024, placement_policy='auto')
+ >>> model = zero_model_wrapper(model, zero_stage=3, gemini_config=config_dict)
+ """
+ assert zero_stage in [1, 2, 3], "The stage of ZeRO should be 1, 2 or 3"
+
+ if gemini_config is None:
+ gemini_config = dict()
+
+ if zero_stage in [1, 2]:
+ wrapped_model = model
+ else:
+ wrapped_model = GeminiDDP(model, **gemini_config)
+
+ setattr(wrapped_model, "_colo_zero_stage", zero_stage)
+
+ return wrapped_model
+
+
+def zero_optim_wrapper(model: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ initial_scale: float = 2**16,
+ growth_factor: float = 2,
+ backoff_factor: float = 0.5,
+ growth_interval: int = 1000,
+ hysteresis: int = 2,
+ min_scale: float = 1,
+ max_scale: float = 2**32,
+ max_norm: float = 0.0,
+ norm_type: float = 2.0,
+ optim_config: Optional[Dict] = None):
+ """This wrapper function is used to wrap your training optimizer for ZeRO DDP.
+
+ Args:
+ model (nn.Module): Your model wrapped by `zero_model_wrapper`
+ optimizer (torch.optim.Optimizer): Your initialized optimizer
+ initial_scale (float, optional): initial_scale used by DynamicGradScaler.
+ min_scale (float, optional): min_scale used by DynamicGradScaler.
+ growth_factor (float, optional): growth_factor used by DynamicGradScaler.
+ backoff_factor (float, optional): backoff_factor used by DynamicGradScaler.
+ growth_interval (float, optional): growth_interval used by DynamicGradScaler.
+ hysteresis (float, optional): hysteresis used by DynamicGradScaler.
+ max_scale (int, optional): max_scale used by DynamicGradScaler.
+ max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
+ clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
+ norm_type (float, optional): norm_type used for `clip_grad_norm`.
+ optim_config (dict, optinoal): The configuration used for the ZeRO optimizer.
+ Example:
+
+ >>> zero2_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True)
+ >>> optim = zero_optim_wrapper(model, optim, optim_config=zero2_config)
+ """
+ assert hasattr(model, "_colo_zero_stage"), "You should use `zero_ddp_wrapper` first"
+ zero_stage = getattr(model, "_colo_zero_stage")
+
+ assert norm_type == 2.0, "Current ZeRO optimizers only support 'norm_type=2'"
+
+ if optim_config is None:
+ config_dict = dict()
+ else:
+ config_dict = copy(optim_config)
+
+ config_dict['initial_scale'] = initial_scale
+ config_dict['growth_factor'] = growth_factor
+ config_dict['backoff_factor'] = backoff_factor
+ config_dict['growth_interval'] = growth_interval
+ config_dict['hysteresis'] = hysteresis
+ config_dict['min_scale'] = min_scale
+ config_dict['max_scale'] = max_scale
+
+ if zero_stage in [1, 2]:
+ from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer
+ config_dict['partition_grad'] = zero_stage == 2
+ config_dict['clip_grad_norm'] = max_norm
+ return LowLevelZeroOptimizer(optimizer, **config_dict)
+ else:
+ from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
+ config_dict['clipping_norm'] = max_norm
+ return ZeroOptimizer(optimizer, model, **config_dict)
diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py
index 4739cdaa9bd3..2d7e25c82e7b 100644
--- a/colossalai/pipeline/rpc/_pipeline_base.py
+++ b/colossalai/pipeline/rpc/_pipeline_base.py
@@ -211,7 +211,7 @@ def _get_output_all(self, key: UniqueKey, ref_use=False, rank=None):
refcount = 0
with self.output_list_condition_lock:
- if refcount < lifecycle:
+ if refcount <= lifecycle:
self.output_list[key] = output_work_item
self.output_list_condition_lock.notify_all()
@@ -390,7 +390,7 @@ def _subscribe_producer(self, microbatch_id: int, forward_only: bool):
subscribe_forward_futures[target_index] = []
else:
subscribe_forward_futures[target_index] = producer_worker_rref.rpc_async().get_output_by_key(
- producer_output_key, rank=self.pp_rank)
+ producer_output_key, rank=self.pp_rank, offsets=offsets)
else:
for i in range(producer_num):
@@ -1115,7 +1115,8 @@ def _init_worker(self) -> None:
# let each worker know global worker rref (include itself)
sync_futs = []
for pp_rank in self.pp_rank_to_worker_rref:
- fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async().sync_global_worker_rrefs(self.pp_rank_to_worker_rref)
+ fut = self.pp_rank_to_worker_rref[pp_rank].rpc_async(timeout=0).sync_global_worker_rrefs(
+ self.pp_rank_to_worker_rref)
sync_futs.append(fut)
for fut in sync_futs:
diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py
index e6aa961f19bc..0d572231d378 100644
--- a/colossalai/pipeline/rpc/_pipeline_schedule.py
+++ b/colossalai/pipeline/rpc/_pipeline_schedule.py
@@ -29,9 +29,6 @@ def _get_work_item_key(self) -> UniqueKey:
target_key = UniqueKey(target_microbatch_id, target_phase)
- with self.work_list_condition_lock:
- self.work_list_condition_lock.wait_for(lambda: target_key in self.work_list)
-
return target_key
diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py
index 92220d9e2a38..b384579feb35 100644
--- a/colossalai/tensor/colo_parameter.py
+++ b/colossalai/tensor/colo_parameter.py
@@ -71,7 +71,7 @@ def from_torch_tensor(tensor: torch.Tensor,
return tensor
def __repr__(self):
- return f'ColoParameter: {ColoTensor.__repr__(self)}'
+ return super(ColoParameter, self).__repr__()
@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py
index 3712d6a0acea..bbed8847abbc 100644
--- a/colossalai/tensor/colo_tensor.py
+++ b/colossalai/tensor/colo_tensor.py
@@ -1,5 +1,6 @@
+import operator
from copy import copy
-from functools import lru_cache
+from functools import lru_cache, reduce
from typing import Callable, Optional, Set
import torch
@@ -188,7 +189,12 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
return _convert_output(ret, colo_spec)
def __repr__(self):
- return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}\n{self.compute_spec}'
+ output_list = [super(ColoTensor, self).__repr__()]
+ output_list.append(str(self.process_group))
+ output_list.append(str(self.dist_spec))
+ if self.compute_spec is not None:
+ output_list.append(str(self.compute_spec))
+ return "\n".join(output_list)
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
@@ -303,6 +309,11 @@ def size_global(self, *args) -> torch.Size:
else:
return size_list[args[0]]
+ def numel_global(self):
+ """Returns the number of elements in the tensor when it's replicated.
+ """
+ return reduce(operator.mul, self.size_global(), 1)
+
# Some API for dist spec check
def is_replicate(self):
diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py
index 3c9e0fd56696..0d8de1062d42 100644
--- a/colossalai/tensor/comm_spec.py
+++ b/colossalai/tensor/comm_spec.py
@@ -429,6 +429,7 @@ def __repr__(self):
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
+ res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ")
@@ -437,6 +438,7 @@ def __repr__(self):
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ")
+ res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
@@ -463,7 +465,7 @@ def get_comm_cost(self):
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
# give a tiny cost to shard
- backward_communication_cost = 10
+ backward_communication_cost = 100
if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
@@ -481,13 +483,13 @@ def get_comm_cost(self):
if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
# give a tiny cost to shard
- forward_communication_cost = 10
+ forward_communication_cost = 100
backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD:
# no need for axis because all devices are used in mix_gather
forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size)
- backward_communication_cost = 10
+ backward_communication_cost = 100
if self.forward_only:
cost_dict["forward"] = forward_communication_cost
diff --git a/colossalai/tensor/compute_spec.py b/colossalai/tensor/compute_spec.py
index a9774c34c01b..73328285ee93 100644
--- a/colossalai/tensor/compute_spec.py
+++ b/colossalai/tensor/compute_spec.py
@@ -9,9 +9,9 @@ class ComputePattern(Enum):
class ComputeSpec(object):
- """ComputeSpec
+ """ComputeSpec
The Specification for compuattion pattern
-
+
Args:
compute_pattern (ComputePattern): an Enum instance for compute pattern.
"""
@@ -23,7 +23,7 @@ def __init__(self, compute_pattern: ComputePattern) -> None:
self.output_replicate = True
def __repr__(self):
- return f'Compute pattern: {self.compute_pattern}'
+ return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})'
def set_output_replicate(self, flag: bool = True):
self.output_replicate = flag
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/distributions/__init__.py b/colossalai/tensor/d_tensor/__init__.py
similarity index 100%
rename from examples/tutorial/stable_diffusion/ldm/modules/distributions/__init__.py
rename to colossalai/tensor/d_tensor/__init__.py
diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py
new file mode 100644
index 000000000000..765d8ec1b01a
--- /dev/null
+++ b/colossalai/tensor/d_tensor/comm_spec.py
@@ -0,0 +1,310 @@
+from enum import Enum
+from typing import Dict
+
+import torch
+import torch.distributed as dist
+from torch.distributed import ReduceOp
+
+__all__ = [
+ 'CollectiveCommPattern',
+ 'CommSpec',
+]
+
+
+class CollectiveCommPattern(Enum):
+ GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
+ ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
+ SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
+ ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
+ IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
+ MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
+
+
+class CommSpec:
+ '''
+ Communication spec is used to record the communication action. It converts the communication spec
+ to real action which will be used in runtime. It contains comm_pattern to determine the
+ communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim
+ to determine the buffer shape, and logical_process_axis
+
+ Argument:
+ comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
+ process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
+ gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
+ shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
+ logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
+ '''
+
+ def __init__(self,
+ comm_pattern: CollectiveCommPattern,
+ process_groups_dict: Dict,
+ gather_dim: int = None,
+ shard_dim: int = None,
+ logical_process_axis: int = None):
+ self.comm_pattern = comm_pattern
+ self.gather_dim = gather_dim
+ self.shard_dim = shard_dim
+ self.logical_process_axis = logical_process_axis
+ self.process_groups_dict = process_groups_dict
+
+ def __repr__(self):
+ res_list = ["CommSpec:("]
+ if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
+ res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ")
+ res_list.append(f"gather_dim:{self.gather_dim}, ")
+ res_list.append(f"shard_dim:{self.gather_dim}, ")
+ res_list.append(f"logical_process_axis:{self.logical_process_axis})")
+ elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
+ res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ")
+ res_list.append(f"gather_dim:{self.gather_dim}, ")
+ res_list.append(f"shard_dim:{self.shard_dim}, ")
+ res_list.append(f"logical_process_axis: {self.logical_process_axis})")
+ elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
+ res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ")
+ res_list.append(f"gather_dim:{self.gather_dim}, ")
+ res_list.append(f"shard_dim:{self.shard_dim}, ")
+ res_list.append(f"logical_process_axis:{self.logical_process_axis})")
+ elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
+ res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ")
+ res_list.append(f"logical_process_axis:{self.logical_process_axis})")
+ elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
+ res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
+ res_list.append(f"logical_process_axis:{self.logical_process_axis})")
+
+ return ''.join(res_list)
+
+ def covert_spec_to_action(self, tensor):
+ '''
+ Convert CommSpec into runtime action, implement real collection communication to target tensor.
+ The collection communication action is directed by the CommSpec.
+
+ Argument:
+ tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
+ '''
+ if self.comm_pattern in pattern_to_func_dict:
+ tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
+ else:
+ tensor = tensor
+ return tensor
+
+
+def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
+ '''
+ Implement all gather operation on device mesh based on information provided by comm_spec.
+ '''
+ process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, process_group in process_groups_list:
+ if dist.get_rank() in rank_list:
+ tensor_list = [
+ torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
+ ]
+ # without this contiguous operation, the all gather may get some unexpected results.
+ tensor = tensor.contiguous()
+ dist.all_gather(tensor_list, tensor, group=process_group)
+ output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
+ return output
+
+
+def _split(tensor: torch.Tensor, comm_spec: CommSpec):
+ '''
+ Implement shard operation on device mesh based on information provided by comm_spec.
+ '''
+ process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, _ in process_groups_list:
+ if dist.get_rank() in rank_list:
+ dim = comm_spec.shard_dim
+ length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
+ start = length * rank_list.index(dist.get_rank())
+ output = torch.narrow(tensor, dim, start, length).contiguous()
+ return output
+
+
+def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
+ '''
+ Implement all to all operation on device mesh based on information provided by comm_spec.
+ '''
+ process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, process_group in process_groups_list:
+ if dist.get_rank() in rank_list:
+ new_shape = list(tensor.shape)
+ new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
+ new_shape = torch.Size(new_shape)
+ output_tensor_list = [
+ torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
+ ]
+ dim = comm_spec.shard_dim
+ length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
+ input_tensor_list = [
+ torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
+ ]
+ group = process_group
+ dist.all_to_all(output_tensor_list, input_tensor_list, group)
+ output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
+ return output
+
+
+def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
+ '''
+ Implement all reduce operation on device mesh based on information provided by comm_spec.
+ '''
+ process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
+ for rank_list, process_group in process_groups_list:
+ if dist.get_rank() in rank_list:
+ if not tensor.is_contiguous():
+ tensor = tensor.contiguous()
+ dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
+ return tensor
+
+
+class _ReduceGrad(torch.autograd.Function):
+ """
+ A customized communication operation which forward is an identity operation,
+ backward is all_reduce operation.
+
+ Args:
+ input_: input matrix.
+ comm_spec: comm_spec will give information like process group, rank list, etc.
+ """
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return input_
+
+ @staticmethod
+ def forward(ctx, input_, comm_spec):
+ ctx.comm_spec = comm_spec
+ return input_
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _all_reduce(grad_output, ctx.comm_spec), None
+
+
+class _ReduceInput(torch.autograd.Function):
+ """
+ A customized communication operation which forward is all_reduce operation,
+ backward is an identity operation.
+
+ Args:
+ input_: input matrix.
+ comm_spec: comm_spec will give information like process group, rank list, etc.
+ """
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _all_reduce(input_)
+
+ @staticmethod
+ def forward(ctx, input_, comm_spec):
+ return _all_reduce(input_, comm_spec)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output, None
+
+
+class _SplitForwardGatherBackward(torch.autograd.Function):
+ """
+ A customized communication operation which forward is split operation,
+ backward is an all gather operation.
+
+ Args:
+ input_: input matrix.
+ comm_spec: comm_spec will give information like process group, rank list, etc.
+ """
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _split(input_)
+
+ @staticmethod
+ def forward(ctx, input_, comm_spec):
+ ctx.comm_spec = comm_spec
+ return _split(input_, comm_spec)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _all_gather(grad_output, ctx.comm_spec), None
+
+
+class _GatherForwardSplitBackward(torch.autograd.Function):
+ """
+ A customized communication operation which forward is an all gather operation,
+ backward is split operation.
+
+ Args:
+ input_: input matrix.
+ comm_spec: comm_spec will give information like process group, rank list, etc.
+ """
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _all_gather(input_)
+
+ @staticmethod
+ def forward(ctx, input_, comm_spec):
+ ctx.comm_spec = comm_spec
+ return _all_gather(input_, comm_spec)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return _split(grad_output, ctx.comm_spec), None
+
+
+class _AllToAll(torch.autograd.Function):
+ """
+ A customized communication operation which forward is an all to all operation,
+ backward is an all to all operation.
+
+ Args:
+ input_: input matrix.
+ comm_spec: comm_spec will give information like process group, rank list, etc.
+ """
+
+ @staticmethod
+ def symbolic(graph, input_):
+ return _all_to_all(input_)
+
+ @staticmethod
+ def forward(ctx, input_, comm_spec):
+ output = _all_to_all(input_, comm_spec)
+ comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
+ process_groups_dict=comm_spec.process_groups_dict,
+ gather_dim=comm_spec.shard_dim,
+ shard_dim=comm_spec.gather_dim,
+ logical_process_axis=comm_spec.logical_process_axis)
+ ctx.comm_spec = comm_spec_for_backward
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_outputs):
+ return _all_to_all(grad_outputs, ctx.comm_spec), None
+
+
+def reduce_grad(input_, comm_spec):
+ return _ReduceGrad.apply(input_, comm_spec)
+
+
+def reduce_input(input_, comm_spec):
+ return _ReduceInput.apply(input_, comm_spec)
+
+
+def split_forward_gather_backward(input_, comm_spec):
+ return _SplitForwardGatherBackward.apply(input_, comm_spec)
+
+
+def gather_forward_split_backward(input_, comm_spec):
+ return _GatherForwardSplitBackward.apply(input_, comm_spec)
+
+
+def all_to_all(input_, comm_spec):
+ return _AllToAll.apply(input_, comm_spec)
+
+
+pattern_to_func_dict = {
+ CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward,
+ CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all,
+ CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward,
+ CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input,
+ CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad,
+}
diff --git a/colossalai/tensor/d_tensor/d_tensor.py b/colossalai/tensor/d_tensor/d_tensor.py
new file mode 100644
index 000000000000..c1fe9d50a048
--- /dev/null
+++ b/colossalai/tensor/d_tensor/d_tensor.py
@@ -0,0 +1,142 @@
+from typing import Optional
+
+import torch
+from torch.utils._pytree import tree_map
+
+from .layout import Layout
+from .layout_converter import LayoutConverter, to_global
+from .sharding_spec import ShardingSpec
+
+layout_converter = LayoutConverter()
+
+
+class DTensor(torch.Tensor):
+
+ def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout):
+ self.local_tensor = local_tensor
+ self.data_type = local_tensor.dtype
+ self.entire_shape = local_tensor.shape
+ self.dist_layout = dist_layout
+ self._apply_layout()
+
+ @staticmethod
+ def __new__(cls, local_tensor, layout):
+ return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad)
+
+ def __repr__(self):
+ return f"DTensor({self.to_global()}, {self.dist_layout})"
+
+ def __str__(self):
+ return self.__repr__()
+
+ def layout_convert(self, target_layout):
+ '''
+ Convert the layout of the tensor from source_spec to target_spec.
+ '''
+ self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout)
+ self.dist_layout = target_layout
+
+ def _apply_layout(self):
+ '''
+ Apply the layout to the local tensor during initializing process.
+ '''
+ source_spec = construct_default_sharding_spec(self.local_tensor)
+ source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
+ device_type=self.dist_layout.device_type,
+ sharding_spec=source_spec,
+ entire_shape=self.entire_shape)
+ self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout)
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+
+ def filter_arg(arg):
+ if isinstance(arg, DTensor):
+ return arg.local_tensor
+ else:
+ return arg
+
+ args = tree_map(filter_arg, args)
+ kwargs = tree_map(filter_arg, kwargs)
+ # if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
+ # and op type.
+
+ return func(*args, **kwargs)
+
+ @property
+ def device_mesh(self):
+ '''
+ Return the device mesh of the tensor.
+ '''
+ return self.dist_layout.device_mesh
+
+ @property
+ def sharding_spec(self):
+ '''
+ Return the sharding specification of the tensor.
+ '''
+ return self.dist_layout.sharding_spec
+
+ def to(self, *args, **kwargs):
+ '''
+ Move the tensor to a new device or convert the tensor to a new dtype.
+ '''
+ self.local_tensor = self.local_tensor.to(*args, **kwargs)
+ self.data_type = self.local_tensor.dtype
+ self.dist_layout.device_type = self.local_tensor.device
+ # TODO: update the device mesh process groups or we should just cache
+ # both the cpu process groups and the cuda process groups?
+ return self
+
+ def to_local(self):
+ '''
+ Return the local tensor in this rank.
+ '''
+ return self.local_tensor
+
+ def to_global(self):
+ '''
+ Recover the global tensor from the distributed tensor.
+
+ Note: This function will all_gather the local tensor to the global tensor and it
+ will not change the layout of the DTensor. This function is mainly used for debugging or
+ check the correctness of the distributed tensor.
+ '''
+ return to_global(self.local_tensor, self.dist_layout)
+
+
+def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor:
+ '''
+ Distribute the local tensor to the distributed tensor according to the dist_layout specified.
+
+ Args:
+ local_tensor: tensor to be distributed.
+ dist_layout: the layout specification of the distributed tensor.
+
+ Returns:
+ A 'DTensor' object.
+ '''
+ return DTensor(local_tensor, dist_layout)
+
+
+def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
+ '''
+ This function converts all the parameters in the module to DTensor(DParam).
+
+ Note: This function is subject to future change as the DParam has not been implemented yet.
+ '''
+ for name, param in module.named_parameters():
+ if param is not None and not isinstance(param, DTensor):
+ # TODO: we could convert the parameter to DParam here,
+ # the type of the parameter could be an optional argument.
+ setattr(module, name, torch.nn.Parameter(partition_fn(name, param.data)))
+ return module
+
+
+def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
+ '''
+ Construct the default sharding specification for the tensor.
+ '''
+ return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py
new file mode 100644
index 000000000000..72a2694a1eaf
--- /dev/null
+++ b/colossalai/tensor/d_tensor/layout.py
@@ -0,0 +1,68 @@
+import operator
+from dataclasses import dataclass
+from functools import reduce
+
+import torch
+
+from colossalai.device.device_mesh import DeviceMesh
+
+from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError
+from .sharding_spec import ShardingSpec
+
+
+class Layout:
+ """Layout of a tensor.
+
+ Attributes:
+ device_mesh: the device mesh to store the tensor distributedly.
+ device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
+ sharding_spec: the sharding specification to describe how the tensor is sharded.
+ entire_shape: the entire shape of the global tensor.
+ """
+
+ def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec,
+ entire_shape: torch.Size):
+ self.device_mesh = device_mesh
+ self.device_type = device_type
+ self.sharding_spec = sharding_spec
+ self.entire_shape = entire_shape
+ self._sanity_check()
+
+ def __hash__(self) -> int:
+ return hash(f'{self.sharding_spec}')
+
+ def get_sharded_shape_per_device(self):
+ sharded_shape = list(self.entire_shape)
+ for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
+ mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
+ shard_partitions = reduce(operator.mul, mesh_list, 1)
+ assert sharded_shape[
+ dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
+ sharded_shape[dim] //= shard_partitions
+ return torch.Size(sharded_shape)
+
+ def _sanity_check(self):
+ sharding_spec = self.sharding_spec
+
+ # make sure all axes in logical device mesh only be used once
+ dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
+ for dim, shard_list in sharding_spec.dim_partition_dict.items():
+ for element in shard_list:
+ if element in dim_check_list:
+ dim_check_list.remove(element)
+ else:
+ raise DuplicatedShardingDimensionError(
+ f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
+
+ # make sure that the sharding for a dimension is divisible by the number of devices
+ for dim, shard_list in sharding_spec.dim_partition_dict.items():
+ tensor_dim_size = self.entire_shape[dim]
+ num_devices = 1
+
+ for element in shard_list:
+ num_devices *= self.device_mesh.mesh_shape[element]
+
+ if tensor_dim_size % num_devices != 0:
+ raise ShardingNotDivisibleError(
+ f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
+ )
diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py
new file mode 100644
index 000000000000..cf02aac309f4
--- /dev/null
+++ b/colossalai/tensor/d_tensor/layout_converter.py
@@ -0,0 +1,556 @@
+import math
+from copy import deepcopy
+from dataclasses import dataclass
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
+from colossalai.context.singleton_meta import SingletonMeta
+from colossalai.tensor.d_tensor.comm_spec import *
+from colossalai.tensor.d_tensor.layout import Layout
+from colossalai.tensor.d_tensor.misc import LayoutException
+from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
+
+from .sharding_spec import ShardingSpec
+from .utils import get_comm_cost
+
+__all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_options']
+
+
+@dataclass
+class LayoutConverterOptions:
+ """
+ LayoutConverterOptions is a dataclass which specifies the preferences for layout converting.
+ """
+ # TODO: layout converter option is not implemented yet
+ pass
+
+
+def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
+ layout_converter = LayoutConverter()
+ global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
+ global_layout = Layout(device_mesh=layout.device_mesh,
+ device_type=layout.device_type,
+ sharding_spec=global_sharding_spec,
+ entire_shape=layout.entire_shape)
+ with torch.no_grad():
+ global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
+ return global_tensor
+
+
+def set_layout_converting_options(options: LayoutConverterOptions):
+ """
+ Configure the shape consistency manager via function call.
+ """
+ manager = LayoutConverter()
+ manager.options = options
+
+
+class LayoutConverter(metaclass=SingletonMeta):
+
+ def __init__(self):
+ self._options = None
+ self._forward_only = False
+ self.cached_solution = {}
+
+ @property
+ def options(self):
+ return self._options
+
+ @options.setter
+ def options(self, options_: LayoutConverterOptions):
+ assert isinstance(options_, LayoutConverterOptions)
+ self._options = options_
+
+ @property
+ def forward_only(self):
+ return self._forward_only
+
+ @forward_only.setter
+ def forward_only(self, value):
+ assert isinstance(value, bool)
+ self._forward_only = value
+
+ def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
+ '''
+ Get all valid layouts from source_layout with single all-gather operation.
+ For the all-gather operation, we just care about the S dimension.
+
+ Argument:
+ source_layout: the layout to be transformed.
+
+ Return:
+ valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-gather operation.
+
+ Example:
+ layout_converter = LayoutConverter()
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ # [[0, 1,
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ entire_shape = (4, 4, 4)
+ dim_partition_dict = {0: [0], 1: [1]}
+
+ # [S0,S1,R]
+ sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
+ layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec,
+ entire_shape=entire_shape)
+
+ rst_dict = layout_converter.all_gather_transform_layouts(layout)
+ for layout, comm_spec in rst_dict.items():
+ print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}')
+
+ Output:
+ [R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0)
+ [S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
+ '''
+ valid_spec_dict = {}
+ comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
+ source_spec = source_layout.sharding_spec
+ process_groups_dict = source_layout.device_mesh.process_groups_dict
+ for target_pair in source_spec.dim_partition_dict.items():
+ shard_list = all_gather_simulator(target_pair)
+ index = target_pair[0]
+ new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
+
+ # We won't add empty list into dim_partition_dict
+ # The key will be popped if the related shard_list is empty
+ if shard_list:
+ new_dim_partition_dict[index] = shard_list
+ else:
+ new_dim_partition_dict.pop(index)
+
+ # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
+ gather_dim = index
+ logical_process_axis = target_pair[1][-1]
+ comm_spec = CommSpec(
+ comm_pattern,
+ process_groups_dict=process_groups_dict,
+ gather_dim=gather_dim,
+ # shard_dim will be used during backward
+ shard_dim=gather_dim,
+ logical_process_axis=logical_process_axis)
+
+ # generate new sharding spec
+ try:
+ new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
+ new_layout = Layout(device_mesh=source_layout.device_mesh,
+ sharding_spec=new_sharding_spec,
+ device_type=source_layout.device_type,
+ entire_shape=source_layout.entire_shape)
+
+ valid_spec_dict[new_layout] = comm_spec
+ except LayoutException:
+ pass
+ return valid_spec_dict
+
+ def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
+ '''
+ Get all valid layouts from source_layout with single all-to-all operation.
+ For the all-to-all operation, we just care about the pairs containing S dimension.
+
+ Argument:
+ source_layout(Layout): the layout to be transformed.
+
+ Return:
+ valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-to-all operation.
+
+ Example:
+ layout_converter = LayoutConverter()
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ # [[0, 1,
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ entire_shape = (4, 4, 4)
+ dim_partition_dict = {0: [0], 1: [1]}
+
+ # [S0,S1,R]
+ sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
+ layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec,
+ entire_shape=entire_shape)
+ rst_dict = layout_converter.all_to_all_transform_layout(layout)
+
+ for layout, comm_spec in rst_dict.items():
+ print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}')
+
+ Output:
+ [S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1)
+ [R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0)
+ [S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1)
+ '''
+ valid_spec_dict = {}
+ comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
+ process_groups_dict = source_layout.device_mesh.process_groups_dict
+ source_spec = source_layout.sharding_spec
+ tensor_dims = source_spec.dims
+ for f_index in range(tensor_dims - 1):
+ for b_index in range(f_index + 1, tensor_dims):
+ # skip (R, R) cases
+ if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict:
+ continue
+ else:
+ if f_index in source_spec.dim_partition_dict:
+ # skip (S01, R) -> (R, S01) is NOT allowed
+ if len(source_spec.dim_partition_dict[f_index]) >= 2:
+ continue
+ f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
+ else:
+ f_target_pair = (f_index, [])
+ if b_index in source_spec.dim_partition_dict:
+ # skip (R, S01) -> (S01, R) is NOT allowed
+ if len(source_spec.dim_partition_dict[b_index]) >= 2:
+ continue
+ b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
+ else:
+ b_target_pair = (b_index, [])
+
+ # skip (S1, S0) -> S10
+ if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]:
+ continue
+ f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair)
+ f_index = f_target_pair[0]
+ b_index = b_target_pair[0]
+
+ # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
+ if len(f_shard_list) < len(f_target_pair[1]):
+ gather_dim = f_index
+ shard_dim = b_index
+ logical_process_axis = f_target_pair[1][-1]
+ else:
+ gather_dim = b_index
+ shard_dim = f_index
+ logical_process_axis = b_target_pair[1][-1]
+ comm_spec = CommSpec(comm_pattern,
+ process_groups_dict,
+ gather_dim=gather_dim,
+ shard_dim=shard_dim,
+ logical_process_axis=logical_process_axis)
+
+ new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
+
+ # We won't add empty list into dim_partition_dict
+ # The key will be popped if the related shard_list is empty
+ if f_shard_list:
+ new_dim_partition_dict[f_index] = f_shard_list
+ else:
+ new_dim_partition_dict.pop(f_index)
+ if b_shard_list:
+ new_dim_partition_dict[b_index] = b_shard_list
+ else:
+ new_dim_partition_dict.pop(b_index)
+
+ # generate new sharding spec
+ try:
+ new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
+ new_layout = Layout(device_mesh=source_layout.device_mesh,
+ sharding_spec=new_sharding_spec,
+ device_type=source_layout.device_type,
+ entire_shape=source_layout.entire_shape)
+ valid_spec_dict[new_layout] = comm_spec
+ except LayoutException:
+ pass
+
+ return valid_spec_dict
+
+ def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
+ '''
+ Get all valid layouts from source_layout with single shard operation.
+ For the sharding operation, we just care about legal sharding dimensions.
+
+ Argument:
+ source_layout(Layout): the layout to be transformed.
+
+ Return:
+ valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single shard operation.
+
+ Example:
+ layout_converter = LayoutConverter()
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ # [[0, 1,
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ entire_shape = (4, 4, 4)
+
+ dim_partition_dict = {0: [0]}
+
+ # [S0,R,R]
+ sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
+ layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec,
+ entire_shape=entire_shape)
+ rst_dict = layout_converter.shard_transform_layout(layout)
+
+ for layout, comm_spec in rst_dict.items():
+ print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}')
+
+ Output:
+ [S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1)
+ [S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1)
+ [S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1)
+ '''
+ valid_spec_dict = {}
+ comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
+ source_spec = source_layout.sharding_spec
+ process_groups_dict = source_layout.device_mesh.process_groups_dict
+
+ # legal sharding dims means the mesh_id is still available to use.
+ legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
+ for dim, shard_list in source_spec.dim_partition_dict.items():
+ for element in shard_list:
+ legal_sharding_dims.remove(element)
+
+ if len(legal_sharding_dims) == 0:
+ return valid_spec_dict
+
+ tensor_dims = source_spec.dims
+
+ for index in range(tensor_dims):
+ if index not in source_spec.dim_partition_dict:
+ shard_list_list = shard_simulator((index, []), legal_sharding_dims)
+ else:
+ shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims)
+ if not shard_list_list:
+ continue
+ for shard_list in shard_list_list:
+ new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict)
+ new_dim_partition_dict[index] = shard_list
+
+ # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec
+ shard_dim = index
+ logical_process_axis = shard_list[-1]
+ comm_spec = CommSpec(comm_pattern,
+ process_groups_dict,
+ gather_dim=shard_dim,
+ shard_dim=shard_dim,
+ logical_process_axis=logical_process_axis)
+
+ # generate new sharding spec
+ try:
+ new_sharding_spec = ShardingSpec(dim_size=source_spec.dims,
+ dim_partition_dict=new_dim_partition_dict)
+ new_layout = Layout(device_mesh=source_layout.device_mesh,
+ sharding_spec=new_sharding_spec,
+ device_type=source_layout.device_type,
+ entire_shape=source_layout.entire_shape)
+ valid_spec_dict[new_layout] = comm_spec
+ except LayoutException:
+ pass
+ return valid_spec_dict
+
+ def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]:
+ '''
+ Get all valid layouts from source_layout with one step transform.
+
+ Note:
+ all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before,
+ and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive,
+ we could safely put them together.
+
+ Argument:
+ source_layout(Layout): the layout to be transformer.
+
+ Return:
+ valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform.
+ '''
+ valid_spec_dict = {}
+ valid_spec_dict.update(self.all_gather_transform_layouts(source_layout))
+ valid_spec_dict.update(self.all_to_all_transform_layout(source_layout))
+ valid_spec_dict.update(self.shard_transform_layout(source_layout))
+ return valid_spec_dict
+
+ def layout_converting(self, source_layout: Layout,
+ target_layout: Layout) -> Tuple[List[Layout], List[CommSpec], float]:
+ '''
+ This method will find a path to transform source_layout to target_layout with
+ a greedy algorithm.
+ The basic idea is:
+ Step1:
+ Generate all one-step transform sequences from source_layout.
+ Step2:
+ Pick the 'best' layout following the heuristic function.
+ Step3:
+ Repeat above steps until the source layout transform to target layout.
+
+ Additionally, to avoid repeating the path search in runtime, we cached all solved path
+ in auto parallel strategy building time, which could handle most of cases in runtime.
+
+ Args:
+ source_layout(Layout): the layout to be transformed.
+ target_layout(Layout): the layout to be achieved after a serious of transforms.
+
+ Return:
+ transform_path(List[Layout]): The transform path from source_layout to target_layout,
+ it contains the source_layout and target_layout.
+ comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the layout converting in order.
+
+ Example:
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ # [[0, 1,
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ entire_shape = (4, 4, 4)
+
+ dim_partition_source = {1: [0, 1]}
+ dim_partition_target = {0: [0, 1]}
+
+ # [R,S01,R]
+ sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
+ source_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_source,
+ entire_shape=entire_shape)
+
+ # [S01,R,R]
+ sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
+ target_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_target,
+ entire_shape=entire_shape)
+
+ transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
+ transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
+ print(transform_path_str)
+
+ output:
+ [R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]
+ '''
+ source_spec = source_layout.sharding_spec
+ target_spec = target_layout.sharding_spec
+ MAX_TRANSFORM_STEPS = 20
+ total_steps = 0
+ transform_path = []
+ comm_action_sequence = []
+ spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
+
+ if spec_pairs in self.cached_solution:
+ return self.cached_solution[spec_pairs]
+
+ # We do nothing if the sharding spec is all the same.
+ if source_spec.spec_diff(target_spec) == 0:
+ self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence)
+ return (
+ transform_path,
+ comm_action_sequence,
+ )
+
+ temp_sharding_layout = source_layout
+
+ transform_path.append(temp_sharding_layout)
+ # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
+ while total_steps <= MAX_TRANSFORM_STEPS:
+ valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_layout)
+ best_difference_score = math.inf
+
+ for layout, comm_spec in valid_transform_spec_dict.items():
+ sharding_spec = layout.sharding_spec
+ spec_difference = sharding_spec.spec_diff(target_spec)
+
+ if spec_difference == 0:
+ transform_path.append(layout)
+ comm_action_sequence.append(comm_spec)
+ self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence)
+ return (transform_path, comm_action_sequence)
+
+ if spec_difference < best_difference_score:
+ temp_sharding_layout = layout
+ temp_comm_spec = comm_spec
+ best_difference_score = spec_difference
+
+ transform_path.append(temp_sharding_layout)
+ comm_action_sequence.append(temp_comm_spec)
+
+ total_steps += 1
+
+ raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
+
+ def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]:
+ '''
+ Get the total communication cost of the layout converting process.
+ '''
+ transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout)
+ total_cost = {'forward': 0.0, 'backward': 0.0, 'total': 0.0}
+ for layout, comm_spec in zip(transform_path, comm_action_sequence):
+ cost_dict = get_comm_cost(layout, comm_spec, self.forward_only)
+ for key in total_cost:
+ total_cost[key] += cost_dict[key]
+ return total_cost
+
+ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor:
+ '''
+ Apply target_layout to tensor with source layout, the transform path is generated by the
+ layout_converting method.
+
+ Argument:
+ tensor (torch.Tensor): The tensor to be redistributed.
+ source_layout(Layout): The source layout of the tensor.
+ target_layout (Layout): The tensor will be redistributed to the target_layout.
+
+ Example:
+ layout_converter = LayoutConverter()
+ dim_partition_source = {0: [0]}
+ dim_partition_target = {1: [0]}
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ # [[0, 1,
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ entire_shape = (4, 4, 4)
+
+ # [S0,R,R]
+ sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
+ source_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_source,
+ entire_shape=entire_shape)
+
+ # [R,S0,R]
+ sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
+ target_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_target,
+ entire_shape=entire_shape)
+
+ if rank in (0, 1):
+ sharded_tensor_0 = torch.zeros(2, 1)
+ sharded_tensor_1 = torch.ones(2, 1)
+ # tensor([[0., 1.],
+ # [0., 1.]])
+ tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
+ if rank in (2, 3):
+ sharded_tensor_0 = torch.ones(2, 1) * 2
+ sharded_tensor_1 = torch.ones(2, 1) * 3
+ # tensor([[2., 3.],
+ # [2., 3.]])
+ tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
+
+ # converted_tensor: [R, S0, R]
+ converted_tensor = layout_converter.apply(tensor_to_comm, source_layout, target_layout)
+ print(converted_tensor)
+
+ Output in rank0 and rank1:
+ tensor([[0.],
+ [0.],
+ [2.],
+ [2.]])
+
+ Output in rank2 and rank3:
+ tensor([[1.],
+ [1.],
+ [3.],
+ [3.]])
+ '''
+ _, comm_action_sequence = self.layout_converting(source_layout, target_layout)
+ for comm_spec in comm_action_sequence:
+ tensor = comm_spec.covert_spec_to_action(tensor)
+ return tensor
diff --git a/colossalai/tensor/d_tensor/misc.py b/colossalai/tensor/d_tensor/misc.py
new file mode 100644
index 000000000000..3bb3f6f1961e
--- /dev/null
+++ b/colossalai/tensor/d_tensor/misc.py
@@ -0,0 +1,14 @@
+class LayoutException(Exception):
+ pass
+
+
+class DuplicatedShardingDimensionError(LayoutException):
+ pass
+
+
+class ShardingNotDivisibleError(LayoutException):
+ pass
+
+
+class ShardingOutOfIndexError(LayoutException):
+ pass
diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py
new file mode 100644
index 000000000000..7591f760cb30
--- /dev/null
+++ b/colossalai/tensor/d_tensor/sharding_spec.py
@@ -0,0 +1,237 @@
+from copy import deepcopy
+from typing import Dict, List
+
+from ..utils import merge_same_dim_mesh_list
+from .misc import ShardingOutOfIndexError
+
+__all__ = ['DimSpec', 'ShardingException', 'ShardingSpec']
+
+ALLGATHER_COST = 20
+SHARD_COST = 5
+STEP_PENALTY = 6
+NAN = 'nan'
+
+
+class DimSpec:
+ '''
+ Sharding spec for single dimension of the sharded tensor decribe the sharding dimension of
+ logical device mesh and give a method to compute the difference between them.
+ This class is used internally in ShardingSpec.
+
+ Argument:
+ shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
+ Otherwise, the element in shard_list means the data will be sharded in that dimension.
+ '''
+
+ def __init__(self, shard_list):
+ self.is_replica = len(shard_list) == 0
+ self.shard_list = shard_list
+ self.build_difference_2d_dict()
+
+ def __eq__(self, other):
+ return str(self) == str(other)
+
+ def __repr__(self):
+ if self.is_replica:
+ return 'R'
+ target = 'S'
+ for dim in self.shard_list:
+ target += str(dim)
+ return target
+
+ def _convert_str_to_shard_list(self, str_spec):
+ '''
+ Conver str_spec into shard_list.
+
+ Argument:
+ str_spec(str): dim spec in str type.
+ '''
+
+ if str_spec == 'R':
+ return []
+ if str_spec == 'S0':
+ return [0]
+ if str_spec == 'S1':
+ return [1]
+ if str_spec == 'S01':
+ return [0, 1]
+
+ def build_difference_2d_dict(self):
+ '''
+ Build a difference maping for 2D device mesh case. It will be used to
+ compute the difference between DimSpec pairs.
+ '''
+
+ source_spec_list = ['R', 'S0', 'S1', 'S01']
+ target_spec_list = ['R', 'S0', 'S1', 'S01']
+ difference_dict = {}
+ for source_spec in source_spec_list:
+ for target_spec in target_spec_list:
+ legal_sharding_dims = []
+ spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
+ source_shard_list = self._convert_str_to_shard_list(source_spec)
+ target_shard_list = self._convert_str_to_shard_list(target_spec)
+
+ # source same as target
+ if source_shard_list == target_shard_list:
+ difference = 0
+
+ # all_gather(source) -> target
+ elif len(source_shard_list
+ ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list:
+ difference = ALLGATHER_COST
+
+ # shard(source) -> target
+ elif len(source_shard_list) == len(
+ target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
+ -1] not in source_shard_list:
+ difference = SHARD_COST
+
+ # S1 -> S0 or S0 -> S1
+ elif len(source_shard_list) == len(target_shard_list):
+ # source -> R -> target
+ difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
+
+ # R -> S01
+ elif len(source_shard_list) == len(target_shard_list) - 2:
+ difference = SHARD_COST + STEP_PENALTY + SHARD_COST
+
+ # S01 -> R
+ elif len(source_shard_list) == len(target_shard_list) + 2:
+ difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
+
+ # S1 -> S01
+ elif len(source_shard_list) == len(target_shard_list) - 1:
+ difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
+
+ # S01 -> S1
+ elif len(source_shard_list) == len(target_shard_list) + 1:
+ difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
+
+ else:
+ difference = NAN
+ difference_dict[spec_pair] = difference
+
+ self.difference_dict = difference_dict
+
+ def dim_diff(self, other):
+ '''
+ The difference between two _DimSpec.
+
+ Argument:
+ other(_DimSpec): the dim spec to compare with.
+
+ Return:
+ difference(int): the difference between two _DimSpec.
+
+ Example:
+ dim_spec = _DimSpec([0])
+ other_dim_spec = _DimSpec([0, 1])
+ print(dim_spec.difference(other_dim_spec))
+
+ Output:
+ 5
+ '''
+ difference = self.difference_dict[(str(self), str(other))]
+ return difference
+
+
+class ShardingSpec:
+ '''
+ Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like
+ [R, R, S0, S1], which means
+
+ Argument:
+ dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
+ and the value of the key decribe which logical axis will be sharded in that dimension.
+ sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
+ '''
+
+ def __init__(self,
+ dim_size: int,
+ dim_partition_dict: Dict[int, List[int]] = None,
+ sharding_sequence: List[DimSpec] = None):
+ self.dims = dim_size
+ self.dim_partition_dict = dim_partition_dict
+ self.sharding_sequence = sharding_sequence
+ if self.sharding_sequence is None:
+ assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
+ self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=self.dims,
+ dim_partition_dict=self.dim_partition_dict)
+ self.sharding_sequence = self.convert_dict_to_shard_sequence()
+
+ elif self.dim_partition_dict is None:
+ assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
+ self.dim_partition_dict = self.convert_shard_sequence_to_dict()
+
+ self._sanity_check()
+
+ def _sanity_check(self):
+ if len(self.sharding_sequence) > self.dims:
+ raise ShardingOutOfIndexError(
+ f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.')
+
+ if list(self.dim_partition_dict.keys()) and max(list(self.dim_partition_dict.keys())) >= self.dims:
+ raise ShardingOutOfIndexError(
+ f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.'
+ )
+
+ def __repr__(self):
+ res_list = ["ShardingSpec:"]
+ res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
+ return ' '.join(res_list)
+
+ def convert_dict_to_shard_sequence(self):
+ '''
+ Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence.
+ '''
+ sharding_sequence = [DimSpec([])] * self.dims
+ for dim, shard_list in self.dim_partition_dict.items():
+ sharding_sequence[dim] = DimSpec(shard_list)
+ return sharding_sequence
+
+ def convert_shard_sequence_to_dict(self):
+ '''
+ Convert sharding_sequence into dim_partition_dict.
+ '''
+ new_dim_partition_dict = {}
+ for index, dim_spec in enumerate(self.sharding_sequence):
+ if not dim_spec.is_replica:
+ if index not in new_dim_partition_dict:
+ new_dim_partition_dict[index] = []
+ new_dim_partition_dict[index].extend(dim_spec.shard_list)
+ return new_dim_partition_dict
+
+ def spec_diff(self, other):
+ '''
+ This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
+ pair of sharding sequence.
+
+ Example:
+ dim_partition_dict = {0: [0, 1]}
+ # DistSpec:
+ # shard_sequence: S01,R,R
+ # device_mesh_shape: (4, 4)
+ sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
+ dim_partition_dict_to_compare = {0: [0], 1: [1]}
+ # DistSpec:
+ # shard_sequence: S0,S1,R
+ # device_mesh_shape: (4, 4)
+ sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
+ print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
+
+ Output:
+ 25
+
+ Argument:
+ other(ShardingSpec): The ShardingSpec to compared with.
+
+ Return:
+ difference(int): Difference between two ShardingSpec.
+ '''
+ assert len(self.sharding_sequence) == len(
+ other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.'
+ difference = 0
+ for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
+ difference += orig_dim_spec.dim_diff(other_dim_spec)
+ return difference
diff --git a/colossalai/tensor/d_tensor/utils.py b/colossalai/tensor/d_tensor/utils.py
new file mode 100644
index 000000000000..644bb6306b42
--- /dev/null
+++ b/colossalai/tensor/d_tensor/utils.py
@@ -0,0 +1,66 @@
+import operator
+from functools import reduce
+from typing import Dict
+
+from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
+from colossalai.tensor.d_tensor.layout import Layout
+
+
+def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]:
+ '''
+ This method is used to compute the communication cost for a given layout and comm_spec.
+
+ For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
+ compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is a tiny cost.
+
+ Args:
+ layout: the layout of the tensor.
+ comm_spec: the comm_spec to instruct the communication operation.
+ forward_only: if it is True, we will just count the forward communication cost.
+ If it is False, we will count both forward and backward communication cost.
+ '''
+ comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1)
+ device_mesh = layout.device_mesh
+ comm_pattern = comm_spec.comm_pattern
+ logical_process_axis = comm_spec.logical_process_axis
+ cost_dict = {}
+
+ if comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
+ # the comm size for all gather is the size of the gathered tensor
+ gather_dim = comm_spec.gather_dim
+ all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
+ all_gather_size = device_mesh.mesh_shape[all_gather_axis]
+ comm_size_for_all_gather = comm_size * all_gather_size
+ forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
+ # give a tiny cost to shard
+ backward_communication_cost = 100
+
+ if comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
+ forward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis)
+ # grad should have same shape as input tensor
+ # all to all operation has same logical process axis as forward.
+ backward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis)
+
+ if comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
+ forward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis)
+ backward_communication_cost = 0
+
+ if comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
+ forward_communication_cost = 0
+ backward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis)
+
+ if comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
+ # give a tiny cost to shard
+ forward_communication_cost = 100
+ backward_communication_cost = device_mesh.all_gather_cost(comm_size, logical_process_axis)
+
+ if forward_only:
+ cost_dict["forward"] = forward_communication_cost
+ cost_dict["backward"] = 0
+ cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
+ else:
+ cost_dict["forward"] = forward_communication_cost
+ cost_dict["backward"] = backward_communication_cost
+ cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
+
+ return cost_dict
diff --git a/colossalai/tensor/distspec.py b/colossalai/tensor/distspec.py
index 0b62cbdda2c5..8dd0d8791537 100644
--- a/colossalai/tensor/distspec.py
+++ b/colossalai/tensor/distspec.py
@@ -11,7 +11,7 @@ class DistPlacementPattern(Enum):
class _DistSpec:
"""_DistSpec
-
+
A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
@@ -39,11 +39,12 @@ def __eq__(self, other: "_DistSpec") -> bool:
return True
def __repr__(self) -> str:
- res_list = ["DistSpec:"]
+ attr_list = []
for attr in dir(self):
if not attr.startswith('__'):
- res_list.append(f'\n\t{attr}: {str(getattr(self, attr))}')
- return ''.join(res_list)
+ attr_list.append(f'{attr}={str(getattr(self, attr))}')
+ attr_str = ", ".join(attr_list)
+ return "DistSpec(" + attr_str + ")"
def ReplicaSpec() -> _DistSpec:
diff --git a/colossalai/tensor/process_group.py b/colossalai/tensor/process_group.py
index e7e565071e58..f108bdc247f5 100644
--- a/colossalai/tensor/process_group.py
+++ b/colossalai/tensor/process_group.py
@@ -1,29 +1,36 @@
-import torch
from typing import List, Optional
-from colossalai.logging import get_dist_logger
+
+import torch
+
from colossalai.context.singleton_meta import SingletonMeta
+from colossalai.logging import get_dist_logger
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
def __init__(self):
# distributed settings
+ # use this dict to record all Pytorch ProcessGroups
self.dict = {}
+ # set a distributed logger
+ self.logger = get_dist_logger('ProcessGroup')
+
+ def log_pg_init(self, rank_list: List[int], backend: str):
+ str_list = ["Pytorch ProcessGroup Init:"]
+ str_list.append(f"backend: {backend}")
+ str_list.append(f"ranks: {rank_list}")
+ self.logger.info("\n\t".join(str_list), ranks=[0])
def get(self, rank_list: List[int], backend: str = 'nccl'):
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""
- rank_tuple = tuple(rank_list)
# we need to convert the passed list to a tuple
# since List is unhashable
- pg_key = (backend, rank_tuple)
-
- if pg_key not in self.dict:
-
- self.logger = get_dist_logger('ProcessGroup')
- self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])
- self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
- return self.dict[pg_key]
+ processgroup_key = (backend, tuple(rank_list))
+ if processgroup_key not in self.dict:
+ self.log_pg_init(rank_list=rank_list, backend=backend)
+ self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
+ return self.dict[processgroup_key]
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
@@ -40,7 +47,7 @@ class ProcessGroup:
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group.
- tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
+ tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
"""
@@ -54,10 +61,10 @@ def __init__(self,
return
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
- if rank is None:
- self._rank = torch.distributed.get_rank()
- else:
- self._rank = rank
+
+ self._rank = torch.distributed.get_rank()
+ if rank is not None:
+ assert self._rank == rank # make sure that the global rank is correct
if ranks is None:
self._rank_list = list(range(torch.distributed.get_world_size()))
@@ -104,7 +111,7 @@ def __init__(self,
self.is_init = True
def set_cpu_groups(self):
- """set_cpu_groups
+ """set_cpu_groups
Initialize Pytorch process groups for cpu communications.
"""
if self.has_cpu_groups:
@@ -122,7 +129,7 @@ def set_cpu_groups(self):
@property
def has_cpu_groups(self) -> bool:
- """has_cpu_groups
+ """has_cpu_groups
If cpu groups have been initailized.
Returns:
@@ -132,8 +139,9 @@ def has_cpu_groups(self) -> bool:
def __repr__(self):
if self.is_init:
- return "ProcessGroup:\n\tRank: {}, World size: {}, DP degree: {}, TP degree: {}\n\tRanks in group: {}".\
- format(self._rank, self._world_size, self._dp_degree, self._tp_degree, self._rank_list)
+ ranks_str = f"ProcessGroup(ranks={self._rank_list},\n"
+ personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})"
+ return ranks_str + personal_str
else:
return "ProcessGroup not initialized"
@@ -155,7 +163,7 @@ def __eq__(self, obj: 'ProcessGroup') -> bool:
return True
def rank(self) -> int:
- """rank
+ """rank
The current rank in the global process group.
@@ -165,9 +173,9 @@ def rank(self) -> int:
return self._rank
def ranks_in_group(self) -> List[int]:
- """ranks_in_group
+ """ranks_in_group
- a list of rank number in in the global process group.
+ a list of rank number in in the global process group.
Returns:
List[int]: a list of rank number.
@@ -177,7 +185,7 @@ def ranks_in_group(self) -> List[int]:
def world_size(self) -> int:
"""world_size
- The world size of the global process group.
+ The world size of the global process group.
Returns:
int: world size
@@ -185,7 +193,7 @@ def world_size(self) -> int:
return self._world_size
def tp_rank_list(self) -> List[int]:
- """tp_rank_list
+ """tp_rank_list
the rank list in the TP process group containing the current rank.
@@ -195,7 +203,7 @@ def tp_rank_list(self) -> List[int]:
return self._tp_rank_list
def dp_rank_list(self) -> List[int]:
- """dp_rank_list
+ """dp_rank_list
the rank list in the DP process group containing the current rank.
@@ -205,7 +213,7 @@ def dp_rank_list(self) -> List[int]:
return self._dp_rank_list
def tp_local_rank(self) -> int:
- """tp_local_rank
+ """tp_local_rank
The local rank number in the current TP process group.
@@ -268,7 +276,7 @@ def cpu_dp_process_group(self):
"""cpu_dp_process_group
the pytorch CPU DP process group containing the current rank.
-
+
assert failed if cpu process group is not initialized.
Returns:
@@ -281,7 +289,7 @@ def cpu_tp_process_group(self):
"""cpu_tp_process_group
the pytorch CPU TP process group containing the current rank.
-
+
assert failed if cpu process group is not initialized.
Returns:
diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py
index 875b5a93ba4f..3f16bd91e5fe 100644
--- a/colossalai/utils/__init__.py
+++ b/colossalai/utils/__init__.py
@@ -1,22 +1,46 @@
-from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .activation_checkpoint import checkpoint
from .checkpointing import load_checkpoint, save_checkpoint
-from .common import (clip_grad_norm_fp32, conditional_context, copy_tensor_parallel_attributes, count_zeros_fp32,
- ensure_path_exists, free_port, is_dp_rank_0, is_model_parallel_parameter, is_no_pp_or_last_stage,
- is_tp_rank_0, is_using_ddp, is_using_pp, is_using_sequence, multi_tensor_applier,
- param_is_not_tensor_parallel_duplicate, print_rank_0, switch_virtual_pipeline_parallel_rank,
- sync_model_param, disposable)
+from .common import (
+ clip_grad_norm_fp32,
+ conditional_context,
+ copy_tensor_parallel_attributes,
+ count_zeros_fp32,
+ disposable,
+ ensure_path_exists,
+ free_port,
+ is_ddp_ignored,
+ is_dp_rank_0,
+ is_model_parallel_parameter,
+ is_no_pp_or_last_stage,
+ is_tp_rank_0,
+ is_using_ddp,
+ is_using_pp,
+ is_using_sequence,
+ multi_tensor_applier,
+ param_is_not_tensor_parallel_duplicate,
+ print_rank_0,
+ switch_virtual_pipeline_parallel_rank,
+ sync_model_param,
+)
+from .cuda import empty_cache, get_current_device, set_to_cuda, synchronize
from .data_sampler import DataParallelSampler, get_dataloader
-from .memory import (report_memory_usage, colo_device_memory_used, colo_set_process_memory_fraction,
- colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity)
-from .timer import MultiTimer, Timer
+from .memory import (
+ colo_device_memory_capacity,
+ colo_device_memory_used,
+ colo_get_cpu_memory_capacity,
+ colo_set_cpu_memory_capacity,
+ colo_set_process_memory_fraction,
+ report_memory_usage,
+)
from .tensor_detector import TensorDetector
+from .timer import MultiTimer, Timer
__all__ = [
'checkpoint',
'free_port',
'print_rank_0',
'sync_model_param',
+ 'is_ddp_ignored',
'is_dp_rank_0',
'is_tp_rank_0',
'is_no_pp_or_last_stage',
diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py
index 7575fa292f14..e15981140be1 100644
--- a/colossalai/utils/common.py
+++ b/colossalai/utils/common.py
@@ -11,7 +11,7 @@
import torch
import torch.distributed as dist
-from torch._six import inf
+from torch import inf
from torch.nn.parameter import Parameter
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_ATTRIBUTES
@@ -50,16 +50,20 @@ def ensure_path_exists(filename: str):
Path(dirpath).mkdir(parents=True, exist_ok=True)
-def free_port():
+def free_port() -> int:
+ """Get a free port on localhost.
+
+ Returns:
+ int: A free port on localhost.
+ """
while True:
+ port = random.randint(20000, 65000)
try:
- sock = socket.socket()
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- port = random.randint(20000, 65000)
- sock.bind(('localhost', port))
- sock.close()
- return port
- except Exception:
+ with socket.socket() as sock:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock.bind(("localhost", port))
+ return port
+ except OSError:
continue
@@ -126,14 +130,18 @@ def is_model_parallel_parameter(p):
return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL)
+def is_ddp_ignored(p):
+ return getattr(p, '_ddp_to_ignore', False)
+
+
def _calc_l2_norm(grads):
- # we should not
+ # we should not
global fused_optim
if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
-
+
norm = 0.0
if len(grads) > 0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py
index 93c91e0995ea..87ae413a2a8a 100644
--- a/colossalai/utils/model/colo_init_context.py
+++ b/colossalai/utils/model/colo_init_context.py
@@ -32,17 +32,16 @@ def _convert_to_coloparam(param: torch.nn.Parameter,
default_pg: Optional[ProcessGroup] = None,
default_dist_spec: Optional[Any] = None) -> ColoParameter:
- if isinstance(param, ColoParameter):
+ if type(param) is ColoParameter:
return param
# detaching tensor is necessary for optimizers.
requires_grad = param.requires_grad
# param is the global tensor.
-
+
if param.device.type == "meta":
colo_param = ColoParameter(param, requires_grad=requires_grad)
- else:
+ else:
colo_param = ColoParameter(param.to(device=device, dtype=dtype), requires_grad=requires_grad)
-
# if default_shard_plan exists, shard the param during initialization.
# This can reduce the model size after initialization.
@@ -103,7 +102,7 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
name_list = []
for name, param in _named_params_with_replica(module):
- if isinstance(param, ColoTensor):
+ if type(param) is ColoParameter:
continue
split = name.rfind('.')
@@ -129,32 +128,29 @@ def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
delattr(submodule, param_name)
setattr(submodule, param_name, colo_param)
colo_param.shared_param_modules.append(submodule)
-
- meta_param_flag = 0
- meta_buffer_flag = 0
+
+ param_number = 0
+ meta_param_number = 0
+ buffer_number = 0
+ meta_buffer_number = 0
+
for param in module.parameters():
- if param.device.type=="meta":
- meta_param_flag = 1
- if meta_param_flag == 1 and param.device.type!="meta":
- raise ValueError("Meta parameters and valued parameters can not be in the same model")
-
+ param_number += 1
+ meta_param_number += (param.device.type == 'meta')
+
for buffer in module.buffers():
- if buffer.device.type=="meta":
- meta_buffer_flag = 1
- if meta_buffer_flag == 1 and buffer.device.type!="meta":
- raise ValueError("Meta buffers and valued buffers can not be in the same model")
-
- if meta_param_flag==1 and meta_buffer_flag==1:
- pass
- elif meta_buffer_flag==0 and meta_param_flag==1:
- for name, buf in module.named_buffers():
- module._buffers[name] = module._buffers[name].to(device=self._device)
- elif meta_param_flag==0 and meta_buffer_flag==1:
- for name, param in module.named_parameters():
- module._parameters[name] = module._parameters[name].to(device=self._device)
- else:
- module.to(self._device)
-
+ buffer_number += 1
+ meta_buffer_number += (buffer.device.type == 'meta')
+
+ if meta_param_number > 0 and meta_param_number != param_number:
+ raise ValueError("Meta parameters and valued parameters can not be in the same model")
+ if meta_buffer_number > 0 and meta_buffer_number != buffer_number:
+ raise ValueError("Meta buffers and valued buffers can not be in the same model")
+
+ if meta_buffer_number == 0:
+ for buffer in module.buffers():
+ buffer.data = buffer.data.to(device=self._device)
+
def post_process_colo_init_ctx(model: torch.nn.Module,
device: torch.device = torch.device('cpu'),
diff --git a/colossalai/utils/model/experimental.py b/colossalai/utils/model/experimental.py
new file mode 100644
index 000000000000..6427a147a5c0
--- /dev/null
+++ b/colossalai/utils/model/experimental.py
@@ -0,0 +1,576 @@
+from types import MethodType
+from typing import Callable, Optional, Union
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch import Tensor
+from torch.utils._pytree import tree_map
+
+from colossalai.fx.profiler.tensor import MetaTensor
+from colossalai.tensor.d_tensor.d_tensor import DTensor
+from colossalai.tensor.d_tensor.layout import Layout
+
+# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
+_NORMAL_FACTORY = [
+ "arange",
+ "empty",
+ "full",
+ "linspace",
+ "logspace",
+ "ones",
+ "rand",
+ "randn",
+ "randint",
+ "randperm",
+ "zeros",
+ "tensor",
+]
+
+# factory function that does not support meta tensor backend
+_NO_META_FACTORY = [
+ "eye",
+]
+
+_EARLY_MATERIALIZED_OPS = ['__getitem__', 'split']
+
+# If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
+# without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
+# These ops cannot be unwrapped using .data
+_CHANGE_META_OPS = ['_cudnn_rnn_flatten_weight', 'requires_grad_', '__get__']
+
+_LEGACY_TENSOR_CONSTRUCTOR = {
+ 'FloatTensor': torch.float,
+ 'DoubleTensor': torch.double,
+ 'HalfTensor': torch.half,
+ 'BFloat16Tensor': torch.bfloat16,
+ 'ByteTensor': torch.uint8,
+ 'CharTensor': torch.int8,
+ 'ShortTensor': torch.short,
+ 'IntTensor': torch.int,
+ 'LongTensor': torch.long,
+ 'BoolTensor': torch.bool,
+}
+
+_EMPTY_DATA = torch.empty(0)
+
+
+class _MyTensor(Tensor):
+ """This class is only for correctness verification.
+ """
+ _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
+
+ def __new__(cls, func, *args, concrete_data=None, **kwargs) -> '_MyTensor':
+ cls._pre_op_fn()
+ if concrete_data is not None:
+ # uniform api as LazyTensor
+ data = concrete_data
+ else:
+ data = func(*args, **kwargs)
+ return Tensor._make_subclass(cls, data, require_grad=data.requires_grad)
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ cls._pre_op_fn()
+ return super().__torch_function__(func, types, args, kwargs)
+
+
+def _convert_cls(tensor: 'LazyTensor', target: torch.Tensor) -> torch.Tensor:
+ """Convert a lazy tensor's class to target's class, with target's data.
+
+ The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
+ If we create a new tensor and update the module by ``setattr(module, name, param)``, the shared parameters will not be updated. And we have to track all shared parameters and update them manually.
+
+ Args:
+ tensor (LazyTensor): the LazyTensor to be converted
+ target (torch.Tensor): target tensor
+
+ Returns:
+ torch.Tensor: the converted tensor
+ """
+ cls_to_become = nn.Parameter if isinstance(tensor, nn.Parameter) else torch.Tensor
+ tensor.__class__ = cls_to_become
+ tensor.data = target
+ tensor.requires_grad = target.requires_grad
+ # subclass of torch.Tensor does not have tolist() method
+ # overwrite this method after materialization or distribution
+ tensor.tolist = MethodType(torch.Tensor.tolist, target)
+ return tensor
+
+
+class LazyTensor(torch.Tensor):
+ """A naive implementation of LazyTensor (https://arxiv.org/pdf/2102.13267.pdf).
+
+ Usage:
+ 1. Use ``LazyTensor`` instead of ``torch.Tensor``.
+ >>> x = LazyTensor(torch.zeros, 2, 3)
+ >>> x += 1
+ >>> y = x * x
+ >>> y = y.cuda().half()
+ >>> y[0, 0] = 0
+ >>> y = y.materialize() # materialize the tensor
+ >>> print(y)
+ tensor([[0., 1., 1.],
+ [1., 1., 1.]], device='cuda:0', dtype=torch.float16)
+
+ Warnings:
+ 1. Cases that ``LazyTensor`` can't deal with.
+ >>> x = LazyTensor(torch.ones, 2, 3)
+ >>> x[0, 0] = -x[0, 0] # this will cause infinite recursion
+ >>> y = x.clone()
+ >>> x.add_(1) # modifying origin tensor after cloning leads to wrong materialization
+ >>> z = x.tolist()
+ >>> x.zeros_() # modifying origin tensor after cloning tolist is not allowed
+ >>> nn.utils.weight_norm(self.conv, name="weight", dim=2) # applying weight norm on a lazy tensor is not allowed
+
+
+ 2. Cases that ``LazyTensor`` becomes eager (early materialization).
+ >>> b = a[:, 2:] # get a slice of a lazy tensor triggers early materialization
+ >>> chunks = a.split(3) # this also triggers early materialization
+ >>> x.data = torch.rand(2, 3) # directly setting data of a lazy tensor triggers early materialization
+
+ """
+
+ _repr = True
+ _meta_data: Optional[MetaTensor] = None # shape, dtype, device
+ _pre_op_fn: Callable[['LazyTensor'], None] = lambda *args: None
+
+ @staticmethod
+ def __new__(cls, func, *args, meta_data=None, concrete_data=None, **kwargs):
+ if concrete_data is not None:
+ # some ops don't support meta backend and should have concrete data
+ elem = concrete_data
+ else:
+ if meta_data is None:
+ device = kwargs.get('device', 'cpu')
+ elem = func(*args, **{**kwargs, 'device': 'meta'})
+ meta_data = MetaTensor(elem, fake_device=device)
+ elem = meta_data._tensor
+ # As a meta tensor cannot be modified __class__ to torch.Tensor, we should use an empty real tensor here
+ r = torch.Tensor._make_subclass(cls, _EMPTY_DATA, require_grad=elem.requires_grad)
+ r._meta_data = meta_data
+ return r
+
+ def __init__(self, func, *args, meta_data=None, concrete_data=None, **kwargs):
+ self._factory_method = (func, args, kwargs) # (func, args, kwargs)
+ self._op_buffer = [] # (func, args, kwargs, replace)
+ self._materialized_data: Optional[torch.Tensor] = concrete_data # materialized data
+
+ def materialize(self) -> torch.Tensor:
+ """Materialize the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace).
+
+ Returns:
+ torch.Tensor: The materialized tensor (self).
+ """
+ target = self._materialize_data()
+ self.clean()
+ return _convert_cls(self, target)
+
+ def distribute(self, layout: Layout) -> torch.Tensor:
+ """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
+
+ Args:
+ layout (Layout): Distribution layout.
+
+ Returns:
+ torch.Tensor: The distributed tensor (self).
+ """
+ target = self._materialize_data()
+ self.clean()
+ local_tensor = DTensor(target, layout).local_tensor
+ return _convert_cls(self, local_tensor)
+
+ def clean(self) -> None:
+ """Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized.
+ """
+ self._factory_method = None
+ self._op_buffer = None
+ self._materialized_data = None
+ self._meta_data = None
+
+ @staticmethod
+ def _replace_with_materialized(x):
+ if isinstance(x, LazyTensor):
+ return x._materialize_data()
+ return x
+
+ def _materialize_data(self) -> torch.Tensor:
+ # self._materialized_data should be generated after the first call of this function
+ if self._materialized_data is None:
+ # apply factory method
+ func, args, kwargs = self._factory_method
+
+ # apply cached sequence
+ self._pre_op_fn()
+
+ try:
+ init_val = func(*tree_map(self._replace_with_materialized, args),
+ **tree_map(self._replace_with_materialized, kwargs))
+ except TypeError as e:
+ print(f'init fn: {func.__name__}')
+ raise e
+
+ self._materialized_data = self._rerun_ops(init_val)
+ return self._materialized_data
+
+ def _rerun_ops(self, target=None) -> torch.Tensor:
+ """Do lazy execution by rerunning all (stored) related operations.
+
+ Args:
+ target (torc.Tensor, optional): Intial value of the target tensor (self). Defaults to None.
+ """
+
+ def replace(x):
+ if x is self:
+ return target
+ elif isinstance(x, LazyTensor):
+ return x._materialize_data()
+ return x
+
+ packed = None
+
+ for (func, args, kwargs) in self._op_buffer:
+ if func == torch.Tensor.requires_grad_:
+ packed = func, args, kwargs # requires grad should be set at last
+ else:
+ self._pre_op_fn()
+ o = func(*tree_map(replace, args), **tree_map(replace, kwargs))
+ target = o if isinstance(o, torch.Tensor) else target # if func returns non-Tensor, discard the value
+
+ # super-dainiu: set requires_grad after all inplace-ops are done
+ if packed is not None:
+ func, args, kwargs = packed
+ func(*tree_map(replace, args), **tree_map(replace, kwargs))
+
+ return target
+
+ # cache everything with __torch_function__
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+ if func.__name__ in _EARLY_MATERIALIZED_OPS:
+ # These OPs cannot be lazy and related tensors should be early materialized
+ tree_map(cls._replace_with_materialized, args)
+ tree_map(cls._replace_with_materialized, kwargs)
+ is_inplace: bool = (func.__name__.endswith('_') and not (func.__name__.endswith('__'))
+ or func.__name__ == "__setitem__")
+
+ is_change_meta_op: bool = func.__name__ in _CHANGE_META_OPS
+
+ if isinstance(func, torch._C.ScriptMethod):
+ # FIXME(ver217): torch script functions are not verified
+
+ target = None
+
+ def unwrap(x):
+ if isinstance(x, LazyTensor):
+ return x._meta_data
+ return x
+
+ target: LazyTensor = args[0].clone()
+ target._op_buffer.append((func, args, kwargs))
+ target._meta_data = getattr(target._meta_data, func.name)(*tree_map(unwrap, args[1:]),
+ **tree_map(unwrap, kwargs))
+ return target
+ else:
+
+ meta_to_lazy = {}
+
+ def unwrap(x):
+ if isinstance(x, LazyTensor):
+ if x._materialized_data is not None:
+ # for early materialized tensor, use its materialized data directly
+ return x._materialized_data if is_change_meta_op else x._materialized_data.data
+ t = x if is_inplace else x.clone()
+ t._op_buffer.append((func, args, kwargs))
+ meta = x._meta_data if is_change_meta_op else x._meta_data.data
+ meta_to_lazy[meta] = t
+ return meta
+ return x
+
+ def wrap(y, i=None):
+ if isinstance(y, MetaTensor):
+ if y in meta_to_lazy:
+ # inplace op, just return origin lazy tensor
+ return meta_to_lazy[y]
+ else:
+ # out of place op, create new lazy tensor
+ fn = lambda *a, **kw: func(*a, **kw) if i is None else func(*a, **kw)[i]
+ lazy_y = LazyTensor(fn, *args, meta_data=y, **kwargs)
+ return lazy_y
+ elif type(y) is Tensor:
+ # for early materialized tensor
+ return LazyTensor(lambda: None, concrete_data=y)
+ return y
+
+ cls._pre_op_fn()
+ o = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
+ if isinstance(o, (tuple, list)):
+ return type(o)(wrap(y, i=i) for i, y in enumerate(o))
+ return wrap(o)
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ pass # skip
+
+ def clone(self) -> "LazyTensor":
+
+ def factory_fn():
+ return self.materialize().clone()
+
+ target = LazyTensor(factory_fn, meta_data=self._meta_data)
+
+ return target
+
+ def detach(self) -> Tensor:
+ return self
+
+ @property
+ def data(self):
+ return self
+
+ @data.setter
+ def data(self, other: 'LazyTensor'):
+ """This is sightly different from oringinal `data` setter.
+
+ E.g.:
+ >>> a = torch.randn(3, 3) # a is a Tensor
+ >>> b = torch.rand(2, 2)
+ >>> a.data = b
+ >>> b.add_(1) # this will affect a
+ >>> x = torch.randn(3, 3) # x is a LazyTensor
+ >>> y = torch.rand(2, 2) # y is a LazyTensor
+ >>> x.data = y
+ >>> y.add_(1) # this will not affect x
+
+ """
+ if other is self:
+ return
+
+ self._op_buffer.append(other._factory_method)
+
+ def replace(x):
+ if x is other:
+ return self
+ return x
+
+ for func, args, kwargs in other._op_buffer:
+ self._op_buffer.append((func, tree_map(replace, args), tree_map(replace, kwargs)))
+
+ def tolist(self) -> list:
+ # Though self.__class__ is modified to torch.Tensor, in C++ side, it is still a subclass of torch.Tensor
+ # And subclass of torch.Tensor does not have tolist() method
+ t = self._materialize_data()
+ return t.tolist()
+
+ def __hash__(self):
+ return id(self)
+
+
+class LazyInitContext:
+ """Context manager for lazy initialization. Enables initializing the model without allocating real memory.
+
+ Usage:
+ 1. The model is initialized, but no real memory is allocated.
+ >>> ctx = LazyInitContext()
+ >>> with ctx:
+ >>> model = MyModel().cuda()
+
+ 2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated.
+ >>> with ctx.traceable(model):
+ >>> gm = symbolic_trace(model, meta_args=meta_args)
+ >>> # Solve the execution strategy and apply the strategy to the model
+ >>> strategy = StrategyAndSpec()
+
+ 3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device)
+ >>> model = ctx.materialize(model)
+
+ 3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario)
+ >>> model = apply_strategy_to_all_params(model, strategy)
+ >>> model = ctx.distribute(model)
+
+ Warnings:
+ This API is still experimental and further modifications can be made to it.
+ For example:
+ 1. Quantization strategies can be applied before allocating real memory.
+ 2. Lazy initialization seems slower than normal initialization.
+ """
+ _replaced: bool = False
+
+ def __init__(self, tensor_cls: Union[_MyTensor, LazyTensor] = LazyTensor):
+ self.overrides = {}
+ self.tensor_cls = tensor_cls
+
+ def __enter__(self):
+ if LazyInitContext._replaced:
+ raise RuntimeError(f'LazyInitContext is not reentrant')
+ LazyInitContext._replaced = True
+
+ def wrap_factory_method(target):
+ # factory functions (eg. torch.empty())
+ def wrapper(*args, **kwargs):
+ return self.tensor_cls(target, *args, **kwargs)
+
+ return wrapper, target
+
+ def wrap_factory_like_method(orig_target, target):
+ # factory_like functions (eg. torch.empty_like())
+ def wrapper(*args, **kwargs):
+ orig_t = args[0]
+ return self.tensor_cls(orig_target, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs)
+
+ return wrapper, target
+
+ def wrap_legacy_constructor(target, dtype):
+ # legacy constructor (e.g. torch.LongTensor())
+ def wrapper(*args, **kwargs):
+ if len(args) == 1 and isinstance(args[0], torch.Tensor):
+ # (Tensor other)
+ return args[0]
+ elif len(args) == 1:
+ # (object data, *, torch.device device)
+ kwargs = {**kwargs, 'dtype': dtype}
+ replaced, orig = self.overrides['tensor']
+ return replaced(*args, **kwargs)
+ elif _is_int_tuple(args):
+ # (tuple of ints size, *, torch.device device)
+ kwargs = {**kwargs, 'dtype': dtype}
+ replaced, orig = self.overrides['empty']
+ return replaced(*args, **kwargs)
+ else:
+ raise TypeError(
+ f'new() received an invalid combination of arguments - got {tuple(type(x) for x in args)}, but expected one of:\n * (Tensor other)\n * (tuple of ints size, *, torch.device device)\n * (object data, *, torch.device device)'
+ )
+
+ return wrapper, target
+
+ def wrap_no_meta_factory(target):
+ # factory functions which don't support meta tensor backend
+ def wrapper(*args, **kwargs):
+ tensor = target(*args, **kwargs)
+ return self.tensor_cls(lambda: None, concrete_data=tensor)
+
+ return wrapper, target
+
+ self.overrides = {
+ target: wrap_factory_method(getattr(torch, target))
+ for target in _NORMAL_FACTORY
+ if callable(getattr(torch, target, None))
+ }
+
+ self.overrides.update({
+ target + '_like': wrap_factory_like_method(getattr(torch, target), getattr(torch, target + '_like'))
+ for target in _NORMAL_FACTORY
+ if callable(getattr(torch, target + '_like', None))
+ })
+
+ self.overrides.update({
+ target: wrap_legacy_constructor(getattr(torch, target), dtype)
+ for target, dtype in _LEGACY_TENSOR_CONSTRUCTOR.items()
+ if callable(getattr(torch, target, None))
+ })
+
+ self.overrides.update({
+ target: wrap_no_meta_factory(getattr(torch, target))
+ for target in _NO_META_FACTORY
+ if callable(getattr(torch, target, None))
+ })
+
+ for name, (wrapper, orig) in self.overrides.items():
+ setattr(torch, name, wrapper)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ LazyInitContext._replaced = False
+ for name, (wrapper, orig) in self.overrides.items():
+ setattr(torch, name, orig)
+
+ @staticmethod
+ def materialize(module: nn.Module, verbose: bool = False) -> nn.Module:
+ """Initialize all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
+
+ Args:
+ module (nn.Module): Target ``nn.Module``
+ verbose (bool): Whether to print lazy initialization rate. Defaults to False.
+ """
+
+ def apply_fn(name: str, p: LazyTensor):
+ p.materialize()
+
+ return _apply_to_lazy_module(module, apply_fn, verbose)
+
+ @staticmethod
+ def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
+ """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
+
+ Args:
+ module (nn.Module): Target ``nn.Module``
+ layout_dict (dict): Dict of layout for each parameter/buffer. The key is the parameter/buffer name, and the value is the layout.
+ verbose (bool, optional): Whether to print lazy initialization rate. Defaults to False.
+ """
+
+ def apply_fn(name: str, p: LazyTensor):
+ p.distribute(layout_dict[name])
+
+ return _apply_to_lazy_module(module, apply_fn, verbose)
+
+
+def _apply_to_lazy_module(module: nn.Module,
+ apply_fn: Callable[[str, torch.Tensor], None],
+ verbose: bool = False) -> nn.Module:
+ if verbose:
+ # verbose info
+ param_cnt = 0
+ param_lazy_cnt = 0
+ buf_cnt = 0
+ buf_lazy_cnt = 0
+ total_numel = 0
+ non_lazy_numel = 0
+
+ for name, p in module.named_parameters():
+ if verbose:
+ param_cnt += 1
+ total_numel += p.numel()
+ if getattr(p, '_materialized_data', False) is None:
+ # if no _materialized_data attr, the tensor is not lazy
+ param_lazy_cnt += 1
+ else:
+ non_lazy_numel += p.numel()
+ if isinstance(p, LazyTensor):
+ apply_fn(name, p)
+
+ for name, buf in module.named_buffers():
+ if verbose:
+ buf_cnt += 1
+ total_numel += buf.numel()
+ if getattr(buf, "_materialized_data", False) is None:
+ # if no _materialized_data attr, the tensor is not lazy
+ buf_lazy_cnt += 1
+ else:
+ non_lazy_numel += buf.numel()
+ if isinstance(buf, LazyTensor):
+ apply_fn(name, buf)
+
+ if verbose:
+ non_lazy_numel_ratio = non_lazy_numel / total_numel * 100 if non_lazy_numel != 0 else 0
+ _print_rank_0(f'Param lazy rate: {param_lazy_cnt}/{param_cnt}')
+ _print_rank_0(f'Buffer lazy rate: {buf_lazy_cnt}/{buf_cnt}')
+ _print_rank_0(
+ f'Non lazy numel: {non_lazy_numel} ({non_lazy_numel/1024**2:.3f} M), ratio: {non_lazy_numel_ratio}%')
+
+ return module
+
+
+def _print_rank_0(*args, **kwargs):
+ if not dist.is_initialized() or dist.get_rank() == 0:
+ print(*args, **kwargs)
+
+
+def _is_int_tuple(args) -> bool:
+ if not isinstance(args, tuple):
+ return False
+ for x in args:
+ if not isinstance(x, int):
+ return False
+ return True
diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py
index 75bb18df66c1..f49607376439 100644
--- a/colossalai/utils/model/utils.py
+++ b/colossalai/utils/model/utils.py
@@ -1,7 +1,12 @@
-import torch
+# This code has been adapted from the DeepSpeed library.
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
import functools
from typing import Optional
+import torch
+
def substitute_init_recursively(cls, func, visited: set):
for subcls in cls.__subclasses__():
diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py
index 572ddd9e4e3f..b40b69962cf7 100644
--- a/colossalai/zero/init_ctx/init_context.py
+++ b/colossalai/zero/init_ctx/init_context.py
@@ -1,46 +1,45 @@
import contextlib
import functools
-from typing import Optional
from contextlib import AbstractContextManager
+from dataclasses import dataclass
+from typing import Optional
import torch
-import torch.nn as nn
import torch.distributed as dist
+import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
-from colossalai.core import global_context as gpc
from colossalai.context.singleton_meta import SingletonMeta
+from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
+from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2
-from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
-class ZeroContextConfig(object):
+@dataclass
+class ZeroContextConfig:
"""The configuration used to control zero context initialization.
Args:
target_device (torch.device): The device where param data are after exiting the context.
- replicated (bool, optional): Whether the param is replicated across data parallel group.
+ is_replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
"""
- def __init__(self, target_device: torch.device, replicated: bool = True, shard_param: bool = False):
- super().__init__()
+ target_device: torch.device
+ is_replicated: bool = True
+ shard_param: bool = False
- if shard_param:
- assert replicated, "Non-replicated parameters can't be sharded."
+ def __post_init__(self):
+ if self.shard_param:
+ assert self.is_replicated, "Non-replicated parameters can't be sharded."
- # replicated no-shard parameters should locate in cuda, since we will broadcast them soon
- if replicated and not shard_param:
- assert target_device.type == 'cuda', "Replicated no-shard paramters should locate in cuda."
-
- self.target_device = target_device
- self.is_replicated: bool = replicated
- self.shard_param: bool = shard_param
+ if self.is_replicated and not self.shard_param:
+ assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
@@ -74,7 +73,7 @@ def __init__(self,
self.seed = seed
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
- self.config = ZeroContextConfig(target_device=target_device, replicated=True, shard_param=shard_param)
+ self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
ZeroContextMgr().current_context = self
@@ -124,7 +123,7 @@ def calc_fanin_fanout(tensor: torch.Tensor):
return fan_in, fan_out
def _pre_context_exec(self):
- """
+ """
The Callback function when entering the context
"""
self.logger = get_dist_logger("ZeroInitContext")
@@ -248,7 +247,7 @@ def hijack_context_config(self, **kwargs):
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
- replicated=is_replicated,
+ is_replicated=is_replicated,
shard_param=False)
diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py
index ae3a619980ac..12e8f65d4a35 100644
--- a/colossalai/zero/sharded_model/sharded_model_v2.py
+++ b/colossalai/zero/sharded_model/sharded_model_v2.py
@@ -1,3 +1,4 @@
+# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import functools
import itertools
from collections import OrderedDict
@@ -493,6 +494,7 @@ def _colo_load_from_state_dict(self,
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
+ shard_strategy (Optional[BaseShardStrategy], optional): A shard strategy to manage shard behavior. Defaults to None.
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
diff --git a/colossalai/zero/sharded_optim/_utils.py b/colossalai/zero/sharded_optim/_utils.py
index 9a839a5705c3..9ca2fdf5aa06 100644
--- a/colossalai/zero/sharded_optim/_utils.py
+++ b/colossalai/zero/sharded_optim/_utils.py
@@ -1,12 +1,12 @@
import math
+from typing import Optional
import torch
import torch.distributed as dist
-from torch._six import inf
+from torch import inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from colossalai.tensor import ColoParameter
from colossalai.utils import is_model_parallel_parameter
@@ -101,7 +101,11 @@ def split_half_float_double(tensor_list):
return buckets
-def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA):
+def reduce_tensor_dp_group(tensor: torch.Tensor,
+ dtype: Optional[torch.dtype] = None,
+ dst_local_rank: Optional[int] = None,
+ dst_global_rank: Optional[int] = None,
+ group: Optional[dist.ProcessGroup] = None):
"""
Reduce the tensor in the data parallel process group
@@ -114,7 +118,7 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
:type tensor: torch.Tensor
:type dtype: torch.dtype, optional
:type dst_rank: int, optional
- :type parallel_mode: ParallelMode, optional
+ :type pg: ProcessGroup, optional
"""
# use the original dtype
if dtype is None:
@@ -126,25 +130,22 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
else:
tensor_to_reduce = tensor
- world_size = gpc.get_world_size(parallel_mode)
- group = gpc.get_group(parallel_mode)
+ world_size = dist.get_world_size(group=group)
tensor_to_reduce.div_(world_size)
# if rank is None, all reduce will be used
# else, reduce is used
- use_all_reduce = dst_rank is None
+ use_all_reduce = dst_local_rank is None
if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group)
else:
- ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
- global_rank = ranks_in_group[dst_rank]
- dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
+ dist.reduce(tensor=tensor_to_reduce, dst=dst_global_rank, group=group)
# recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
- local_rank = gpc.get_local_rank(parallel_mode)
- if use_all_reduce or dst_rank == local_rank:
+ local_rank = dist.get_rank(group=group)
+ if use_all_reduce or dst_local_rank == local_rank:
tensor.copy_(tensor_to_reduce)
return tensor
@@ -222,7 +223,10 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
for g, p in zip(gradients, params):
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
- if is_model_parallel_parameter(p) or mp_rank == 0:
+ tp_param_flag = False
+ if is_model_parallel_parameter(p) or (isinstance(p, ColoParameter) and not p.is_replicate()):
+ tp_param_flag = True
+ if tp_param_flag or mp_rank == 0:
param_norm = g.data.double().norm(2)
total_norm += param_norm.item()**2
@@ -231,7 +235,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2):
torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if mp_group is not None:
- dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM)
+ dist.all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=mp_group)
total_norm = total_norm_cuda[0].item()**(1. / norm_type)
diff --git a/colossalai/zero/sharded_optim/bookkeeping/base_store.py b/colossalai/zero/sharded_optim/bookkeeping/base_store.py
index d4436acaa4bf..2ebd122464f4 100644
--- a/colossalai/zero/sharded_optim/bookkeeping/base_store.py
+++ b/colossalai/zero/sharded_optim/bookkeeping/base_store.py
@@ -1,12 +1,12 @@
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+import torch.distributed as dist
+from torch.distributed import ProcessGroup
class BaseStore:
- def __init__(self, dp_parallel_mode=ParallelMode.DATA):
- self._world_size = gpc.get_world_size(dp_parallel_mode)
- self._local_rank = gpc.get_local_rank(dp_parallel_mode)
+ def __init__(self, torch_pg: ProcessGroup):
+ self._world_size = dist.get_world_size(group=torch_pg)
+ self._local_rank = dist.get_rank(group=torch_pg)
@property
def world_size(self):
diff --git a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
index 0f2b1bb88b58..ec322a78bf81 100644
--- a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
+++ b/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
@@ -1,14 +1,12 @@
-from colossalai.context import ParallelMode
-from colossalai.core import global_context as gpc
+from torch.distributed import ProcessGroup
from .base_store import BaseStore
class BucketStore(BaseStore):
- def __init__(self, dp_parallel_mode):
- super().__init__(dp_parallel_mode)
- self._grads = dict()
+ def __init__(self, torch_pg: ProcessGroup):
+ super().__init__(torch_pg)
self._params = dict()
self._num_elements_in_bucket = dict()
@@ -20,25 +18,24 @@ def num_elements_in_bucket(self, reduce_rank: int = None):
def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None):
self._num_elements_in_bucket[reduce_rank] += num_elements
- def add_grad(self, tensor, reduce_rank: int = None):
- self._grads[reduce_rank].append(tensor)
-
def add_param(self, tensor, reduce_rank: int = None):
self._params[reduce_rank].append(tensor)
def reset(self):
keys = [None] + list(range(self._world_size))
- self._grads = {rank: [] for rank in keys}
self._params = {rank: [] for rank in keys}
self._num_elements_in_bucket = {rank: 0 for rank in keys}
def reset_by_rank(self, reduce_rank=None):
- self._grads[reduce_rank] = []
self._params[reduce_rank] = []
self._num_elements_in_bucket[reduce_rank] = 0
def get_grad(self, reduce_rank: int = None):
- return self._grads[reduce_rank]
+ param_list = self.get_param(reduce_rank)
+ for param in param_list:
+ # the param must have grad for reduction
+ assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
+ return [param.grad for param in param_list]
def get_param(self, reduce_rank: int = None):
return self._params[reduce_rank]
diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py
index 8a9128a18964..942d7186e55f 100644
--- a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py
+++ b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py
@@ -15,7 +15,7 @@ def __init__(self, *args):
# for backward reduction hooks
self._grad_acc_objs = []
- def add_accumulate_grad_object(self, obj):
+ def append_accumulate_grad_object(self, obj):
"""
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully.
@@ -36,10 +36,12 @@ def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
"""
+ if group_id not in self._averaged_gradients:
+ self._averaged_gradients[group_id] = []
return self._averaged_gradients[group_id]
- def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
+ def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
"""
Append an average gradient to the list of averaged gradients of a parameter group
@@ -55,6 +57,20 @@ def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
else:
self._averaged_gradients[group_id] = [tensor]
+ def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
+ """
+ Add an average gradient to the list of averaged gradients of a parameter group
+
+ :param group_id: The index of a parameter group
+ :param tensor_idx: The index of a tensor in the list of averaged gradients
+ :param tensor: A :class:`torch.Tensor` object
+ :type group_id: int
+ :type tensor_idx: int
+ :type tensor: torch.Tensor
+
+ """
+ self._averaged_gradients[group_id][tensor_idx].add_(tensor)
+
def reset_average_gradients_by_group(self, group_id: int) -> None:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
@@ -64,3 +80,9 @@ def reset_average_gradients_by_group(self, group_id: int) -> None:
"""
self._averaged_gradients[group_id] = []
+
+ def reset_all_average_gradients(self) -> None:
+ """
+ Reset the bookkeeping data structure for averaged gradients to an empty list
+ """
+ self._averaged_gradients = dict()
diff --git a/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py b/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
index 09ebaaf9938c..cbf708b3471f 100644
--- a/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
+++ b/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
@@ -1,14 +1,15 @@
from typing import List
from torch import Tensor
+from torch.distributed import ProcessGroup
from .base_store import BaseStore
class ParameterStore(BaseStore):
- def __init__(self, dp_paralle_mode):
- super().__init__(dp_paralle_mode)
+ def __init__(self, torch_pg: ProcessGroup):
+ super().__init__(torch_pg)
# param partitioning data structures
self._fp16_param_to_rank = dict()
self._rank_groupid_to_fp16_param_list = dict()
diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py
index c437ac54939c..49fb8b54b7d2 100644
--- a/colossalai/zero/sharded_optim/low_level_optim.py
+++ b/colossalai/zero/sharded_optim/low_level_optim.py
@@ -1,5 +1,6 @@
+# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from functools import partial
-from itertools import groupby
+from typing import Optional
import torch
import torch.distributed as dist
@@ -10,15 +11,15 @@
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
+from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils.cuda import get_current_device
from ._utils import (
calculate_global_norm_from_list,
compute_norm,
flatten,
- get_grad_accumulate_object,
has_inf_or_nan,
- reduce_tensor,
+ reduce_tensor_dp_group,
release_param_grad,
split_half_float_double,
sync_param,
@@ -33,35 +34,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def __init__(
self,
optimizer: Optimizer,
-
- # grad scaler config
- initial_scale=2**16,
- min_scale=1,
- growth_factor=2,
- backoff_factor=0.5,
- growth_interval=2000,
- hysteresis=2,
+ initial_scale: int = 2**16, # grad scaler config
+ min_scale: int = 1,
+ growth_factor: float = 2.,
+ backoff_factor: float = .5,
+ growth_interval: int = 2000,
+ hysteresis: int = 2,
max_scale: int = 2**24,
-
- # grad clipping
- clip_grad_norm=0.0,
- verbose=False,
-
- # communication
- reduce_bucket_size=1024 * 1024,
- communication_dtype=None,
- overlap_communication=False,
-
- # stage 2
- partition_grad=False,
- dp_parallel_mode=ParallelMode.DATA,
- mp_parallel_mode=ParallelMode.MODEL,
-
- # cpu offload
- cpu_offload=False,
-
- # forced dtype
- forced_dtype=None):
+ clip_grad_norm: float = 0.0, # grad clipping
+ verbose: bool = False,
+ reduce_bucket_size: int = 1024 * 1024, # communication
+ communication_dtype: Optional[torch.dtype] = None,
+ overlap_communication: bool = False,
+ partition_grad: bool = False, # stage 2 flag
+ cpu_offload: bool = False, # cpu offload
+ forced_dtype: Optional[torch.dtype] = None):
# TODO: add support for
# 1. fp16 master weights
@@ -76,21 +63,32 @@ def __init__(
# stage 2
self._partition_grads = partition_grad
- # cpu_offload
self._cpu_offload = cpu_offload
- # get process groups
- self._dp_parallel_mode = dp_parallel_mode
- self._mp_parallel_mode = mp_parallel_mode
- self._local_rank = gpc.get_local_rank(dp_parallel_mode)
- self._world_size = gpc.get_world_size(dp_parallel_mode)
-
- self._dp_group = gpc.get_group(dp_parallel_mode)
- if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
- self._mp_group = gpc.get_group(mp_parallel_mode)
+ colo_pg = self._search_colo_process_group()
+ if isinstance(colo_pg, ProcessGroup):
+ self._local_rank = colo_pg.dp_local_rank()
+ self._world_size = colo_pg.dp_world_size()
+ self._dp_global_ranks = colo_pg.get_ranks_in_dp()
+ self._dp_torch_group = colo_pg.dp_process_group()
+ self._mp_torch_group = None
+ if colo_pg.tp_world_size() > 1:
+ self._mp_torch_group = colo_pg.tp_process_group()
+ elif colo_pg is None:
+ dp_parallel_mode = ParallelMode.DATA
+ mp_parallel_mode = ParallelMode.MODEL
+
+ self._dp_parallel_mode = dp_parallel_mode
+ self._mp_parallel_mode = mp_parallel_mode
+ self._local_rank = gpc.get_local_rank(dp_parallel_mode)
+ self._world_size = gpc.get_world_size(dp_parallel_mode)
+ self._dp_global_ranks = gpc.get_ranks_in_group(dp_parallel_mode)
+ self._dp_torch_group = gpc.get_group(dp_parallel_mode)
+ self._mp_torch_group = None
+ if gpc.is_initialized(mp_parallel_mode) and gpc.get_world_size(mp_parallel_mode) > 1:
+ self._mp_torch_group = gpc.get_group(mp_parallel_mode)
else:
- self._mp_group = None
-
+ raise NotImplementedError
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
self._fp32_flat_param_groups_of_current_rank = dict()
@@ -126,15 +124,18 @@ def __init__(
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
- self._param_store = ParameterStore(self._dp_parallel_mode)
- self._grad_store = GradientStore(self._dp_parallel_mode)
- self._bucket_store = BucketStore(self._dp_parallel_mode)
+ self._param_store = ParameterStore(self._dp_torch_group)
+ self._grad_store = GradientStore(self._dp_torch_group)
+ self._bucket_store = BucketStore(self._dp_torch_group)
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for group_id, param_group in enumerate(self.optim.param_groups):
- group_params = param_group['params']
+ group_params = list()
+ for param in param_group['params']:
+ if param.requires_grad:
+ group_params.append(param)
# add the fp16 params to fp16_param_groups for bookkeeping
self._fp16_param_groups[group_id] = group_params
@@ -209,6 +210,30 @@ def loss_scale(self):
def num_param_groups(self):
return len(self._fp16_param_groups)
+ def _sanity_checks(self):
+ assert torch.cuda.is_available(), 'CUDA is required'
+ for param_group in self.optim.param_groups:
+ group_params = param_group['params']
+ for param in group_params:
+ assert param.dtype == self._dtype, \
+ f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
+
+ def _search_colo_process_group(self):
+ colo_flag = False
+ colo_pg = None
+ for param_group in self.optim.param_groups:
+ group_params = param_group['params']
+ for param in group_params:
+ if isinstance(param, ColoParameter):
+ colo_flag = True
+ if colo_pg is None:
+ colo_pg = param.get_process_group()
+ else:
+ assert colo_pg == param.get_process_group(), "All parameters should be in a same process group"
+ elif colo_flag:
+ raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.")
+ return colo_pg
+
def _partition_param_list(self, param_list):
params_per_rank = [[] for _ in range(self._world_size)]
numel_per_rank = [0 for _ in range(self._world_size)]
@@ -223,22 +248,16 @@ def _partition_param_list(self, param_list):
numel_per_rank[rank_to_go] += param.numel()
if self._verbose:
- self._logger.info(f'Number of elements on ranks: {numel_per_rank}',
- ranks=[0],
- parallel_mode=self._dp_parallel_mode)
+ self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0])
return params_per_rank
- def _sanity_checks(self):
- assert torch.cuda.is_available(), 'CUDA is required'
- for param_group in self.optim.param_groups:
- group_params = param_group['params']
- for param in group_params:
- assert param.dtype == self._dtype, \
- f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`"
+ ###########################
+ # Backward Reduction Hook #
+ ###########################
- ###########################################################
- # Backward Reduction Hook
- ###########################################################
+ def _grad_handler(self, param, grad, reduce_rank):
+ self._add_to_reduction_bucket(param, reduce_rank)
+ return grad
def _attach_reduction_hook(self):
# we iterate over the fp16 params
@@ -256,53 +275,61 @@ def _attach_reduction_hook(self):
else:
reduce_rank = None
- def _define_and_attach(param, reduce_rank):
- # get the AccumulateGrad object of the param itself
- accum_grad_obj = get_grad_accumulate_object(param)
- self._grad_store.add_accumulate_grad_object(accum_grad_obj)
+ param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank))
- reduction_func = partial(self._reduce_and_remove_grads_by_bucket,
- param=param,
- reduce_rank=reduce_rank)
+ def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank):
+ if self._overlap_communication:
+ torch.cuda.synchronize()
+ self._param_store.clear_grads_of_previous_reduced_params()
+ stream = self._comm_stream
+ else:
+ stream = torch.cuda.current_stream()
- # define hook
- # NOT IMPORTANT BUT GOOD TO KNOW:
- # args here is not grad, but allow_unreacable and accumulate_grad
- def reduce_grad_hook(*args):
- reduction_func()
+ with torch.cuda.stream(stream):
+ flat = bucket.flatten()
+ reduce_global_rank = None
+ if reduce_rank is not None:
+ reduce_global_rank = self._dp_global_ranks[reduce_rank]
+ reduced_flat = reduce_tensor_dp_group(tensor=flat,
+ dtype=self._communication_dtype,
+ dst_local_rank=reduce_rank,
+ dst_global_rank=reduce_global_rank,
+ group=self._dp_torch_group)
- accum_grad_obj.register_hook(reduce_grad_hook)
+ # update the reduced tensor
+ if reduce_rank is None or reduce_rank == self._local_rank:
+ bucket.unflatten_and_copy(reduced_flat)
- _define_and_attach(param, reduce_rank)
+ def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank):
+ param_bucket = TensorBucket(size=bucket_size)
- def _reduce_and_remove_grads_by_bucket(self, param, reduce_rank=None):
- param_size = param.numel()
+ for tensor in tensor_list:
+ param_bucket.add_to_bucket(tensor, allow_oversize=True)
- # check if the bucket is full
- # if full, will reduce the grads already in the bucket
- # after reduction, the bucket will be empty
- if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
- self._reduce_grads_in_bucket(reduce_rank)
+ if param_bucket.is_full_or_oversized():
+ self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
+ param_bucket.empty()
- # the param must not be reduced to ensure correctness
- is_param_reduced = self._param_store.is_param_reduced(param)
- if is_param_reduced:
- msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
- + 'duplicate reduction will lead to arithmetic incorrectness'
- raise RuntimeError(msg)
+ if not param_bucket.is_empty():
+ self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank)
- # the param must have grad for reduction
- assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced'
+ def _reduce_grads(self, reduce_rank, grads, bucket_size):
+ grad_buckets_by_dtype = split_half_float_double(grads)
- self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
- self._bucket_store.add_grad(param.grad, reduce_rank)
- self._bucket_store.add_param(param, reduce_rank)
+ for tensor_list in grad_buckets_by_dtype:
+ self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list,
+ bucket_size=bucket_size,
+ reduce_rank=reduce_rank)
+
+ #######################
+ # Reduction Functions #
+ #######################
- def _reduce_grads_in_bucket(self, reduce_rank=None):
+ def _run_reduction(self, reduce_rank=None):
# reduce grads
- self._reduce_grads_by_rank(reduce_rank=reduce_rank,
- grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
- bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
+ self._reduce_grads(reduce_rank=reduce_rank,
+ grads=self._bucket_store.get_grad(reduce_rank=reduce_rank),
+ bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank))
# use communication stream if overlapping
# communication with computation
@@ -339,52 +366,30 @@ def _reduce_grads_in_bucket(self, reduce_rank=None):
self._bucket_store.reset_by_rank(reduce_rank)
- def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
- grad_buckets_by_dtype = split_half_float_double(grads)
-
- for tensor_list in grad_buckets_by_dtype:
- self._reduce_no_retain(tensor_list=tensor_list, bucket_size=bucket_size, reduce_rank=reduce_rank)
-
- ##############################
- # Reduction Utility Function #
- ##############################
- def _reduce_no_retain(self, tensor_list, bucket_size, reduce_rank):
- param_bucket = TensorBucket(size=bucket_size)
-
- for tensor in tensor_list:
- param_bucket.add_to_bucket(tensor, allow_oversize=True)
-
- if param_bucket.is_full_or_oversized():
- self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
- param_bucket.empty()
-
- if not param_bucket.is_empty():
- self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
+ def _add_to_reduction_bucket(self, param, reduce_rank=None):
+ param_size = param.numel()
- def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
- if self._overlap_communication:
- torch.cuda.synchronize()
- self._param_store.clear_grads_of_previous_reduced_params()
- stream = self._comm_stream
- else:
- stream = torch.cuda.current_stream()
+ # check if the bucket is full
+ # if full, will reduce the grads already in the bucket
+ # after reduction, the bucket will be empty
+ if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
+ self._run_reduction(reduce_rank)
- with torch.cuda.stream(stream):
- flat = bucket.flatten()
- reduced_flat = reduce_tensor(tensor=flat,
- dtype=self._communication_dtype,
- dst_rank=reduce_rank,
- parallel_mode=self._dp_parallel_mode)
+ # the param must not be reduced to ensure correctness
+ is_param_reduced = self._param_store.is_param_reduced(param)
+ if is_param_reduced:
+ msg = f'Parameter of size ({param.size()}) has already been reduced, ' \
+ + 'duplicate reduction will lead to arithmetic incorrectness'
+ raise RuntimeError(msg)
- # update the reduced tensor
- if reduce_rank is None or reduce_rank == self._local_rank:
- bucket.unflatten_and_copy(reduced_flat)
+ self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank)
+ self._bucket_store.add_param(param, reduce_rank)
################################
# torch.optim.Optimizer methods
################################
- def backward(self, loss, retain_graph=False):
+ def backward(self, loss, retain_graph=False, sync_grad=True):
loss = self.loss_scale * loss
loss.backward(retain_graph=retain_graph)
@@ -400,6 +405,10 @@ def backward(self, loss, retain_graph=False):
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
+ # gradient synchronization
+ if sync_grad:
+ self._sync_grad()
+
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
@@ -408,7 +417,7 @@ def zero_grad(self, set_to_none=True):
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
- for group_id, param_group in self._fp16_param_groups.items():
+ for _, param_group in self._fp16_param_groups.items():
for param in param_group:
if set_to_none:
param.grad = None
@@ -430,7 +439,7 @@ def step(self, closure=None):
# update loss scale if overflow occurs
if found_inf:
- self._grad_store._averaged_gradients = dict()
+ self._grad_store.reset_all_average_gradients()
self.zero_grad()
return
@@ -440,11 +449,11 @@ def step(self, closure=None):
for group_id in range(self.num_param_groups):
# compute norm
- norm_group = compute_norm(gradients=self._grad_store._averaged_gradients[group_id],
+ norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id),
params=self._param_store.get_fp16_params_by_rank_group(group_id=group_id,
rank=self._local_rank),
- dp_group=self._dp_group,
- mp_group=self._mp_group)
+ dp_group=self._dp_torch_group,
+ mp_group=self._mp_torch_group)
norm_groups.append(norm_group)
# create flat gradient for the flat fp32 params
@@ -461,8 +470,7 @@ def step(self, closure=None):
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
- self._grad_store._averaged_gradients[group_id] = []
- self._grad_store._averaged_gradients[group_id] = []
+ self._grad_store.reset_average_gradients_by_group(group_id)
# unscale and clip grads
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
@@ -482,9 +490,10 @@ def step(self, closure=None):
# broadcast the updated model weights
handles = []
for group_id in range(self.num_param_groups):
- for rank in range(self._world_size):
- fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
- handle = dist.broadcast(fp16_param, src=rank, group=self._dp_group, async_op=True)
+ for index in range(self._world_size):
+ rank = self._dp_global_ranks[index]
+ fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=index, group_id=group_id)
+ handle = dist.broadcast(fp16_param, src=rank, group=self._dp_torch_group, async_op=True)
handles.append(handle)
for handle in handles:
@@ -506,11 +515,11 @@ def _check_overflow(self):
break
# all-reduce across dp group
- dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_group)
+ dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
# all-reduce over model parallel group
- if self._mp_group:
- dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_group)
+ if self._mp_torch_group:
+ dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
if self._found_overflow.item() > 0:
return True
@@ -534,27 +543,25 @@ def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# Gradient Synchronization #
############################
- def sync_grad(self):
+ def _sync_grad(self):
# update param already reduced flag
reduction_states = self._param_store.get_param_reduction_states()
- for tensor, state in reduction_states.items():
+ for tensor, _ in reduction_states.items():
reduction_states[tensor] = False
# accumulate gradient
- avg_gradients = self._grad_store._averaged_gradients
for group_id in range(self.num_param_groups):
param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id)
- if group_id not in avg_gradients:
- avg_gradients[group_id] = []
+ avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id)
param_idx = 0
for param in param_group:
if param.grad is not None:
- if len(avg_gradients[group_id]) == param_idx:
- avg_gradients[group_id].append(param.grad)
+ if len(avg_gradients_group) == param_idx:
+ self._grad_store.append_average_gradient_by_group(group_id, param.grad)
else:
- avg_gradients[group_id][param_idx].add_(param.grad)
+ self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad)
param_idx += 1
# the gradients needed are stored in the avg_gradients buffer
@@ -569,11 +576,11 @@ def _reduce_grad_stage1(self):
param_group = self._fp16_param_groups[group_id]
for param in param_group:
if param.grad is not None:
- self._reduce_and_remove_grads_by_bucket(param)
+ self._add_to_reduction_bucket(param)
# we need to reduce the gradients
# left in the communication bucket
- self._reduce_grads_in_bucket()
+ self._run_reduction()
def _reduce_grad_stage2(self):
# when partition_grads is True, reduction hooks
@@ -581,4 +588,4 @@ def _reduce_grad_stage2(self):
# only need to reduce the gradients
# left in the communication bucket
for reduce_rank in range(self._world_size):
- self._reduce_grads_in_bucket(reduce_rank)
+ self._run_reduction(reduce_rank)
diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py
index 401ff988df4a..43a0b7d76107 100644
--- a/colossalai/zero/sharded_optim/sharded_optim_v2.py
+++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py
@@ -1,3 +1,4 @@
+# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from enum import Enum
from os import stat
from typing import Dict, Optional, Tuple
@@ -5,20 +6,21 @@
import torch
import torch.distributed as dist
import torch.nn as nn
+from torch import Tensor
+from torch.distributed import ProcessGroup
+from torch.nn.parameter import Parameter
+from torch.optim import Optimizer
+
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
+from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
+from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
+from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
-from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
-from torch import Tensor
-from torch.distributed import ProcessGroup
-from torch.nn.parameter import Parameter
-from torch.optim import Optimizer
-from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState)
-from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
class OptimState(Enum):
@@ -36,9 +38,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
- which is detected by a runtime memory tracer.
+ which is detected by a runtime memory tracer.
- We place as many OS chunks in the margin space as possible.
+ We place as many OS chunks in the margin space as possible.
The size of margin space can be controlled by ``gpu_margin_mem_ratio``.
If it is set as ``0.0``, it is the same as classical ZeRO optimizer.
@@ -54,8 +56,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
optimizer (Optimizer): An Optimizer instance.
- gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
- which will be used when using hybrid CPU optimizer.
+ gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
+ which will be used when using hybrid CPU optimizer.
This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
diff --git a/colossalai/zero/utils/gemini_hook.py b/colossalai/zero/utils/gemini_hook.py
index 35569c7172b3..bddc307a0504 100644
--- a/colossalai/zero/utils/gemini_hook.py
+++ b/colossalai/zero/utils/gemini_hook.py
@@ -8,6 +8,7 @@
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ColoParamOpHook
+from colossalai.utils import is_ddp_ignored
class TrainingPhase(Enum):
@@ -24,7 +25,7 @@ def __init__(self, gemini_manager: GeminiManager) -> None:
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
- params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
+ params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
@@ -37,7 +38,7 @@ def pre_op(self, params):
self._gemini_manager.record_model_data_volume()
def post_op(self, params):
- params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
+ params = [p for p in params if not is_ddp_ignored(p)]
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index bcb7c0fffbb3..49ff9b344268 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -1,17 +1,23 @@
FROM hpcaitech/cuda-conda:11.3
+# metainformation
+LABEL org.opencontainers.image.source = "https://github.com/hpcaitech/ColossalAI"
+LABEL org.opencontainers.image.licenses = "Apache License 2.0"
+LABEL org.opencontainers.image.base.name = "docker.io/library/hpcaitech/cuda-conda:11.3"
+
# install torch
-RUN conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
+RUN conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
# install apex
RUN git clone https://github.com/NVIDIA/apex && \
cd apex && \
+ pip install packaging && \
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" ./
# install colossalai
RUN git clone https://github.com/hpcaitech/ColossalAI.git \
&& cd ./ColossalAI \
- && pip install -v --no-cache-dir .
+ && CUDA_EXT=1 pip install -v --no-cache-dir .
# install titans
RUN pip install --no-cache-dir titans
@@ -21,4 +27,4 @@ RUN conda install cmake && \
git clone https://github.com/hpcaitech/TensorNVMe.git && \
cd TensorNVMe && \
pip install -r requirements.txt && \
- pip install -v --no-cache-dir .
\ No newline at end of file
+ pip install -v --no-cache-dir .
diff --git a/docs/Makefile b/docs/Makefile
deleted file mode 100644
index 9f43a48d6420..000000000000
--- a/docs/Makefile
+++ /dev/null
@@ -1,26 +0,0 @@
-# Minimal makefile for Sphinx documentation
-#
-
-# You can set these variables from the command line, and also
-# from the environment for the first two.
-SPHINXOPTS ?=
-SPHINXBUILD ?= sphinx-build
-SOURCEDIR = .
-BUILDDIR = .build
-SPHINXAPIDOC ?= sphinx-apidoc
-SPHINX_APIDOC_OPTIONS = members
-SPHINX_APIDOC_TEMPLATEDIR = _templates/apidoc
-
-# Put it first so that "make" without argument is like "make help".
-help:
- @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
-
-.PHONY: help Makefile apidoc
-
-apidoc:
- @SPHINX_APIDOC_OPTIONS=$(SPHINX_APIDOC_OPTIONS) $(SPHINXAPIDOC) -f -T -e -M -d 2 -t $(SPHINX_APIDOC_TEMPLATEDIR) -o ./colossalai ../colossalai
-# @$(SPHINXAPIDOC) -f -o ./model_zoo ../model_zoo
-# Catch-all target: route all unknown targets to Sphinx using the new
-# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
-%: Makefile
- @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/README-zh-Hans.md b/docs/README-zh-Hans.md
similarity index 68%
rename from README-zh-Hans.md
rename to docs/README-zh-Hans.md
index 8edcff28bf04..81c45abfd833 100644
--- a/README-zh-Hans.md
+++ b/docs/README-zh-Hans.md
@@ -3,15 +3,16 @@
[](https://www.colossalai.org/)
- Colossal-AI: 一个面向大模型时代的通用深度学习系统
+ Colossal-AI: 让AI大模型更低成本、方便易用、高效扩展
-
## 新闻
-* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0)
+* [2023/03] [AWS and Google Fund Colossal-AI with Startup Cloud Programs](https://www.hpc-ai.tech/blog/aws-and-google-fund-colossal-ai-with-startup-cloud-programs)
+* [2023/02] [Open source solution replicates ChatGPT training process! Ready to go with only 1.6GB GPU memory](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt)
+* [2023/01] [Hardware Savings Up to 46 Times for AIGC and Automatic Parallelism](https://medium.com/pytorch/latest-colossal-ai-boasts-novel-automatic-parallelism-and-offers-savings-up-to-46x-for-stable-1453b48f3f02)
* [2022/11] [Diffusion Pretraining and Hardware Fine-Tuning Can Be Almost 7X Cheaper](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper)
* [2022/10] [Use a Laptop to Analyze 90% of Proteins, With a Single-GPU Inference Sequence Exceeding 10,000](https://www.hpc-ai.tech/blog/use-a-laptop-to-analyze-90-of-proteins-with-a-single-gpu-inference-sequence-exceeding)
-* [2022/10] [Embedding Training With 1% GPU Memory and 100 Times Less Budget for Super-Large Recommendation Model](https://www.hpc-ai.tech/blog/embedding-training-with-1-gpu-memory-and-10-times-less-budget-an-open-source-solution-for)
* [2022/09] [HPC-AI Tech Completes $6 Million Seed and Angel Round Fundraising](https://www.hpc-ai.tech/blog/hpc-ai-tech-completes-6-million-seed-and-angel-round-fundraising-led-by-bluerun-ventures-in-the)
@@ -35,7 +37,7 @@
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 000000000000..f520608d552c
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,112 @@
+# 📕 Documentation
+
+## 🔗 Table of Contents
+
+- [📕 Documentation](#-documentation)
+ - [🔗 Table of Contents](#-table-of-contents)
+ - [📝 Overview](#-overview)
+ - [🗺 Module Structure](#-module-structure)
+ - [🧱 Our Documentation System](#-our-documentation-system)
+ - [🎊 Contribution](#-contribution)
+ - [🖊 Adding a New Documentation](#-adding-a-new-documentation)
+ - [🧹 Doc Testing](#-doc-testing)
+ - [💉 Auto Documentation](#-auto-documentation)
+
+## 📝 Overview
+
+We evaluated various existing solutions for documentation in the community and discussed their advantages and disadvantages in the [issue #2651](https://github.com/hpcaitech/ColossalAI/issues/2651). Therefore, we propose to build a more modern and robust documentation system by integrating the Sphinx [autodoc](https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html) function and the [Docusaurus](https://docusaurus.io/) framework.
+
+## 🗺 Module Structure
+
+```text
+- docs
+ - source
+ - en
+ - zh-Hans
+ - sidebars.json
+ - versions.json
+ - requirements-doc-test.txt
+```
+
+The documentation module structure is shown above:
+1. source: This folder contains multi-language documentation files.
+2. `sidebars.json`: The `sidebars.json` defines the table of content for the tutorials. You need to update this file when a new doc is added/deleted.
+3. `versions.json`: The `versions.json` in the **main branch** in the **latest commit** will be used to control the versions to be displayed on our website
+
+## 🧱 Our Documentation System
+
+We believe that the combination of the existing systems can yield several advantages such as simplicity, usability and maintainability:
+1. Support [Markdown](https://www.markdownguide.org/). We believe is a more popular language for writing documentation compared to [RST](https://docutils.sourceforge.io/rst.html).
+2. Support Autodoc. It can automatically generate documentation from the docstrings in the source code provided by [Sphinx](https://www.sphinx-doc.org/en/master/).
+3. Support elegant and modern UI, which is provided by [Docusaurus](https://docusaurus.io/).
+4. Support MDX for more flexible and powerful documentation, which is provided by [Docusaurus](https://docusaurus.io/).
+5. Support hosting blogs/project home page/other pages besides the documentation, which is provided by [Docusaurus](https://docusaurus.io/).
+
+Therefore, we have built the [ColossalAI-Documentation](https://github.com/hpcaitech/ColossalAI-Documentation) repository to integrate the features above.
+
+## 🎊 Contribution
+
+You can contribute to the documentation by directly setting up a Pull Request towards the `docs/source` folder. There are several guidelines for documentation contribution.
+
+1. The documentation is written in Markdown. You can refer to the [Markdown Guide](https://www.markdownguide.org/) for the syntax.
+2. You must ensure that the documentation exists for all languages. You can refer to the [Adding a New Documentation](#-adding-a-new-documentation) for more details.
+3. You must provide a test command for your documentation, please see [Doc Testing](#-doc-testing) for more details.
+4. You can embed your docstring in your markdown, please see [Auto Documentation](#-auto-documentation) for more details.
+
+### 🖊 Adding a New Documentation
+
+You can add a Markdown file to the `docs/source` folder`. You need to ensure that multi-language is supported in your PR.
+Let's assume that you want to add a file called `your_doc.md`, your file structure will look like this.
+
+```text
+- docs
+ - source
+ - en
+ - your_doc.md # written in English
+ - zh-Hans
+ - your_doc.md # written in Chinese
+ - sidebars.json # add your documentation file name here
+```
+
+Meanwhile, you need to ensure the `sidebars.json` is updated such that it contains your documentation file. Our CI will check whether documentation exists for all languages and can be used to build the website successfully.
+
+### 🧹 Doc Testing
+
+Every documentation is tested to ensure it works well. You need to add the following line to the **bottom of your file** and replace `$command` with the actual command. Do note that the markdown will be converted into a Python file. Assuming you have a `demo.md` file, the test file generated will be `demo.py`. Therefore, you should use `demo.py` in your command, e.g. `python demo.py`.
+
+```markdown
+
+```
+
+Meanwhile, only code labeled as a Python code block will be considered for testing.
+
+```markdown
+ ```python
+ print("hello world")
+ ```
+```
+
+Lastly, if you want to skip some code, you just need to add the following annotations to tell `docer` to discard the wrapped code for testing.
+
+```markdown
+
+
+ ```python
+ print("hello world")
+ ```
+
+
+```
+
+If you have any dependency required, please add it to `requriements-doc-test.txt` for pip and `conda-doc-test-deps.yml` for Conda.
+
+
+### 💉 Auto Documentation
+
+Lastly, you may want to include the API documentation for a class/function in your documentation for reference.
+We support `autodoc` to extract the docstring and transform it into a Web element for an elegant display.
+You just need to add `{{ autodoc: }}` in your markdown as a single line. An example is given below and you can see the outcome in [this PR](https://github.com/hpcaitech/ColossalAI-Documentation/pull/175).
+
+```markdown
+{{ autodoc:colossalai.amp.apex_amp.convert_to_apex_amp }}
+```
diff --git a/docs/REFERENCE.md b/docs/REFERENCE.md
new file mode 100644
index 000000000000..2681198191cb
--- /dev/null
+++ b/docs/REFERENCE.md
@@ -0,0 +1,38 @@
+# References
+
+The Colossal-AI project aims to provide a wide array of parallelism techniques for the machine learning community in the big-model era. This project is inspired by quite a few reserach works, some are conducted by some of our developers and the others are research projects open-sourced by other organizations. We would like to credit these amazing projects below in the IEEE citation format.
+
+## By Our Team
+
+- Q. Xu, S. Li, C. Gong, and Y. You, ‘An Efficient 2D Method for Training Super-Large Deep Learning Models’. arXiv, 2021.
+
+- Z. Bian, Q. Xu, B. Wang, and Y. You, ‘Maximizing Parallelism in Distributed Training for Huge Neural Networks’. arXiv, 2021.
+
+- S. Li, F. Xue, C. Baranwal, Y. Li, and Y. You, ‘Sequence Parallelism: Long Sequence Training from System Perspective’. arXiv, 2021.
+
+- S. Li et al., ‘Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training’. arXiv, 2021.
+
+- B. Wang, Q. Xu, Z. Bian, and Y. You, ‘Tesseract: Parallelize the Tensor Parallelism Efficiently’, in Proceedings of the 51th International Conference on Parallel Processing, 2022.
+
+- J. Fang et al., ‘A Frequency-aware Software Cache for Large Recommendation System Embeddings’. arXiv, 2022.
+
+- J. Fang et al., ‘Parallel Training of Pre-Trained Models via Chunk-Based Dynamic Memory Management’, IEEE Transactions on Parallel and Distributed Systems, vol. 34, no. 1, pp. 304–315, 2023.
+
+- Y. Liu, S. Li, J. Fang, Y. Shao, B. Yao, and Y. You, ‘Colossal-Auto: Unified Automation of Parallelization and Activation Checkpoint for Large-scale Models’. arXiv, 2023.
+
+
+## By Other Organizations
+
+- M. Shoeybi, M. Patwary, R. Puri, P. LeGresley, J. Casper, and B. Catanzaro, ‘Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism’. arXiv, 2019.
+
+- S. Rajbhandari, J. Rasley, O. Ruwase, and Y. He, ‘ZeRO: Memory Optimizations toward Training Trillion Parameter Models’, in Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis, 2020.
+
+- J. Rasley, S. Rajbhandari, O. Ruwase, and Y. He, ‘DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters’, in Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, Virtual Event, CA, USA, 2020, pp. 3505–3506.
+
+- D. Narayanan et al., ‘Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM’, in Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis, St. Louis, Missouri, 2021.
+
+- Jie Ren, Samyam Rajbhandari, Reza Yazdani Aminabadi, Olatunji Ruwase, Shuangyan Yang, Minjia Zhang, Dong Li, Yuxiong He. 2021. ZeRO-Offload: Democratizing Billion-Scale Model Training. arXiv:2101.06840 and USENIX ATC 2021.
+
+- S. Rajbhandari, O. Ruwase, J. Rasley, S. Smith, and Y. He, ‘ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning’. in Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis, St. Louis, Missouri, 2021.
+
+- L. Zheng et al., ‘Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning’, in 16th USENIX Symposium on Operating Systems Design and Implementation (OSDI 22), 2022, pp. 559–578.
diff --git a/docs/_static/css/rtd_theme.css b/docs/_static/css/rtd_theme.css
deleted file mode 100644
index caf42dc5aaab..000000000000
--- a/docs/_static/css/rtd_theme.css
+++ /dev/null
@@ -1,3 +0,0 @@
-.wy-nav-content {
- max-width: 80%;
-}
\ No newline at end of file
diff --git a/docs/_templates/apidoc/module.rst_t b/docs/_templates/apidoc/module.rst_t
deleted file mode 100644
index d9a50e6b9752..000000000000
--- a/docs/_templates/apidoc/module.rst_t
+++ /dev/null
@@ -1,9 +0,0 @@
-{%- if show_headings %}
-{{- basename | e | heading }}
-
-{% endif -%}
-.. automodule:: {{ qualname }}
-{%- for option in automodule_options %}
- :{{ option }}:
-{%- endfor %}
-
diff --git a/docs/_templates/apidoc/package.rst_t b/docs/_templates/apidoc/package.rst_t
deleted file mode 100644
index 83742b3f7c66..000000000000
--- a/docs/_templates/apidoc/package.rst_t
+++ /dev/null
@@ -1,52 +0,0 @@
-{%- macro automodule(modname, options) -%}
-.. automodule:: {{ modname }}
-{%- for option in options %}
- :{{ option }}:
-{%- endfor %}
-{%- endmacro %}
-
-{%- macro toctree(docnames) -%}
-.. toctree::
- :maxdepth: {{ maxdepth }}
-{% for docname in docnames %}
- {{ docname }}
-{%- endfor %}
-{%- endmacro %}
-
-{%- if is_namespace %}
-{{- pkgname | e | heading }}
-{% else %}
-{{- pkgname | e | heading }}
-{% endif %}
-
-{%- if is_namespace %}
-.. py:module:: {{ pkgname }}
-{% endif %}
-
-{%- if modulefirst and not is_namespace %}
-{{ automodule(pkgname, automodule_options) }}
-{% endif %}
-
-{%- if subpackages %}
-{{ toctree(subpackages) }}
-{% endif %}
-
-{%- if submodules %}
-{% if separatemodules %}
-{{ toctree(submodules) }}
-{% else %}
-{%- for submodule in submodules %}
-{% if show_headings %}
-{{- submodule | e | heading(2) }}
-{% endif %}
-{{ automodule(submodule, automodule_options) }}
-{% endfor %}
-{%- endif %}
-{%- endif %}
-
-{%- if not modulefirst and not is_namespace %}
-Module contents
----------------
-
-{{ automodule(pkgname, automodule_options) }}
-{% endif %}
diff --git a/docs/_templates/apidoc/toc.rst_t b/docs/_templates/apidoc/toc.rst_t
deleted file mode 100644
index f0877eeb2f85..000000000000
--- a/docs/_templates/apidoc/toc.rst_t
+++ /dev/null
@@ -1,8 +0,0 @@
-{{ header | heading }}
-
-.. toctree::
- :maxdepth: {{ maxdepth }}
-{% for docname in docnames %}
- {{ docname }}
-{%- endfor %}
-
diff --git a/docs/colossalai/colossalai.amp.amp_type.rst b/docs/colossalai/colossalai.amp.amp_type.rst
deleted file mode 100644
index 067af7d8c51a..000000000000
--- a/docs/colossalai/colossalai.amp.amp_type.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.amp.amp\_type
-========================
-
-.. automodule:: colossalai.amp.amp_type
- :members:
diff --git a/docs/colossalai/colossalai.amp.apex_amp.apex_amp.rst b/docs/colossalai/colossalai.amp.apex_amp.apex_amp.rst
deleted file mode 100644
index cba7e00625a4..000000000000
--- a/docs/colossalai/colossalai.amp.apex_amp.apex_amp.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.amp.apex\_amp.apex\_amp
-==================================
-
-.. automodule:: colossalai.amp.apex_amp.apex_amp
- :members:
diff --git a/docs/colossalai/colossalai.amp.apex_amp.rst b/docs/colossalai/colossalai.amp.apex_amp.rst
deleted file mode 100644
index 7116a538b4c1..000000000000
--- a/docs/colossalai/colossalai.amp.apex_amp.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.amp.apex\_amp
-========================
-
-.. automodule:: colossalai.amp.apex_amp
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.amp.apex_amp.apex_amp
diff --git a/docs/colossalai/colossalai.amp.naive_amp.grad_scaler.rst b/docs/colossalai/colossalai.amp.naive_amp.grad_scaler.rst
deleted file mode 100644
index 12d477825659..000000000000
--- a/docs/colossalai/colossalai.amp.naive_amp.grad_scaler.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-colossalai.amp.naive\_amp.grad\_scaler
-======================================
-
-.. automodule:: colossalai.amp.naive_amp.grad_scaler
- :members:
-
-
-
diff --git a/docs/colossalai/colossalai.amp.naive_amp.naive_amp.rst b/docs/colossalai/colossalai.amp.naive_amp.naive_amp.rst
deleted file mode 100644
index e20f22b2e386..000000000000
--- a/docs/colossalai/colossalai.amp.naive_amp.naive_amp.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.amp.naive\_amp.naive\_amp
-====================================
-
-.. automodule:: colossalai.amp.naive_amp.naive_amp
- :members:
diff --git a/docs/colossalai/colossalai.amp.naive_amp.rst b/docs/colossalai/colossalai.amp.naive_amp.rst
deleted file mode 100644
index fd364c05331c..000000000000
--- a/docs/colossalai/colossalai.amp.naive_amp.rst
+++ /dev/null
@@ -1,16 +0,0 @@
-colossalai.amp.naive\_amp
-=========================
-
-.. automodule:: colossalai.amp.naive_amp
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.amp.naive_amp.grad_scaler
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.amp.naive_amp.naive_amp
diff --git a/docs/colossalai/colossalai.amp.rst b/docs/colossalai/colossalai.amp.rst
deleted file mode 100644
index 5ef4f36c13ac..000000000000
--- a/docs/colossalai/colossalai.amp.rst
+++ /dev/null
@@ -1,18 +0,0 @@
-colossalai.amp
-==============
-
-.. automodule:: colossalai.amp
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.amp.apex_amp
- colossalai.amp.naive_amp
- colossalai.amp.torch_amp
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.amp.amp_type
diff --git a/docs/colossalai/colossalai.amp.torch_amp.rst b/docs/colossalai/colossalai.amp.torch_amp.rst
deleted file mode 100644
index f10095f136e0..000000000000
--- a/docs/colossalai/colossalai.amp.torch_amp.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.amp.torch\_amp
-=========================
-
-.. automodule:: colossalai.amp.torch_amp
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.amp.torch_amp.torch_amp
diff --git a/docs/colossalai/colossalai.amp.torch_amp.torch_amp.rst b/docs/colossalai/colossalai.amp.torch_amp.torch_amp.rst
deleted file mode 100644
index 5f1549cb8d48..000000000000
--- a/docs/colossalai/colossalai.amp.torch_amp.torch_amp.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.amp.torch\_amp.torch\_amp
-====================================
-
-.. automodule:: colossalai.amp.torch_amp.torch_amp
- :members:
diff --git a/docs/colossalai/colossalai.builder.builder.rst b/docs/colossalai/colossalai.builder.builder.rst
deleted file mode 100644
index 85da78ab9e3d..000000000000
--- a/docs/colossalai/colossalai.builder.builder.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.builder.builder
-==========================
-
-.. automodule:: colossalai.builder.builder
- :members:
diff --git a/docs/colossalai/colossalai.builder.rst b/docs/colossalai/colossalai.builder.rst
deleted file mode 100644
index 61163d7c1ea1..000000000000
--- a/docs/colossalai/colossalai.builder.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.builder
-==================
-
-.. automodule:: colossalai.builder
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.builder.builder
diff --git a/docs/colossalai/colossalai.cli.benchmark.benchmark.rst b/docs/colossalai/colossalai.cli.benchmark.benchmark.rst
deleted file mode 100644
index 94a4170c8590..000000000000
--- a/docs/colossalai/colossalai.cli.benchmark.benchmark.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.cli.benchmark.benchmark
-==================================
-
-.. automodule:: colossalai.cli.benchmark.benchmark
- :members:
diff --git a/docs/colossalai/colossalai.cli.benchmark.models.rst b/docs/colossalai/colossalai.cli.benchmark.models.rst
deleted file mode 100644
index 4e6290288d59..000000000000
--- a/docs/colossalai/colossalai.cli.benchmark.models.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.cli.benchmark.models
-===============================
-
-.. automodule:: colossalai.cli.benchmark.models
- :members:
diff --git a/docs/colossalai/colossalai.cli.benchmark.rst b/docs/colossalai/colossalai.cli.benchmark.rst
deleted file mode 100644
index 80fb43dde04b..000000000000
--- a/docs/colossalai/colossalai.cli.benchmark.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.cli.benchmark
-========================
-
-.. automodule:: colossalai.cli.benchmark
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.cli.benchmark.benchmark
- colossalai.cli.benchmark.models
- colossalai.cli.benchmark.utils
diff --git a/docs/colossalai/colossalai.cli.benchmark.utils.rst b/docs/colossalai/colossalai.cli.benchmark.utils.rst
deleted file mode 100644
index 12fbaf2270ec..000000000000
--- a/docs/colossalai/colossalai.cli.benchmark.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.cli.benchmark.utils
-==============================
-
-.. automodule:: colossalai.cli.benchmark.utils
- :members:
diff --git a/docs/colossalai/colossalai.cli.check.check_installation.rst b/docs/colossalai/colossalai.cli.check.check_installation.rst
deleted file mode 100644
index 95b2d02ca371..000000000000
--- a/docs/colossalai/colossalai.cli.check.check_installation.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.cli.check.check\_installation
-========================================
-
-.. automodule:: colossalai.cli.check.check_installation
- :members:
diff --git a/docs/colossalai/colossalai.cli.check.rst b/docs/colossalai/colossalai.cli.check.rst
deleted file mode 100644
index 262ae7ad31ba..000000000000
--- a/docs/colossalai/colossalai.cli.check.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.cli.check
-====================
-
-.. automodule:: colossalai.cli.check
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.cli.check.check_installation
diff --git a/docs/colossalai/colossalai.cli.cli.rst b/docs/colossalai/colossalai.cli.cli.rst
deleted file mode 100644
index 8f83973d5e0c..000000000000
--- a/docs/colossalai/colossalai.cli.cli.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.cli.cli
-==================
-
-.. automodule:: colossalai.cli.cli
- :members:
diff --git a/docs/colossalai/colossalai.cli.launcher.hostinfo.rst b/docs/colossalai/colossalai.cli.launcher.hostinfo.rst
deleted file mode 100644
index 5bcd9dd8cc4c..000000000000
--- a/docs/colossalai/colossalai.cli.launcher.hostinfo.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.cli.launcher.hostinfo
-================================
-
-.. automodule:: colossalai.cli.launcher.hostinfo
- :members:
diff --git a/docs/colossalai/colossalai.cli.launcher.multinode_runner.rst b/docs/colossalai/colossalai.cli.launcher.multinode_runner.rst
deleted file mode 100644
index 223b0deac1f1..000000000000
--- a/docs/colossalai/colossalai.cli.launcher.multinode_runner.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.cli.launcher.multinode\_runner
-=========================================
-
-.. automodule:: colossalai.cli.launcher.multinode_runner
- :members:
diff --git a/docs/colossalai/colossalai.cli.launcher.rst b/docs/colossalai/colossalai.cli.launcher.rst
deleted file mode 100644
index 38bef61c790d..000000000000
--- a/docs/colossalai/colossalai.cli.launcher.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.cli.launcher
-=======================
-
-.. automodule:: colossalai.cli.launcher
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.cli.launcher.hostinfo
- colossalai.cli.launcher.multinode_runner
- colossalai.cli.launcher.run
diff --git a/docs/colossalai/colossalai.cli.launcher.run.rst b/docs/colossalai/colossalai.cli.launcher.run.rst
deleted file mode 100644
index 8506fb9e3165..000000000000
--- a/docs/colossalai/colossalai.cli.launcher.run.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.cli.launcher.run
-===========================
-
-.. automodule:: colossalai.cli.launcher.run
- :members:
diff --git a/docs/colossalai/colossalai.cli.rst b/docs/colossalai/colossalai.cli.rst
deleted file mode 100644
index 8cc0dcb04aed..000000000000
--- a/docs/colossalai/colossalai.cli.rst
+++ /dev/null
@@ -1,18 +0,0 @@
-colossalai.cli
-==============
-
-.. automodule:: colossalai.cli
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.cli.benchmark
- colossalai.cli.check
- colossalai.cli.launcher
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.cli.cli
diff --git a/docs/colossalai/colossalai.communication.collective.rst b/docs/colossalai/colossalai.communication.collective.rst
deleted file mode 100644
index 5015edf98901..000000000000
--- a/docs/colossalai/colossalai.communication.collective.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.communication.collective
-===================================
-
-.. automodule:: colossalai.communication.collective
- :members:
diff --git a/docs/colossalai/colossalai.communication.p2p.rst b/docs/colossalai/colossalai.communication.p2p.rst
deleted file mode 100644
index 79135bb8630f..000000000000
--- a/docs/colossalai/colossalai.communication.p2p.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.communication.p2p
-============================
-
-.. automodule:: colossalai.communication.p2p
- :members:
diff --git a/docs/colossalai/colossalai.communication.ring.rst b/docs/colossalai/colossalai.communication.ring.rst
deleted file mode 100644
index c218d4bed350..000000000000
--- a/docs/colossalai/colossalai.communication.ring.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.communication.ring
-=============================
-
-.. automodule:: colossalai.communication.ring
- :members:
diff --git a/docs/colossalai/colossalai.communication.rst b/docs/colossalai/colossalai.communication.rst
deleted file mode 100644
index 5086fa663ec7..000000000000
--- a/docs/colossalai/colossalai.communication.rst
+++ /dev/null
@@ -1,14 +0,0 @@
-colossalai.communication
-========================
-
-.. automodule:: colossalai.communication
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.communication.collective
- colossalai.communication.p2p
- colossalai.communication.ring
- colossalai.communication.utils
diff --git a/docs/colossalai/colossalai.communication.utils.rst b/docs/colossalai/colossalai.communication.utils.rst
deleted file mode 100644
index 19a36cc9ff6f..000000000000
--- a/docs/colossalai/colossalai.communication.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.communication.utils
-==============================
-
-.. automodule:: colossalai.communication.utils
- :members:
diff --git a/docs/colossalai/colossalai.constants.rst b/docs/colossalai/colossalai.constants.rst
deleted file mode 100644
index 330b3e8668ec..000000000000
--- a/docs/colossalai/colossalai.constants.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.constants
-====================
-
-.. automodule:: colossalai.constants
- :members:
diff --git a/docs/colossalai/colossalai.context.config.rst b/docs/colossalai/colossalai.context.config.rst
deleted file mode 100644
index 2fb1b99d3e7a..000000000000
--- a/docs/colossalai/colossalai.context.config.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.config
-=========================
-
-.. automodule:: colossalai.context.config
- :members:
diff --git a/docs/colossalai/colossalai.context.moe_context.rst b/docs/colossalai/colossalai.context.moe_context.rst
deleted file mode 100644
index 9027d19ff023..000000000000
--- a/docs/colossalai/colossalai.context.moe_context.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.moe\_context
-===============================
-
-.. automodule:: colossalai.context.moe_context
- :members:
diff --git a/docs/colossalai/colossalai.context.parallel_context.rst b/docs/colossalai/colossalai.context.parallel_context.rst
deleted file mode 100644
index d1c82c517845..000000000000
--- a/docs/colossalai/colossalai.context.parallel_context.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.parallel\_context
-====================================
-
-.. automodule:: colossalai.context.parallel_context
- :members:
diff --git a/docs/colossalai/colossalai.context.parallel_mode.rst b/docs/colossalai/colossalai.context.parallel_mode.rst
deleted file mode 100644
index f7ac137493fb..000000000000
--- a/docs/colossalai/colossalai.context.parallel_mode.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.parallel\_mode
-=================================
-
-.. automodule:: colossalai.context.parallel_mode
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_1d.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_1d.rst
deleted file mode 100644
index 88cbf3ebadb3..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.initializer_1d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.initializer\_1d
-==============================================================
-
-.. automodule:: colossalai.context.process_group_initializer.initializer_1d
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_2d.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_2d.rst
deleted file mode 100644
index d99a2e1c3177..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.initializer_2d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.initializer\_2d
-==============================================================
-
-.. automodule:: colossalai.context.process_group_initializer.initializer_2d
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_2p5d.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_2p5d.rst
deleted file mode 100644
index 73d80e4431bb..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.initializer_2p5d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.initializer\_2p5d
-================================================================
-
-.. automodule:: colossalai.context.process_group_initializer.initializer_2p5d
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_3d.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_3d.rst
deleted file mode 100644
index 5cfba5ce0870..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.initializer_3d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.initializer\_3d
-==============================================================
-
-.. automodule:: colossalai.context.process_group_initializer.initializer_3d
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_data.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_data.rst
deleted file mode 100644
index 55ad05f32b14..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.initializer_data.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.initializer\_data
-================================================================
-
-.. automodule:: colossalai.context.process_group_initializer.initializer_data
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_model.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_model.rst
deleted file mode 100644
index 8f2d79369915..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.initializer_model.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.initializer\_model
-=================================================================
-
-.. automodule:: colossalai.context.process_group_initializer.initializer_model
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_pipeline.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_pipeline.rst
deleted file mode 100644
index 466d5143a02b..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.initializer_pipeline.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.initializer\_pipeline
-====================================================================
-
-.. automodule:: colossalai.context.process_group_initializer.initializer_pipeline
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_sequence.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_sequence.rst
deleted file mode 100644
index dab71cc3c391..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.initializer_sequence.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.initializer\_sequence
-====================================================================
-
-.. automodule:: colossalai.context.process_group_initializer.initializer_sequence
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.initializer_tensor.rst b/docs/colossalai/colossalai.context.process_group_initializer.initializer_tensor.rst
deleted file mode 100644
index 0c2d8d1e9daa..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.initializer_tensor.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.initializer\_tensor
-==================================================================
-
-.. automodule:: colossalai.context.process_group_initializer.initializer_tensor
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.process_group_initializer.rst b/docs/colossalai/colossalai.context.process_group_initializer.process_group_initializer.rst
deleted file mode 100644
index 3f98723c170b..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.process_group_initializer.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.process\_group\_initializer.process\_group\_initializer
-==========================================================================
-
-.. automodule:: colossalai.context.process_group_initializer.process_group_initializer
- :members:
diff --git a/docs/colossalai/colossalai.context.process_group_initializer.rst b/docs/colossalai/colossalai.context.process_group_initializer.rst
deleted file mode 100644
index 519337e9c71d..000000000000
--- a/docs/colossalai/colossalai.context.process_group_initializer.rst
+++ /dev/null
@@ -1,20 +0,0 @@
-colossalai.context.process\_group\_initializer
-==============================================
-
-.. automodule:: colossalai.context.process_group_initializer
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.context.process_group_initializer.initializer_1d
- colossalai.context.process_group_initializer.initializer_2d
- colossalai.context.process_group_initializer.initializer_2p5d
- colossalai.context.process_group_initializer.initializer_3d
- colossalai.context.process_group_initializer.initializer_data
- colossalai.context.process_group_initializer.initializer_model
- colossalai.context.process_group_initializer.initializer_pipeline
- colossalai.context.process_group_initializer.initializer_sequence
- colossalai.context.process_group_initializer.initializer_tensor
- colossalai.context.process_group_initializer.process_group_initializer
diff --git a/docs/colossalai/colossalai.context.random.rst b/docs/colossalai/colossalai.context.random.rst
deleted file mode 100644
index 8d4b9c56af3c..000000000000
--- a/docs/colossalai/colossalai.context.random.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.context.random
-=========================
-
-.. automodule:: colossalai.context.random
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.context.random.seed_manager
diff --git a/docs/colossalai/colossalai.context.random.seed_manager.rst b/docs/colossalai/colossalai.context.random.seed_manager.rst
deleted file mode 100644
index b71f35c2750c..000000000000
--- a/docs/colossalai/colossalai.context.random.seed_manager.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.random.seed\_manager
-=======================================
-
-.. automodule:: colossalai.context.random.seed_manager
- :members:
diff --git a/docs/colossalai/colossalai.context.rst b/docs/colossalai/colossalai.context.rst
deleted file mode 100644
index 102a9e02eaa4..000000000000
--- a/docs/colossalai/colossalai.context.rst
+++ /dev/null
@@ -1,21 +0,0 @@
-colossalai.context
-==================
-
-.. automodule:: colossalai.context
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.context.process_group_initializer
- colossalai.context.random
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.context.config
- colossalai.context.moe_context
- colossalai.context.parallel_context
- colossalai.context.parallel_mode
- colossalai.context.singleton_meta
diff --git a/docs/colossalai/colossalai.context.singleton_meta.rst b/docs/colossalai/colossalai.context.singleton_meta.rst
deleted file mode 100644
index ae4ceb314f32..000000000000
--- a/docs/colossalai/colossalai.context.singleton_meta.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.context.singleton\_meta
-==================================
-
-.. automodule:: colossalai.context.singleton_meta
- :members:
diff --git a/docs/colossalai/colossalai.core.rst b/docs/colossalai/colossalai.core.rst
deleted file mode 100644
index d9ddb76ed72a..000000000000
--- a/docs/colossalai/colossalai.core.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.core
-===============
-
-.. automodule:: colossalai.core
- :members:
diff --git a/docs/colossalai/colossalai.engine.gradient_accumulation.rst b/docs/colossalai/colossalai.engine.gradient_accumulation.rst
deleted file mode 100644
index 75fc0e9a24eb..000000000000
--- a/docs/colossalai/colossalai.engine.gradient_accumulation.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.engine.gradient\_accumulation
-========================================
-
-.. automodule:: colossalai.engine.gradient_accumulation
- :members:
diff --git a/docs/colossalai/colossalai.engine.gradient_handler.rst b/docs/colossalai/colossalai.engine.gradient_handler.rst
deleted file mode 100644
index 27eb2b56a29f..000000000000
--- a/docs/colossalai/colossalai.engine.gradient_handler.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.engine.gradient\_handler
-===================================
-
-.. automodule:: colossalai.engine.gradient_handler
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.engine.gradient_handler.utils
diff --git a/docs/colossalai/colossalai.engine.gradient_handler.utils.rst b/docs/colossalai/colossalai.engine.gradient_handler.utils.rst
deleted file mode 100644
index c8997e135b60..000000000000
--- a/docs/colossalai/colossalai.engine.gradient_handler.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.engine.gradient\_handler.utils
-=========================================
-
-.. automodule:: colossalai.engine.gradient_handler.utils
- :members:
diff --git a/docs/colossalai/colossalai.engine.rst b/docs/colossalai/colossalai.engine.rst
deleted file mode 100644
index 3d194b70695e..000000000000
--- a/docs/colossalai/colossalai.engine.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-colossalai.engine
-=================
-
-.. automodule:: colossalai.engine
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.engine.gradient_accumulation
- colossalai.engine.gradient_handler
- colossalai.engine.schedule
diff --git a/docs/colossalai/colossalai.engine.schedule.rst b/docs/colossalai/colossalai.engine.schedule.rst
deleted file mode 100644
index 2909373f0002..000000000000
--- a/docs/colossalai/colossalai.engine.schedule.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.engine.schedule
-==========================
-
-.. automodule:: colossalai.engine.schedule
- :members:
diff --git a/docs/colossalai/colossalai.fx.passes.adding_split_node_pass.rst b/docs/colossalai/colossalai.fx.passes.adding_split_node_pass.rst
deleted file mode 100644
index 6799fdc658cd..000000000000
--- a/docs/colossalai/colossalai.fx.passes.adding_split_node_pass.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.fx.passes.adding\_split\_node\_pass
-==============================================
-
-.. automodule:: colossalai.fx.passes.adding_split_node_pass
- :members:
diff --git a/docs/colossalai/colossalai.fx.passes.meta_info_prop.rst b/docs/colossalai/colossalai.fx.passes.meta_info_prop.rst
deleted file mode 100644
index 4e51732ce83d..000000000000
--- a/docs/colossalai/colossalai.fx.passes.meta_info_prop.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.fx.passes.meta\_info\_prop
-=====================================
-
-.. automodule:: colossalai.fx.passes.meta_info_prop
- :members:
diff --git a/docs/colossalai/colossalai.fx.passes.rst b/docs/colossalai/colossalai.fx.passes.rst
deleted file mode 100644
index fac10b768034..000000000000
--- a/docs/colossalai/colossalai.fx.passes.rst
+++ /dev/null
@@ -1,15 +0,0 @@
-colossalai.fx.passes
-====================
-
-.. automodule:: colossalai.fx.passes
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.fx.passes.adding_split_node_pass
- colossalai.fx.passes.meta_info_prop
- colossalai.fx.passes.shard_1d_pass
- colossalai.fx.passes.split_module
- colossalai.fx.passes.utils
diff --git a/docs/colossalai/colossalai.fx.passes.shard_1d_pass.rst b/docs/colossalai/colossalai.fx.passes.shard_1d_pass.rst
deleted file mode 100644
index 0942e96d46dc..000000000000
--- a/docs/colossalai/colossalai.fx.passes.shard_1d_pass.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.fx.passes.shard\_1d\_pass
-====================================
-
-.. automodule:: colossalai.fx.passes.shard_1d_pass
- :members:
diff --git a/docs/colossalai/colossalai.fx.passes.split_module.rst b/docs/colossalai/colossalai.fx.passes.split_module.rst
deleted file mode 100644
index 9e5e58259254..000000000000
--- a/docs/colossalai/colossalai.fx.passes.split_module.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.fx.passes.split\_module
-==================================
-
-.. automodule:: colossalai.fx.passes.split_module
- :members:
diff --git a/docs/colossalai/colossalai.fx.passes.utils.rst b/docs/colossalai/colossalai.fx.passes.utils.rst
deleted file mode 100644
index 4afd9256322b..000000000000
--- a/docs/colossalai/colossalai.fx.passes.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.fx.passes.utils
-==========================
-
-.. automodule:: colossalai.fx.passes.utils
- :members:
diff --git a/docs/colossalai/colossalai.fx.proxy.rst b/docs/colossalai/colossalai.fx.proxy.rst
deleted file mode 100644
index 4b92da41c794..000000000000
--- a/docs/colossalai/colossalai.fx.proxy.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.fx.proxy
-===================
-
-.. automodule:: colossalai.fx.proxy
- :members:
diff --git a/docs/colossalai/colossalai.fx.rst b/docs/colossalai/colossalai.fx.rst
deleted file mode 100644
index 778d642c3a11..000000000000
--- a/docs/colossalai/colossalai.fx.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-colossalai.fx
-=============
-
-.. automodule:: colossalai.fx
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.fx.passes
- colossalai.fx.tracer
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.fx.proxy
diff --git a/docs/colossalai/colossalai.fx.tracer.rst b/docs/colossalai/colossalai.fx.tracer.rst
deleted file mode 100644
index d2f743d67d55..000000000000
--- a/docs/colossalai/colossalai.fx.tracer.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.fx.tracer
-====================
-
-.. automodule:: colossalai.fx.tracer
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.fx.tracer.tracer
diff --git a/docs/colossalai/colossalai.fx.tracer.tracer.rst b/docs/colossalai/colossalai.fx.tracer.tracer.rst
deleted file mode 100644
index 83b98bafd825..000000000000
--- a/docs/colossalai/colossalai.fx.tracer.tracer.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.fx.tracer.tracer
-===========================
-
-.. automodule:: colossalai.fx.tracer.tracer
- :members:
diff --git a/docs/colossalai/colossalai.gemini.chunk.rst b/docs/colossalai/colossalai.gemini.chunk.rst
deleted file mode 100644
index 9fe1c2b415d6..000000000000
--- a/docs/colossalai/colossalai.gemini.chunk.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.chunk
-=======================
-
-.. automodule:: colossalai.gemini.chunk
- :members:
diff --git a/docs/colossalai/colossalai.gemini.chunk_mgr.rst b/docs/colossalai/colossalai.gemini.chunk_mgr.rst
deleted file mode 100644
index acb554faf319..000000000000
--- a/docs/colossalai/colossalai.gemini.chunk_mgr.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.chunk\_mgr
-============================
-
-.. automodule:: colossalai.gemini.chunk_mgr
- :members:
diff --git a/docs/colossalai/colossalai.gemini.gemini_context.rst b/docs/colossalai/colossalai.gemini.gemini_context.rst
deleted file mode 100644
index be4884062253..000000000000
--- a/docs/colossalai/colossalai.gemini.gemini_context.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.gemini\_context
-=================================
-
-.. automodule:: colossalai.gemini.gemini_context
- :members:
diff --git a/docs/colossalai/colossalai.gemini.gemini_mgr.rst b/docs/colossalai/colossalai.gemini.gemini_mgr.rst
deleted file mode 100644
index 5d7f944f7a56..000000000000
--- a/docs/colossalai/colossalai.gemini.gemini_mgr.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.gemini\_mgr
-=============================
-
-.. automodule:: colossalai.gemini.gemini_mgr
- :members:
diff --git a/docs/colossalai/colossalai.gemini.memory_tracer.memory_monitor.rst b/docs/colossalai/colossalai.gemini.memory_tracer.memory_monitor.rst
deleted file mode 100644
index e8088a609f34..000000000000
--- a/docs/colossalai/colossalai.gemini.memory_tracer.memory_monitor.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.memory\_tracer.memory\_monitor
-================================================
-
-.. automodule:: colossalai.gemini.memory_tracer.memory_monitor
- :members:
diff --git a/docs/colossalai/colossalai.gemini.memory_tracer.memstats_collector.rst b/docs/colossalai/colossalai.gemini.memory_tracer.memstats_collector.rst
deleted file mode 100644
index e2682220c27b..000000000000
--- a/docs/colossalai/colossalai.gemini.memory_tracer.memstats_collector.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.memory\_tracer.memstats\_collector
-====================================================
-
-.. automodule:: colossalai.gemini.memory_tracer.memstats_collector
- :members:
diff --git a/docs/colossalai/colossalai.gemini.memory_tracer.model_data_memtracer.rst b/docs/colossalai/colossalai.gemini.memory_tracer.model_data_memtracer.rst
deleted file mode 100644
index ccdfe6682c3f..000000000000
--- a/docs/colossalai/colossalai.gemini.memory_tracer.model_data_memtracer.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.memory\_tracer.model\_data\_memtracer
-=======================================================
-
-.. automodule:: colossalai.gemini.memory_tracer.model_data_memtracer
- :members:
diff --git a/docs/colossalai/colossalai.gemini.memory_tracer.rst b/docs/colossalai/colossalai.gemini.memory_tracer.rst
deleted file mode 100644
index f3d9c4d76dd8..000000000000
--- a/docs/colossalai/colossalai.gemini.memory_tracer.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.gemini.memory\_tracer
-================================
-
-.. automodule:: colossalai.gemini.memory_tracer
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.gemini.memory_tracer.memory_monitor
- colossalai.gemini.memory_tracer.memstats_collector
- colossalai.gemini.memory_tracer.model_data_memtracer
diff --git a/docs/colossalai/colossalai.gemini.ophooks.rst b/docs/colossalai/colossalai.gemini.ophooks.rst
deleted file mode 100644
index af87ab568ac0..000000000000
--- a/docs/colossalai/colossalai.gemini.ophooks.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.gemini.ophooks
-=========================
-
-.. automodule:: colossalai.gemini.ophooks
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.gemini.ophooks.utils
diff --git a/docs/colossalai/colossalai.gemini.ophooks.utils.rst b/docs/colossalai/colossalai.gemini.ophooks.utils.rst
deleted file mode 100644
index 5c5917047f44..000000000000
--- a/docs/colossalai/colossalai.gemini.ophooks.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.ophooks.utils
-===============================
-
-.. automodule:: colossalai.gemini.ophooks.utils
- :members:
diff --git a/docs/colossalai/colossalai.gemini.paramhooks.rst b/docs/colossalai/colossalai.gemini.paramhooks.rst
deleted file mode 100644
index 28a823d4e69c..000000000000
--- a/docs/colossalai/colossalai.gemini.paramhooks.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.paramhooks
-============================
-
-.. automodule:: colossalai.gemini.paramhooks
- :members:
diff --git a/docs/colossalai/colossalai.gemini.placement_policy.rst b/docs/colossalai/colossalai.gemini.placement_policy.rst
deleted file mode 100644
index 9de0ed52371b..000000000000
--- a/docs/colossalai/colossalai.gemini.placement_policy.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.placement\_policy
-===================================
-
-.. automodule:: colossalai.gemini.placement_policy
- :members:
diff --git a/docs/colossalai/colossalai.gemini.rst b/docs/colossalai/colossalai.gemini.rst
deleted file mode 100644
index 4f6efe386521..000000000000
--- a/docs/colossalai/colossalai.gemini.rst
+++ /dev/null
@@ -1,27 +0,0 @@
-colossalai.gemini
-=================
-
-.. automodule:: colossalai.gemini
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.gemini.memory_tracer
- colossalai.gemini.ophooks
- colossalai.gemini.paramhooks
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.gemini.chunk
- colossalai.gemini.chunk_mgr
- colossalai.gemini.gemini_context
- colossalai.gemini.gemini_mgr
- colossalai.gemini.placement_policy
- colossalai.gemini.stateful_tensor
- colossalai.gemini.stateful_tensor_container
- colossalai.gemini.stateful_tensor_mgr
- colossalai.gemini.tensor_placement_policy
- colossalai.gemini.tensor_utils
diff --git a/docs/colossalai/colossalai.gemini.stateful_tensor.rst b/docs/colossalai/colossalai.gemini.stateful_tensor.rst
deleted file mode 100644
index 02d526d1b4c8..000000000000
--- a/docs/colossalai/colossalai.gemini.stateful_tensor.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.stateful\_tensor
-==================================
-
-.. automodule:: colossalai.gemini.stateful_tensor
- :members:
diff --git a/docs/colossalai/colossalai.gemini.stateful_tensor_container.rst b/docs/colossalai/colossalai.gemini.stateful_tensor_container.rst
deleted file mode 100644
index be56c2aa8ed2..000000000000
--- a/docs/colossalai/colossalai.gemini.stateful_tensor_container.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.stateful\_tensor\_container
-=============================================
-
-.. automodule:: colossalai.gemini.stateful_tensor_container
- :members:
diff --git a/docs/colossalai/colossalai.gemini.stateful_tensor_mgr.rst b/docs/colossalai/colossalai.gemini.stateful_tensor_mgr.rst
deleted file mode 100644
index 3456192bd735..000000000000
--- a/docs/colossalai/colossalai.gemini.stateful_tensor_mgr.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.stateful\_tensor\_mgr
-=======================================
-
-.. automodule:: colossalai.gemini.stateful_tensor_mgr
- :members:
diff --git a/docs/colossalai/colossalai.gemini.tensor_placement_policy.rst b/docs/colossalai/colossalai.gemini.tensor_placement_policy.rst
deleted file mode 100644
index 81dcac339048..000000000000
--- a/docs/colossalai/colossalai.gemini.tensor_placement_policy.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.tensor\_placement\_policy
-===========================================
-
-.. automodule:: colossalai.gemini.tensor_placement_policy
- :members:
diff --git a/docs/colossalai/colossalai.gemini.tensor_utils.rst b/docs/colossalai/colossalai.gemini.tensor_utils.rst
deleted file mode 100644
index 385baf4b50bb..000000000000
--- a/docs/colossalai/colossalai.gemini.tensor_utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.gemini.tensor\_utils
-===============================
-
-.. automodule:: colossalai.gemini.tensor_utils
- :members:
diff --git a/docs/colossalai/colossalai.global_variables.rst b/docs/colossalai/colossalai.global_variables.rst
deleted file mode 100644
index 1900c88351ff..000000000000
--- a/docs/colossalai/colossalai.global_variables.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.global\_variables
-============================
-
-.. automodule:: colossalai.global_variables
- :members:
diff --git a/docs/colossalai/colossalai.initialize.rst b/docs/colossalai/colossalai.initialize.rst
deleted file mode 100644
index d3f65076a795..000000000000
--- a/docs/colossalai/colossalai.initialize.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.initialize
-=====================
-
-.. automodule:: colossalai.initialize
- :members:
diff --git a/docs/colossalai/colossalai.kernel.cuda_native.layer_norm.rst b/docs/colossalai/colossalai.kernel.cuda_native.layer_norm.rst
deleted file mode 100644
index b8bff51bef34..000000000000
--- a/docs/colossalai/colossalai.kernel.cuda_native.layer_norm.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.kernel.cuda\_native.layer\_norm
-==========================================
-
-.. automodule:: colossalai.kernel.cuda_native.layer_norm
- :members:
diff --git a/docs/colossalai/colossalai.kernel.cuda_native.multihead_attention.rst b/docs/colossalai/colossalai.kernel.cuda_native.multihead_attention.rst
deleted file mode 100644
index de7577d195cd..000000000000
--- a/docs/colossalai/colossalai.kernel.cuda_native.multihead_attention.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.kernel.cuda\_native.multihead\_attention
-===================================================
-
-.. automodule:: colossalai.kernel.cuda_native.multihead_attention
- :members:
diff --git a/docs/colossalai/colossalai.kernel.cuda_native.rst b/docs/colossalai/colossalai.kernel.cuda_native.rst
deleted file mode 100644
index d88e4cfdb761..000000000000
--- a/docs/colossalai/colossalai.kernel.cuda_native.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.kernel.cuda\_native
-==============================
-
-.. automodule:: colossalai.kernel.cuda_native
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.kernel.cuda_native.layer_norm
- colossalai.kernel.cuda_native.multihead_attention
- colossalai.kernel.cuda_native.scaled_softmax
diff --git a/docs/colossalai/colossalai.kernel.cuda_native.scaled_softmax.rst b/docs/colossalai/colossalai.kernel.cuda_native.scaled_softmax.rst
deleted file mode 100644
index 474fcd3349bd..000000000000
--- a/docs/colossalai/colossalai.kernel.cuda_native.scaled_softmax.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.kernel.cuda\_native.scaled\_softmax
-==============================================
-
-.. automodule:: colossalai.kernel.cuda_native.scaled_softmax
- :members:
diff --git a/docs/colossalai/colossalai.kernel.jit.bias_dropout_add.rst b/docs/colossalai/colossalai.kernel.jit.bias_dropout_add.rst
deleted file mode 100644
index d61550928bc8..000000000000
--- a/docs/colossalai/colossalai.kernel.jit.bias_dropout_add.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.kernel.jit.bias\_dropout\_add
-========================================
-
-.. automodule:: colossalai.kernel.jit.bias_dropout_add
- :members:
diff --git a/docs/colossalai/colossalai.kernel.jit.bias_gelu.rst b/docs/colossalai/colossalai.kernel.jit.bias_gelu.rst
deleted file mode 100644
index 7db184b4ce3b..000000000000
--- a/docs/colossalai/colossalai.kernel.jit.bias_gelu.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.kernel.jit.bias\_gelu
-================================
-
-.. automodule:: colossalai.kernel.jit.bias_gelu
- :members:
diff --git a/docs/colossalai/colossalai.kernel.jit.option.rst b/docs/colossalai/colossalai.kernel.jit.option.rst
deleted file mode 100644
index 15ebfc83aa77..000000000000
--- a/docs/colossalai/colossalai.kernel.jit.option.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.kernel.jit.option
-============================
-
-.. automodule:: colossalai.kernel.jit.option
- :members:
diff --git a/docs/colossalai/colossalai.kernel.jit.rst b/docs/colossalai/colossalai.kernel.jit.rst
deleted file mode 100644
index 8b2f728d34d5..000000000000
--- a/docs/colossalai/colossalai.kernel.jit.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.kernel.jit
-=====================
-
-.. automodule:: colossalai.kernel.jit
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.kernel.jit.bias_dropout_add
- colossalai.kernel.jit.bias_gelu
- colossalai.kernel.jit.option
diff --git a/docs/colossalai/colossalai.kernel.rst b/docs/colossalai/colossalai.kernel.rst
deleted file mode 100644
index dcbac8c1de76..000000000000
--- a/docs/colossalai/colossalai.kernel.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.kernel
-=================
-
-.. automodule:: colossalai.kernel
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.kernel.cuda_native
- colossalai.kernel.jit
diff --git a/docs/colossalai/colossalai.logging.logger.rst b/docs/colossalai/colossalai.logging.logger.rst
deleted file mode 100644
index 047deb8a1d19..000000000000
--- a/docs/colossalai/colossalai.logging.logger.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.logging.logger
-=========================
-
-.. automodule:: colossalai.logging.logger
- :members:
diff --git a/docs/colossalai/colossalai.logging.rst b/docs/colossalai/colossalai.logging.rst
deleted file mode 100644
index bc593fc81bf4..000000000000
--- a/docs/colossalai/colossalai.logging.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.logging
-==================
-
-.. automodule:: colossalai.logging
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.logging.logger
diff --git a/docs/colossalai/colossalai.nn.graph.graph_node.rst b/docs/colossalai/colossalai.nn.graph.graph_node.rst
deleted file mode 100644
index 335ecfe620fe..000000000000
--- a/docs/colossalai/colossalai.nn.graph.graph_node.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.graph.graph\_node
-===============================
-
-.. automodule:: colossalai.nn.graph.graph_node
- :members:
diff --git a/docs/colossalai/colossalai.nn.graph.rst b/docs/colossalai/colossalai.nn.graph.rst
deleted file mode 100644
index 4510b3374f2a..000000000000
--- a/docs/colossalai/colossalai.nn.graph.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-colossalai.nn.graph
-===================
-
-.. automodule:: colossalai.nn.graph
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.graph.graph_node
- colossalai.nn.graph.utils
diff --git a/docs/colossalai/colossalai.nn.graph.utils.rst b/docs/colossalai/colossalai.nn.graph.utils.rst
deleted file mode 100644
index 866a93cd9201..000000000000
--- a/docs/colossalai/colossalai.nn.graph.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.graph.utils
-=========================
-
-.. automodule:: colossalai.nn.graph.utils
- :members:
diff --git a/docs/colossalai/colossalai.nn.init.rst b/docs/colossalai/colossalai.nn.init.rst
deleted file mode 100644
index d0ab993126d5..000000000000
--- a/docs/colossalai/colossalai.nn.init.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.init
-==================
-
-.. automodule:: colossalai.nn.init
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.base_layer.rst b/docs/colossalai/colossalai.nn.layer.base_layer.rst
deleted file mode 100644
index c2a22f04d3f3..000000000000
--- a/docs/colossalai/colossalai.nn.layer.base_layer.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.base\_layer
-===============================
-
-.. automodule:: colossalai.nn.layer.base_layer
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.dropout.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.dropout.rst
deleted file mode 100644
index ec1dfd395f17..000000000000
--- a/docs/colossalai/colossalai.nn.layer.colossalai_layer.dropout.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.colossalai\_layer.dropout
-=============================================
-
-.. automodule:: colossalai.nn.layer.colossalai_layer.dropout
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.embedding.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.embedding.rst
deleted file mode 100644
index 8438b3a07787..000000000000
--- a/docs/colossalai/colossalai.nn.layer.colossalai_layer.embedding.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.colossalai\_layer.embedding
-===============================================
-
-.. automodule:: colossalai.nn.layer.colossalai_layer.embedding
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.linear.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.linear.rst
deleted file mode 100644
index 3213282549ea..000000000000
--- a/docs/colossalai/colossalai.nn.layer.colossalai_layer.linear.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.colossalai\_layer.linear
-============================================
-
-.. automodule:: colossalai.nn.layer.colossalai_layer.linear
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.normalization.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.normalization.rst
deleted file mode 100644
index f94dd27b86e4..000000000000
--- a/docs/colossalai/colossalai.nn.layer.colossalai_layer.normalization.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.colossalai\_layer.normalization
-===================================================
-
-.. automodule:: colossalai.nn.layer.colossalai_layer.normalization
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.colossalai_layer.rst b/docs/colossalai/colossalai.nn.layer.colossalai_layer.rst
deleted file mode 100644
index 0f685e6c2dc3..000000000000
--- a/docs/colossalai/colossalai.nn.layer.colossalai_layer.rst
+++ /dev/null
@@ -1,14 +0,0 @@
-colossalai.nn.layer.colossalai\_layer
-=====================================
-
-.. automodule:: colossalai.nn.layer.colossalai_layer
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.colossalai_layer.dropout
- colossalai.nn.layer.colossalai_layer.embedding
- colossalai.nn.layer.colossalai_layer.linear
- colossalai.nn.layer.colossalai_layer.normalization
diff --git a/docs/colossalai/colossalai.nn.layer.moe.experts.rst b/docs/colossalai/colossalai.nn.layer.moe.experts.rst
deleted file mode 100644
index c05e763d5723..000000000000
--- a/docs/colossalai/colossalai.nn.layer.moe.experts.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.moe.experts
-===============================
-
-.. automodule:: colossalai.nn.layer.moe.experts
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.moe.layers.rst b/docs/colossalai/colossalai.nn.layer.moe.layers.rst
deleted file mode 100644
index d109d47b8174..000000000000
--- a/docs/colossalai/colossalai.nn.layer.moe.layers.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.moe.layers
-==============================
-
-.. automodule:: colossalai.nn.layer.moe.layers
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.moe.rst b/docs/colossalai/colossalai.nn.layer.moe.rst
deleted file mode 100644
index f3106b98d405..000000000000
--- a/docs/colossalai/colossalai.nn.layer.moe.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.nn.layer.moe
-=======================
-
-.. automodule:: colossalai.nn.layer.moe
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.moe.experts
- colossalai.nn.layer.moe.layers
- colossalai.nn.layer.moe.utils
diff --git a/docs/colossalai/colossalai.nn.layer.moe.utils.rst b/docs/colossalai/colossalai.nn.layer.moe.utils.rst
deleted file mode 100644
index fc085d136bb4..000000000000
--- a/docs/colossalai/colossalai.nn.layer.moe.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.moe.utils
-=============================
-
-.. automodule:: colossalai.nn.layer.moe.utils
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_1d.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_1d.layers.rst
deleted file mode 100644
index 380f6bf8d134..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_1d.layers.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.parallel\_1d.layers
-=======================================
-
-.. automodule:: colossalai.nn.layer.parallel_1d.layers
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_1d.rst b/docs/colossalai/colossalai.nn.layer.parallel_1d.rst
deleted file mode 100644
index 3a8ed6206721..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_1d.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.nn.layer.parallel\_1d
-================================
-
-.. automodule:: colossalai.nn.layer.parallel_1d
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.parallel_1d.layers
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_2d.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_2d.layers.rst
deleted file mode 100644
index b64d402bdf3e..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_2d.layers.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.parallel\_2d.layers
-=======================================
-
-.. automodule:: colossalai.nn.layer.parallel_2d.layers
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_2d.rst b/docs/colossalai/colossalai.nn.layer.parallel_2d.rst
deleted file mode 100644
index f5ad41a1b450..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_2d.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.nn.layer.parallel\_2d
-================================
-
-.. automodule:: colossalai.nn.layer.parallel_2d
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.parallel_2d.layers
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_2p5d.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_2p5d.layers.rst
deleted file mode 100644
index ebc99d56ccdc..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_2p5d.layers.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.parallel\_2p5d.layers
-=========================================
-
-.. automodule:: colossalai.nn.layer.parallel_2p5d.layers
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_2p5d.rst b/docs/colossalai/colossalai.nn.layer.parallel_2p5d.rst
deleted file mode 100644
index 5869bdee9928..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_2p5d.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.nn.layer.parallel\_2p5d
-==================================
-
-.. automodule:: colossalai.nn.layer.parallel_2p5d
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.parallel_2p5d.layers
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_3d.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_3d.layers.rst
deleted file mode 100644
index a1702f1fcf62..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_3d.layers.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.parallel\_3d.layers
-=======================================
-
-.. automodule:: colossalai.nn.layer.parallel_3d.layers
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_3d.rst b/docs/colossalai/colossalai.nn.layer.parallel_3d.rst
deleted file mode 100644
index bb55a63e507d..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_3d.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.nn.layer.parallel\_3d
-================================
-
-.. automodule:: colossalai.nn.layer.parallel_3d
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.parallel_3d.layers
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_sequence.layers.rst b/docs/colossalai/colossalai.nn.layer.parallel_sequence.layers.rst
deleted file mode 100644
index 54929d2e7169..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_sequence.layers.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.parallel\_sequence.layers
-=============================================
-
-.. automodule:: colossalai.nn.layer.parallel_sequence.layers
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.parallel_sequence.rst b/docs/colossalai/colossalai.nn.layer.parallel_sequence.rst
deleted file mode 100644
index 24e8941d4ec4..000000000000
--- a/docs/colossalai/colossalai.nn.layer.parallel_sequence.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.nn.layer.parallel\_sequence
-======================================
-
-.. automodule:: colossalai.nn.layer.parallel_sequence
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.parallel_sequence.layers
diff --git a/docs/colossalai/colossalai.nn.layer.rst b/docs/colossalai/colossalai.nn.layer.rst
deleted file mode 100644
index 32a93128f2a4..000000000000
--- a/docs/colossalai/colossalai.nn.layer.rst
+++ /dev/null
@@ -1,25 +0,0 @@
-colossalai.nn.layer
-===================
-
-.. automodule:: colossalai.nn.layer
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.colossalai_layer
- colossalai.nn.layer.moe
- colossalai.nn.layer.parallel_1d
- colossalai.nn.layer.parallel_2d
- colossalai.nn.layer.parallel_2p5d
- colossalai.nn.layer.parallel_3d
- colossalai.nn.layer.parallel_sequence
- colossalai.nn.layer.utils
- colossalai.nn.layer.vanilla
- colossalai.nn.layer.wrapper
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.base_layer
diff --git a/docs/colossalai/colossalai.nn.layer.utils.common.rst b/docs/colossalai/colossalai.nn.layer.utils.common.rst
deleted file mode 100644
index 6a552830f8f5..000000000000
--- a/docs/colossalai/colossalai.nn.layer.utils.common.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.utils.common
-================================
-
-.. automodule:: colossalai.nn.layer.utils.common
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.utils.rst b/docs/colossalai/colossalai.nn.layer.utils.rst
deleted file mode 100644
index 16c3d718286a..000000000000
--- a/docs/colossalai/colossalai.nn.layer.utils.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.nn.layer.utils
-=========================
-
-.. automodule:: colossalai.nn.layer.utils
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.utils.common
diff --git a/docs/colossalai/colossalai.nn.layer.vanilla.layers.rst b/docs/colossalai/colossalai.nn.layer.vanilla.layers.rst
deleted file mode 100644
index f993b1f50e5b..000000000000
--- a/docs/colossalai/colossalai.nn.layer.vanilla.layers.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.vanilla.layers
-==================================
-
-.. automodule:: colossalai.nn.layer.vanilla.layers
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.vanilla.rst b/docs/colossalai/colossalai.nn.layer.vanilla.rst
deleted file mode 100644
index fe1ea5c6c53e..000000000000
--- a/docs/colossalai/colossalai.nn.layer.vanilla.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.nn.layer.vanilla
-===========================
-
-.. automodule:: colossalai.nn.layer.vanilla
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.vanilla.layers
diff --git a/docs/colossalai/colossalai.nn.layer.wrapper.pipeline_wrapper.rst b/docs/colossalai/colossalai.nn.layer.wrapper.pipeline_wrapper.rst
deleted file mode 100644
index e5648873d34b..000000000000
--- a/docs/colossalai/colossalai.nn.layer.wrapper.pipeline_wrapper.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.layer.wrapper.pipeline\_wrapper
-=============================================
-
-.. automodule:: colossalai.nn.layer.wrapper.pipeline_wrapper
- :members:
diff --git a/docs/colossalai/colossalai.nn.layer.wrapper.rst b/docs/colossalai/colossalai.nn.layer.wrapper.rst
deleted file mode 100644
index 761bf843af36..000000000000
--- a/docs/colossalai/colossalai.nn.layer.wrapper.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.nn.layer.wrapper
-===========================
-
-.. automodule:: colossalai.nn.layer.wrapper
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.layer.wrapper.pipeline_wrapper
diff --git a/docs/colossalai/colossalai.nn.loss.loss_1d.rst b/docs/colossalai/colossalai.nn.loss.loss_1d.rst
deleted file mode 100644
index d9ac2e67d317..000000000000
--- a/docs/colossalai/colossalai.nn.loss.loss_1d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.loss.loss\_1d
-===========================
-
-.. automodule:: colossalai.nn.loss.loss_1d
- :members:
diff --git a/docs/colossalai/colossalai.nn.loss.loss_2d.rst b/docs/colossalai/colossalai.nn.loss.loss_2d.rst
deleted file mode 100644
index 14d1585e3e0f..000000000000
--- a/docs/colossalai/colossalai.nn.loss.loss_2d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.loss.loss\_2d
-===========================
-
-.. automodule:: colossalai.nn.loss.loss_2d
- :members:
diff --git a/docs/colossalai/colossalai.nn.loss.loss_2p5d.rst b/docs/colossalai/colossalai.nn.loss.loss_2p5d.rst
deleted file mode 100644
index fc3714da3630..000000000000
--- a/docs/colossalai/colossalai.nn.loss.loss_2p5d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.loss.loss\_2p5d
-=============================
-
-.. automodule:: colossalai.nn.loss.loss_2p5d
- :members:
diff --git a/docs/colossalai/colossalai.nn.loss.loss_3d.rst b/docs/colossalai/colossalai.nn.loss.loss_3d.rst
deleted file mode 100644
index a593324fb4f1..000000000000
--- a/docs/colossalai/colossalai.nn.loss.loss_3d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.loss.loss\_3d
-===========================
-
-.. automodule:: colossalai.nn.loss.loss_3d
- :members:
diff --git a/docs/colossalai/colossalai.nn.loss.loss_moe.rst b/docs/colossalai/colossalai.nn.loss.loss_moe.rst
deleted file mode 100644
index ef2851ace83a..000000000000
--- a/docs/colossalai/colossalai.nn.loss.loss_moe.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.loss.loss\_moe
-============================
-
-.. automodule:: colossalai.nn.loss.loss_moe
- :members:
diff --git a/docs/colossalai/colossalai.nn.loss.rst b/docs/colossalai/colossalai.nn.loss.rst
deleted file mode 100644
index 5df7d1ae3770..000000000000
--- a/docs/colossalai/colossalai.nn.loss.rst
+++ /dev/null
@@ -1,15 +0,0 @@
-colossalai.nn.loss
-==================
-
-.. automodule:: colossalai.nn.loss
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.loss.loss_1d
- colossalai.nn.loss.loss_2d
- colossalai.nn.loss.loss_2p5d
- colossalai.nn.loss.loss_3d
- colossalai.nn.loss.loss_moe
diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.cosine.rst b/docs/colossalai/colossalai.nn.lr_scheduler.cosine.rst
deleted file mode 100644
index a7c636ad3a36..000000000000
--- a/docs/colossalai/colossalai.nn.lr_scheduler.cosine.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.lr\_scheduler.cosine
-==================================
-
-.. automodule:: colossalai.nn.lr_scheduler.cosine
- :members:
diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.delayed.rst b/docs/colossalai/colossalai.nn.lr_scheduler.delayed.rst
deleted file mode 100644
index 2a86c4b2a20c..000000000000
--- a/docs/colossalai/colossalai.nn.lr_scheduler.delayed.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.lr\_scheduler.delayed
-===================================
-
-.. automodule:: colossalai.nn.lr_scheduler.delayed
- :members:
diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.linear.rst b/docs/colossalai/colossalai.nn.lr_scheduler.linear.rst
deleted file mode 100644
index 5e917edc2faf..000000000000
--- a/docs/colossalai/colossalai.nn.lr_scheduler.linear.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.lr\_scheduler.linear
-==================================
-
-.. automodule:: colossalai.nn.lr_scheduler.linear
- :members:
diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.multistep.rst b/docs/colossalai/colossalai.nn.lr_scheduler.multistep.rst
deleted file mode 100644
index 4248a6386375..000000000000
--- a/docs/colossalai/colossalai.nn.lr_scheduler.multistep.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.lr\_scheduler.multistep
-=====================================
-
-.. automodule:: colossalai.nn.lr_scheduler.multistep
- :members:
diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.onecycle.rst b/docs/colossalai/colossalai.nn.lr_scheduler.onecycle.rst
deleted file mode 100644
index 7f2fd47586fe..000000000000
--- a/docs/colossalai/colossalai.nn.lr_scheduler.onecycle.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.lr\_scheduler.onecycle
-====================================
-
-.. automodule:: colossalai.nn.lr_scheduler.onecycle
- :members:
diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.poly.rst b/docs/colossalai/colossalai.nn.lr_scheduler.poly.rst
deleted file mode 100644
index c1618812aa0c..000000000000
--- a/docs/colossalai/colossalai.nn.lr_scheduler.poly.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.lr\_scheduler.poly
-================================
-
-.. automodule:: colossalai.nn.lr_scheduler.poly
- :members:
diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.rst b/docs/colossalai/colossalai.nn.lr_scheduler.rst
deleted file mode 100644
index 427a3ee4529e..000000000000
--- a/docs/colossalai/colossalai.nn.lr_scheduler.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-colossalai.nn.lr\_scheduler
-===========================
-
-.. automodule:: colossalai.nn.lr_scheduler
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.lr_scheduler.cosine
- colossalai.nn.lr_scheduler.delayed
- colossalai.nn.lr_scheduler.linear
- colossalai.nn.lr_scheduler.multistep
- colossalai.nn.lr_scheduler.onecycle
- colossalai.nn.lr_scheduler.poly
- colossalai.nn.lr_scheduler.torch
diff --git a/docs/colossalai/colossalai.nn.lr_scheduler.torch.rst b/docs/colossalai/colossalai.nn.lr_scheduler.torch.rst
deleted file mode 100644
index f8d552bf1d62..000000000000
--- a/docs/colossalai/colossalai.nn.lr_scheduler.torch.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.lr\_scheduler.torch
-=================================
-
-.. automodule:: colossalai.nn.lr_scheduler.torch
- :members:
diff --git a/docs/colossalai/colossalai.nn.metric.accuracy_2d.rst b/docs/colossalai/colossalai.nn.metric.accuracy_2d.rst
deleted file mode 100644
index 63bcb8349763..000000000000
--- a/docs/colossalai/colossalai.nn.metric.accuracy_2d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.metric.accuracy\_2d
-=================================
-
-.. automodule:: colossalai.nn.metric.accuracy_2d
- :members:
diff --git a/docs/colossalai/colossalai.nn.metric.accuracy_2p5d.rst b/docs/colossalai/colossalai.nn.metric.accuracy_2p5d.rst
deleted file mode 100644
index dd4358fbff72..000000000000
--- a/docs/colossalai/colossalai.nn.metric.accuracy_2p5d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.metric.accuracy\_2p5d
-===================================
-
-.. automodule:: colossalai.nn.metric.accuracy_2p5d
- :members:
diff --git a/docs/colossalai/colossalai.nn.metric.accuracy_3d.rst b/docs/colossalai/colossalai.nn.metric.accuracy_3d.rst
deleted file mode 100644
index 95143444b945..000000000000
--- a/docs/colossalai/colossalai.nn.metric.accuracy_3d.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.metric.accuracy\_3d
-=================================
-
-.. automodule:: colossalai.nn.metric.accuracy_3d
- :members:
diff --git a/docs/colossalai/colossalai.nn.metric.rst b/docs/colossalai/colossalai.nn.metric.rst
deleted file mode 100644
index 28f5568eb846..000000000000
--- a/docs/colossalai/colossalai.nn.metric.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.nn.metric
-====================
-
-.. automodule:: colossalai.nn.metric
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.metric.accuracy_2d
- colossalai.nn.metric.accuracy_2p5d
- colossalai.nn.metric.accuracy_3d
diff --git a/docs/colossalai/colossalai.nn.optimizer.colossalai_optimizer.rst b/docs/colossalai/colossalai.nn.optimizer.colossalai_optimizer.rst
deleted file mode 100644
index 35515c374f33..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.colossalai_optimizer.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.optimizer.colossalai\_optimizer
-=============================================
-
-.. automodule:: colossalai.nn.optimizer.colossalai_optimizer
- :members:
diff --git a/docs/colossalai/colossalai.nn.optimizer.cpu_adam.rst b/docs/colossalai/colossalai.nn.optimizer.cpu_adam.rst
deleted file mode 100644
index 224dfab43ed0..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.cpu_adam.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.optimizer.cpu\_adam
-=================================
-
-.. automodule:: colossalai.nn.optimizer.cpu_adam
- :members:
diff --git a/docs/colossalai/colossalai.nn.optimizer.fused_adam.rst b/docs/colossalai/colossalai.nn.optimizer.fused_adam.rst
deleted file mode 100644
index 60af624cb6c1..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.fused_adam.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.optimizer.fused\_adam
-===================================
-
-.. automodule:: colossalai.nn.optimizer.fused_adam
- :members:
diff --git a/docs/colossalai/colossalai.nn.optimizer.fused_lamb.rst b/docs/colossalai/colossalai.nn.optimizer.fused_lamb.rst
deleted file mode 100644
index 66c0fa4ca1c7..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.fused_lamb.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.optimizer.fused\_lamb
-===================================
-
-.. automodule:: colossalai.nn.optimizer.fused_lamb
- :members:
diff --git a/docs/colossalai/colossalai.nn.optimizer.fused_sgd.rst b/docs/colossalai/colossalai.nn.optimizer.fused_sgd.rst
deleted file mode 100644
index 2ecc77c33d88..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.fused_sgd.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.optimizer.fused\_sgd
-==================================
-
-.. automodule:: colossalai.nn.optimizer.fused_sgd
- :members:
diff --git a/docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst b/docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst
deleted file mode 100644
index 20508d664701..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.optimizer.hybrid\_adam
-====================================
-
-.. automodule:: colossalai.nn.optimizer.hybrid_adam
- :members:
diff --git a/docs/colossalai/colossalai.nn.optimizer.lamb.rst b/docs/colossalai/colossalai.nn.optimizer.lamb.rst
deleted file mode 100644
index 57199ea36951..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.lamb.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.optimizer.lamb
-============================
-
-.. automodule:: colossalai.nn.optimizer.lamb
- :members:
diff --git a/docs/colossalai/colossalai.nn.optimizer.lars.rst b/docs/colossalai/colossalai.nn.optimizer.lars.rst
deleted file mode 100644
index f935950f8b5a..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.lars.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.optimizer.lars
-============================
-
-.. automodule:: colossalai.nn.optimizer.lars
- :members:
diff --git a/docs/colossalai/colossalai.nn.optimizer.rst b/docs/colossalai/colossalai.nn.optimizer.rst
deleted file mode 100644
index ede9cc496967..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.rst
+++ /dev/null
@@ -1,19 +0,0 @@
-colossalai.nn.optimizer
-=======================
-
-.. automodule:: colossalai.nn.optimizer
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.optimizer.colossalai_optimizer
- colossalai.nn.optimizer.cpu_adam
- colossalai.nn.optimizer.fused_adam
- colossalai.nn.optimizer.fused_lamb
- colossalai.nn.optimizer.fused_sgd
- colossalai.nn.optimizer.hybrid_adam
- colossalai.nn.optimizer.lamb
- colossalai.nn.optimizer.lars
- colossalai.nn.optimizer.utils
diff --git a/docs/colossalai/colossalai.nn.optimizer.utils.rst b/docs/colossalai/colossalai.nn.optimizer.utils.rst
deleted file mode 100644
index 9b2bc2f016c4..000000000000
--- a/docs/colossalai/colossalai.nn.optimizer.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.optimizer.utils
-=============================
-
-.. automodule:: colossalai.nn.optimizer.utils
- :members:
diff --git a/docs/colossalai/colossalai.nn.parallel.data_parallel.rst b/docs/colossalai/colossalai.nn.parallel.data_parallel.rst
deleted file mode 100644
index ba987c2ee2f3..000000000000
--- a/docs/colossalai/colossalai.nn.parallel.data_parallel.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.parallel.data\_parallel
-=====================================
-
-.. automodule:: colossalai.nn.parallel.data_parallel
- :members:
diff --git a/docs/colossalai/colossalai.nn.parallel.layers.colo_module.rst b/docs/colossalai/colossalai.nn.parallel.layers.colo_module.rst
deleted file mode 100644
index c80fff6d543a..000000000000
--- a/docs/colossalai/colossalai.nn.parallel.layers.colo_module.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.parallel.layers.colo\_module
-==========================================
-
-.. automodule:: colossalai.nn.parallel.layers.colo_module
- :members:
diff --git a/docs/colossalai/colossalai.nn.parallel.layers.embedding.rst b/docs/colossalai/colossalai.nn.parallel.layers.embedding.rst
deleted file mode 100644
index 1e7ecc50f478..000000000000
--- a/docs/colossalai/colossalai.nn.parallel.layers.embedding.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.parallel.layers.embedding
-=======================================
-
-.. automodule:: colossalai.nn.parallel.layers.embedding
- :members:
diff --git a/docs/colossalai/colossalai.nn.parallel.layers.linear.rst b/docs/colossalai/colossalai.nn.parallel.layers.linear.rst
deleted file mode 100644
index bbc5e32570e7..000000000000
--- a/docs/colossalai/colossalai.nn.parallel.layers.linear.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.parallel.layers.linear
-====================================
-
-.. automodule:: colossalai.nn.parallel.layers.linear
- :members:
diff --git a/docs/colossalai/colossalai.nn.parallel.layers.module_utils.rst b/docs/colossalai/colossalai.nn.parallel.layers.module_utils.rst
deleted file mode 100644
index 5190ab40345a..000000000000
--- a/docs/colossalai/colossalai.nn.parallel.layers.module_utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.parallel.layers.module\_utils
-===========================================
-
-.. automodule:: colossalai.nn.parallel.layers.module_utils
- :members:
diff --git a/docs/colossalai/colossalai.nn.parallel.layers.rst b/docs/colossalai/colossalai.nn.parallel.layers.rst
deleted file mode 100644
index 782a206e88d5..000000000000
--- a/docs/colossalai/colossalai.nn.parallel.layers.rst
+++ /dev/null
@@ -1,14 +0,0 @@
-colossalai.nn.parallel.layers
-=============================
-
-.. automodule:: colossalai.nn.parallel.layers
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.parallel.layers.colo_module
- colossalai.nn.parallel.layers.embedding
- colossalai.nn.parallel.layers.linear
- colossalai.nn.parallel.layers.module_utils
diff --git a/docs/colossalai/colossalai.nn.parallel.reducer.rst b/docs/colossalai/colossalai.nn.parallel.reducer.rst
deleted file mode 100644
index d80841f6916e..000000000000
--- a/docs/colossalai/colossalai.nn.parallel.reducer.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.nn.parallel.reducer
-==============================
-
-.. automodule:: colossalai.nn.parallel.reducer
- :members:
diff --git a/docs/colossalai/colossalai.nn.parallel.rst b/docs/colossalai/colossalai.nn.parallel.rst
deleted file mode 100644
index 19e9d1eef19b..000000000000
--- a/docs/colossalai/colossalai.nn.parallel.rst
+++ /dev/null
@@ -1,17 +0,0 @@
-colossalai.nn.parallel
-======================
-
-.. automodule:: colossalai.nn.parallel
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.parallel.layers
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.parallel.data_parallel
- colossalai.nn.parallel.reducer
diff --git a/docs/colossalai/colossalai.nn.rst b/docs/colossalai/colossalai.nn.rst
deleted file mode 100644
index 7e683952f3db..000000000000
--- a/docs/colossalai/colossalai.nn.rst
+++ /dev/null
@@ -1,22 +0,0 @@
-colossalai.nn
-=============
-
-.. automodule:: colossalai.nn
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.graph
- colossalai.nn.layer
- colossalai.nn.loss
- colossalai.nn.lr_scheduler
- colossalai.nn.metric
- colossalai.nn.optimizer
- colossalai.nn.parallel
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.nn.init
diff --git a/docs/colossalai/colossalai.pipeline.layer_sepc.rst b/docs/colossalai/colossalai.pipeline.layer_sepc.rst
deleted file mode 100644
index 156660b5c00f..000000000000
--- a/docs/colossalai/colossalai.pipeline.layer_sepc.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.pipeline.layer\_sepc
-===============================
-
-.. automodule:: colossalai.pipeline.layer_spec
- :members:
diff --git a/docs/colossalai/colossalai.pipeline.pipelinable.rst b/docs/colossalai/colossalai.pipeline.pipelinable.rst
deleted file mode 100644
index 5c2b02ba63e2..000000000000
--- a/docs/colossalai/colossalai.pipeline.pipelinable.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.pipeline.pipelinable
-===============================
-
-.. automodule:: colossalai.pipeline.pipelinable
- :members:
diff --git a/docs/colossalai/colossalai.pipeline.rst b/docs/colossalai/colossalai.pipeline.rst
deleted file mode 100644
index 6f7652d492e0..000000000000
--- a/docs/colossalai/colossalai.pipeline.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.pipeline
-===================
-
-.. automodule:: colossalai.pipeline
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.pipeline.layer_spec
- colossalai.pipeline.pipelinable
- colossalai.pipeline.utils
diff --git a/docs/colossalai/colossalai.pipeline.utils.rst b/docs/colossalai/colossalai.pipeline.utils.rst
deleted file mode 100644
index a33bf42cfc2b..000000000000
--- a/docs/colossalai/colossalai.pipeline.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.pipeline.utils
-=========================
-
-.. automodule:: colossalai.pipeline.utils
- :members:
diff --git a/docs/colossalai/colossalai.registry.registry.rst b/docs/colossalai/colossalai.registry.registry.rst
deleted file mode 100644
index e942d7969b60..000000000000
--- a/docs/colossalai/colossalai.registry.registry.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.registry.registry
-============================
-
-.. automodule:: colossalai.registry.registry
- :members:
diff --git a/docs/colossalai/colossalai.registry.rst b/docs/colossalai/colossalai.registry.rst
deleted file mode 100644
index 0f294f6d15a7..000000000000
--- a/docs/colossalai/colossalai.registry.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.registry
-===================
-
-.. automodule:: colossalai.registry
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.registry.registry
diff --git a/docs/colossalai/colossalai.rst b/docs/colossalai/colossalai.rst
deleted file mode 100644
index 921f15a97f00..000000000000
--- a/docs/colossalai/colossalai.rst
+++ /dev/null
@@ -1,36 +0,0 @@
-colossalai
-==========
-
-.. automodule:: colossalai
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.amp
- colossalai.builder
- colossalai.cli
- colossalai.communication
- colossalai.context
- colossalai.engine
- colossalai.fx
- colossalai.gemini
- colossalai.kernel
- colossalai.logging
- colossalai.nn
- colossalai.pipeline
- colossalai.registry
- colossalai.tensor
- colossalai.testing
- colossalai.trainer
- colossalai.utils
- colossalai.zero
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.constants
- colossalai.core
- colossalai.global_variables
- colossalai.initialize
diff --git a/docs/colossalai/colossalai.tensor.colo_parameter.rst b/docs/colossalai/colossalai.tensor.colo_parameter.rst
deleted file mode 100644
index 9b65029dbbe4..000000000000
--- a/docs/colossalai/colossalai.tensor.colo_parameter.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.colo\_parameter
-=================================
-
-.. automodule:: colossalai.tensor.colo_parameter
- :members:
diff --git a/docs/colossalai/colossalai.tensor.colo_tensor.rst b/docs/colossalai/colossalai.tensor.colo_tensor.rst
deleted file mode 100644
index 9161ac22f665..000000000000
--- a/docs/colossalai/colossalai.tensor.colo_tensor.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.colo\_tensor
-==============================
-
-.. automodule:: colossalai.tensor.colo_tensor
- :members:
diff --git a/docs/colossalai/colossalai.tensor.compute_spec.rst b/docs/colossalai/colossalai.tensor.compute_spec.rst
deleted file mode 100644
index e2d7235d99c4..000000000000
--- a/docs/colossalai/colossalai.tensor.compute_spec.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.compute\_spec
-===============================
-
-.. automodule:: colossalai.tensor.compute_spec
- :members:
diff --git a/docs/colossalai/colossalai.tensor.const.rst b/docs/colossalai/colossalai.tensor.const.rst
deleted file mode 100644
index a22a2789349b..000000000000
--- a/docs/colossalai/colossalai.tensor.const.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.const
-=======================
-
-.. automodule:: colossalai.tensor.const
- :members:
diff --git a/docs/colossalai/colossalai.tensor.dist_spec_mgr.rst b/docs/colossalai/colossalai.tensor.dist_spec_mgr.rst
deleted file mode 100644
index 043cf22604a3..000000000000
--- a/docs/colossalai/colossalai.tensor.dist_spec_mgr.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.dist\_spec\_mgr
-=================================
-
-.. automodule:: colossalai.tensor.dist_spec_mgr
- :members:
diff --git a/docs/colossalai/colossalai.tensor.distspec.rst b/docs/colossalai/colossalai.tensor.distspec.rst
deleted file mode 100644
index 2b4b0e5fa266..000000000000
--- a/docs/colossalai/colossalai.tensor.distspec.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.distspec
-==========================
-
-.. automodule:: colossalai.tensor.distspec
- :members:
diff --git a/docs/colossalai/colossalai.tensor.op_wrapper.rst b/docs/colossalai/colossalai.tensor.op_wrapper.rst
deleted file mode 100644
index a246e0a6a548..000000000000
--- a/docs/colossalai/colossalai.tensor.op_wrapper.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.op\_wrapper
-=============================
-
-.. automodule:: colossalai.tensor.op_wrapper
- :members:
diff --git a/docs/colossalai/colossalai.tensor.param_op_hook.rst b/docs/colossalai/colossalai.tensor.param_op_hook.rst
deleted file mode 100644
index 475ada452bb2..000000000000
--- a/docs/colossalai/colossalai.tensor.param_op_hook.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.param\_op\_hook
-=================================
-
-.. automodule:: colossalai.tensor.param_op_hook
- :members:
diff --git a/docs/colossalai/colossalai.tensor.process_group.rst b/docs/colossalai/colossalai.tensor.process_group.rst
deleted file mode 100644
index b71409e3bd11..000000000000
--- a/docs/colossalai/colossalai.tensor.process_group.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.process\_group
-================================
-
-.. automodule:: colossalai.tensor.process_group
- :members:
diff --git a/docs/colossalai/colossalai.tensor.rst b/docs/colossalai/colossalai.tensor.rst
deleted file mode 100644
index 68e06552b873..000000000000
--- a/docs/colossalai/colossalai.tensor.rst
+++ /dev/null
@@ -1,21 +0,0 @@
-colossalai.tensor
-=================
-
-.. automodule:: colossalai.tensor
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.tensor.colo_parameter
- colossalai.tensor.colo_tensor
- colossalai.tensor.compute_spec
- colossalai.tensor.const
- colossalai.tensor.dist_spec_mgr
- colossalai.tensor.distspec
- colossalai.tensor.op_wrapper
- colossalai.tensor.param_op_hook
- colossalai.tensor.process_group
- colossalai.tensor.tensor_spec
- colossalai.tensor.utils
diff --git a/docs/colossalai/colossalai.tensor.tensor_spec.rst b/docs/colossalai/colossalai.tensor.tensor_spec.rst
deleted file mode 100644
index 7125b9cbc28d..000000000000
--- a/docs/colossalai/colossalai.tensor.tensor_spec.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.tensor\_spec
-==============================
-
-.. automodule:: colossalai.tensor.tensor_spec
- :members:
diff --git a/docs/colossalai/colossalai.tensor.utils.rst b/docs/colossalai/colossalai.tensor.utils.rst
deleted file mode 100644
index 5d9bd1b03038..000000000000
--- a/docs/colossalai/colossalai.tensor.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.tensor.utils
-=======================
-
-.. automodule:: colossalai.tensor.utils
- :members:
diff --git a/docs/colossalai/colossalai.testing.comparison.rst b/docs/colossalai/colossalai.testing.comparison.rst
deleted file mode 100644
index bcfdf0598856..000000000000
--- a/docs/colossalai/colossalai.testing.comparison.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.testing.comparison
-=============================
-
-.. automodule:: colossalai.testing.comparison
- :members:
diff --git a/docs/colossalai/colossalai.testing.rst b/docs/colossalai/colossalai.testing.rst
deleted file mode 100644
index 1127aa52c1ad..000000000000
--- a/docs/colossalai/colossalai.testing.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-colossalai.testing
-==================
-
-.. automodule:: colossalai.testing
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.testing.comparison
- colossalai.testing.utils
diff --git a/docs/colossalai/colossalai.testing.utils.rst b/docs/colossalai/colossalai.testing.utils.rst
deleted file mode 100644
index d8c2edcce71c..000000000000
--- a/docs/colossalai/colossalai.testing.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.testing.utils
-========================
-
-.. automodule:: colossalai.testing.utils
- :members:
diff --git a/docs/colossalai/colossalai.trainer.hooks.rst b/docs/colossalai/colossalai.trainer.hooks.rst
deleted file mode 100644
index 84cc6797b831..000000000000
--- a/docs/colossalai/colossalai.trainer.hooks.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.trainer.hooks
-========================
-
-.. automodule:: colossalai.trainer.hooks
- :members:
diff --git a/docs/colossalai/colossalai.trainer.rst b/docs/colossalai/colossalai.trainer.rst
deleted file mode 100644
index abc636e62373..000000000000
--- a/docs/colossalai/colossalai.trainer.rst
+++ /dev/null
@@ -1,10 +0,0 @@
-colossalai.trainer
-==================
-
-.. automodule:: colossalai.trainer
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.trainer.hooks
diff --git a/docs/colossalai/colossalai.utils.activation_checkpoint.rst b/docs/colossalai/colossalai.utils.activation_checkpoint.rst
deleted file mode 100644
index 671b5fe9e9c4..000000000000
--- a/docs/colossalai/colossalai.utils.activation_checkpoint.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.activation\_checkpoint
-=======================================
-
-.. automodule:: colossalai.utils.activation_checkpoint
- :members:
diff --git a/docs/colossalai/colossalai.utils.checkpoint.module_checkpoint.rst b/docs/colossalai/colossalai.utils.checkpoint.module_checkpoint.rst
deleted file mode 100644
index 237ad380b301..000000000000
--- a/docs/colossalai/colossalai.utils.checkpoint.module_checkpoint.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.checkpoint.module\_checkpoint
-==============================================
-
-.. automodule:: colossalai.utils.checkpoint.module_checkpoint
- :members:
diff --git a/docs/colossalai/colossalai.utils.checkpoint.rst b/docs/colossalai/colossalai.utils.checkpoint.rst
deleted file mode 100644
index 220c270f09b9..000000000000
--- a/docs/colossalai/colossalai.utils.checkpoint.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-colossalai.utils.checkpoint
-===========================
-
-.. automodule:: colossalai.utils.checkpoint
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.checkpoint.module_checkpoint
- colossalai.utils.checkpoint.utils
diff --git a/docs/colossalai/colossalai.utils.checkpoint.utils.rst b/docs/colossalai/colossalai.utils.checkpoint.utils.rst
deleted file mode 100644
index 7fdeefd539fe..000000000000
--- a/docs/colossalai/colossalai.utils.checkpoint.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.checkpoint.utils
-=================================
-
-.. automodule:: colossalai.utils.checkpoint.utils
- :members:
diff --git a/docs/colossalai/colossalai.utils.checkpointing.rst b/docs/colossalai/colossalai.utils.checkpointing.rst
deleted file mode 100644
index 534a581d5364..000000000000
--- a/docs/colossalai/colossalai.utils.checkpointing.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.checkpointing
-==============================
-
-.. automodule:: colossalai.utils.checkpointing
- :members:
diff --git a/docs/colossalai/colossalai.utils.common.rst b/docs/colossalai/colossalai.utils.common.rst
deleted file mode 100644
index cb9f9c14ef4f..000000000000
--- a/docs/colossalai/colossalai.utils.common.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.common
-=======================
-
-.. automodule:: colossalai.utils.common
- :members:
diff --git a/docs/colossalai/colossalai.utils.cuda.rst b/docs/colossalai/colossalai.utils.cuda.rst
deleted file mode 100644
index ec428c5ef6ea..000000000000
--- a/docs/colossalai/colossalai.utils.cuda.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.cuda
-=====================
-
-.. automodule:: colossalai.utils.cuda
- :members:
diff --git a/docs/colossalai/colossalai.utils.data_sampler.base_sampler.rst b/docs/colossalai/colossalai.utils.data_sampler.base_sampler.rst
deleted file mode 100644
index 199e8fcf83c3..000000000000
--- a/docs/colossalai/colossalai.utils.data_sampler.base_sampler.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.data\_sampler.base\_sampler
-============================================
-
-.. automodule:: colossalai.utils.data_sampler.base_sampler
- :members:
diff --git a/docs/colossalai/colossalai.utils.data_sampler.data_parallel_sampler.rst b/docs/colossalai/colossalai.utils.data_sampler.data_parallel_sampler.rst
deleted file mode 100644
index 85e1b121c682..000000000000
--- a/docs/colossalai/colossalai.utils.data_sampler.data_parallel_sampler.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.data\_sampler.data\_parallel\_sampler
-======================================================
-
-.. automodule:: colossalai.utils.data_sampler.data_parallel_sampler
- :members:
diff --git a/docs/colossalai/colossalai.utils.data_sampler.rst b/docs/colossalai/colossalai.utils.data_sampler.rst
deleted file mode 100644
index 61dde070bad4..000000000000
--- a/docs/colossalai/colossalai.utils.data_sampler.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-colossalai.utils.data\_sampler
-==============================
-
-.. automodule:: colossalai.utils.data_sampler
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.data_sampler.base_sampler
- colossalai.utils.data_sampler.data_parallel_sampler
diff --git a/docs/colossalai/colossalai.utils.memory.rst b/docs/colossalai/colossalai.utils.memory.rst
deleted file mode 100644
index 67c5d60022dd..000000000000
--- a/docs/colossalai/colossalai.utils.memory.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.memory
-=======================
-
-.. automodule:: colossalai.utils.memory
- :members:
diff --git a/docs/colossalai/colossalai.utils.model.colo_init_context.rst b/docs/colossalai/colossalai.utils.model.colo_init_context.rst
deleted file mode 100644
index 33ee44915083..000000000000
--- a/docs/colossalai/colossalai.utils.model.colo_init_context.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.model.colo\_init\_context
-==========================================
-
-.. automodule:: colossalai.utils.model.colo_init_context
- :members:
diff --git a/docs/colossalai/colossalai.utils.model.lazy_init_context.rst b/docs/colossalai/colossalai.utils.model.lazy_init_context.rst
deleted file mode 100644
index 27c9a32c6a7d..000000000000
--- a/docs/colossalai/colossalai.utils.model.lazy_init_context.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.model.lazy\_init\_context
-==========================================
-
-.. automodule:: colossalai.utils.model.lazy_init_context
- :members:
diff --git a/docs/colossalai/colossalai.utils.model.rst b/docs/colossalai/colossalai.utils.model.rst
deleted file mode 100644
index 9adfd1450a47..000000000000
--- a/docs/colossalai/colossalai.utils.model.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.utils.model
-======================
-
-.. automodule:: colossalai.utils.model
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.model.colo_init_context
- colossalai.utils.model.lazy_init_context
- colossalai.utils.model.utils
diff --git a/docs/colossalai/colossalai.utils.model.utils.rst b/docs/colossalai/colossalai.utils.model.utils.rst
deleted file mode 100644
index 211106662dc3..000000000000
--- a/docs/colossalai/colossalai.utils.model.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.model.utils
-============================
-
-.. automodule:: colossalai.utils.model.utils
- :members:
diff --git a/docs/colossalai/colossalai.utils.moe.rst b/docs/colossalai/colossalai.utils.moe.rst
deleted file mode 100644
index b66ccdc8ec2d..000000000000
--- a/docs/colossalai/colossalai.utils.moe.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.moe
-====================
-
-.. automodule:: colossalai.utils.moe
- :members:
diff --git a/docs/colossalai/colossalai.utils.multi_tensor_apply.multi_tensor_apply.rst b/docs/colossalai/colossalai.utils.multi_tensor_apply.multi_tensor_apply.rst
deleted file mode 100644
index 493b9530e0f6..000000000000
--- a/docs/colossalai/colossalai.utils.multi_tensor_apply.multi_tensor_apply.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.multi\_tensor\_apply.multi\_tensor\_apply
-==========================================================
-
-.. automodule:: colossalai.utils.multi_tensor_apply.multi_tensor_apply
- :members:
diff --git a/docs/colossalai/colossalai.utils.multi_tensor_apply.rst b/docs/colossalai/colossalai.utils.multi_tensor_apply.rst
deleted file mode 100644
index d5749cfa8801..000000000000
--- a/docs/colossalai/colossalai.utils.multi_tensor_apply.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.utils.multi\_tensor\_apply
-=====================================
-
-.. automodule:: colossalai.utils.multi_tensor_apply
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.multi_tensor_apply.multi_tensor_apply
diff --git a/docs/colossalai/colossalai.utils.profiler.extention.rst b/docs/colossalai/colossalai.utils.profiler.extention.rst
deleted file mode 100644
index 5c87692611a0..000000000000
--- a/docs/colossalai/colossalai.utils.profiler.extention.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.profiler.extention
-===================================
-
-.. automodule:: colossalai.utils.profiler.extention
- :members:
diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.comm_profiler.rst b/docs/colossalai/colossalai.utils.profiler.legacy.comm_profiler.rst
deleted file mode 100644
index 4329a3d60da3..000000000000
--- a/docs/colossalai/colossalai.utils.profiler.legacy.comm_profiler.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.profiler.legacy.comm\_profiler
-===============================================
-
-.. automodule:: colossalai.utils.profiler.legacy.comm_profiler
- :members:
diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.mem_profiler.rst b/docs/colossalai/colossalai.utils.profiler.legacy.mem_profiler.rst
deleted file mode 100644
index 35c665c71d3b..000000000000
--- a/docs/colossalai/colossalai.utils.profiler.legacy.mem_profiler.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.profiler.legacy.mem\_profiler
-==============================================
-
-.. automodule:: colossalai.utils.profiler.legacy.mem_profiler
- :members:
diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.pcie_profiler.rst b/docs/colossalai/colossalai.utils.profiler.legacy.pcie_profiler.rst
deleted file mode 100644
index 7aa82b8f7a4f..000000000000
--- a/docs/colossalai/colossalai.utils.profiler.legacy.pcie_profiler.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.profiler.legacy.pcie\_profiler
-===============================================
-
-.. automodule:: colossalai.utils.profiler.legacy.pcie_profiler
- :members:
diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.prof_utils.rst b/docs/colossalai/colossalai.utils.profiler.legacy.prof_utils.rst
deleted file mode 100644
index 93af82b2fabb..000000000000
--- a/docs/colossalai/colossalai.utils.profiler.legacy.prof_utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.profiler.legacy.prof\_utils
-============================================
-
-.. automodule:: colossalai.utils.profiler.legacy.prof_utils
- :members:
diff --git a/docs/colossalai/colossalai.utils.profiler.legacy.rst b/docs/colossalai/colossalai.utils.profiler.legacy.rst
deleted file mode 100644
index 37fcebde5a43..000000000000
--- a/docs/colossalai/colossalai.utils.profiler.legacy.rst
+++ /dev/null
@@ -1,14 +0,0 @@
-colossalai.utils.profiler.legacy
-================================
-
-.. automodule:: colossalai.utils.profiler.legacy
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.profiler.legacy.comm_profiler
- colossalai.utils.profiler.legacy.mem_profiler
- colossalai.utils.profiler.legacy.pcie_profiler
- colossalai.utils.profiler.legacy.prof_utils
diff --git a/docs/colossalai/colossalai.utils.profiler.profiler.rst b/docs/colossalai/colossalai.utils.profiler.profiler.rst
deleted file mode 100644
index d35522837801..000000000000
--- a/docs/colossalai/colossalai.utils.profiler.profiler.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.profiler.profiler
-==================================
-
-.. automodule:: colossalai.utils.profiler.profiler
- :members:
diff --git a/docs/colossalai/colossalai.utils.profiler.rst b/docs/colossalai/colossalai.utils.profiler.rst
deleted file mode 100644
index 15681fcf2d82..000000000000
--- a/docs/colossalai/colossalai.utils.profiler.rst
+++ /dev/null
@@ -1,18 +0,0 @@
-colossalai.utils.profiler
-=========================
-
-.. automodule:: colossalai.utils.profiler
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.profiler.legacy
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.profiler.extention
- colossalai.utils.profiler.profiler
- colossalai.utils.profiler.stateful_tensor_mem_extention
diff --git a/docs/colossalai/colossalai.utils.profiler.stateful_tensor_mem_extention.rst b/docs/colossalai/colossalai.utils.profiler.stateful_tensor_mem_extention.rst
deleted file mode 100644
index 72a3fcceca18..000000000000
--- a/docs/colossalai/colossalai.utils.profiler.stateful_tensor_mem_extention.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.profiler.stateful\_tensor\_mem\_extention
-==========================================================
-
-.. automodule:: colossalai.utils.profiler.stateful_tensor_mem_extention
- :members:
diff --git a/docs/colossalai/colossalai.utils.rst b/docs/colossalai/colossalai.utils.rst
deleted file mode 100644
index 8b232a12c245..000000000000
--- a/docs/colossalai/colossalai.utils.rst
+++ /dev/null
@@ -1,27 +0,0 @@
-colossalai.utils
-================
-
-.. automodule:: colossalai.utils
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.checkpoint
- colossalai.utils.data_sampler
- colossalai.utils.model
- colossalai.utils.multi_tensor_apply
- colossalai.utils.profiler
- colossalai.utils.tensor_detector
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.activation_checkpoint
- colossalai.utils.checkpointing
- colossalai.utils.common
- colossalai.utils.cuda
- colossalai.utils.memory
- colossalai.utils.moe
- colossalai.utils.timer
diff --git a/docs/colossalai/colossalai.utils.tensor_detector.rst b/docs/colossalai/colossalai.utils.tensor_detector.rst
deleted file mode 100644
index 807d67e3ad1e..000000000000
--- a/docs/colossalai/colossalai.utils.tensor_detector.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.utils.tensor\_detector
-=================================
-
-.. automodule:: colossalai.utils.tensor_detector
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.utils.tensor_detector.tensor_detector
diff --git a/docs/colossalai/colossalai.utils.tensor_detector.tensor_detector.rst b/docs/colossalai/colossalai.utils.tensor_detector.tensor_detector.rst
deleted file mode 100644
index 991cea3438b3..000000000000
--- a/docs/colossalai/colossalai.utils.tensor_detector.tensor_detector.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.tensor\_detector.tensor\_detector
-==================================================
-
-.. automodule:: colossalai.utils.tensor_detector.tensor_detector
- :members:
diff --git a/docs/colossalai/colossalai.utils.timer.rst b/docs/colossalai/colossalai.utils.timer.rst
deleted file mode 100644
index 2014c85f548f..000000000000
--- a/docs/colossalai/colossalai.utils.timer.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.utils.timer
-======================
-
-.. automodule:: colossalai.utils.timer
- :members:
diff --git a/docs/colossalai/colossalai.zero.init_ctx.init_context.rst b/docs/colossalai/colossalai.zero.init_ctx.init_context.rst
deleted file mode 100644
index 1694074e83bf..000000000000
--- a/docs/colossalai/colossalai.zero.init_ctx.init_context.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.init\_ctx.init\_context
-=======================================
-
-.. automodule:: colossalai.zero.init_ctx.init_context
- :members:
diff --git a/docs/colossalai/colossalai.zero.init_ctx.rst b/docs/colossalai/colossalai.zero.init_ctx.rst
deleted file mode 100644
index 88cf471df9d3..000000000000
--- a/docs/colossalai/colossalai.zero.init_ctx.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.zero.init\_ctx
-=========================
-
-.. automodule:: colossalai.zero.init_ctx
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.zero.init_ctx.init_context
diff --git a/docs/colossalai/colossalai.zero.rst b/docs/colossalai/colossalai.zero.rst
deleted file mode 100644
index 3bcaffd28d05..000000000000
--- a/docs/colossalai/colossalai.zero.rst
+++ /dev/null
@@ -1,21 +0,0 @@
-colossalai.zero
-===============
-
-.. automodule:: colossalai.zero
- :members:
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.zero.init_ctx
- colossalai.zero.shard_utils
- colossalai.zero.sharded_model
- colossalai.zero.sharded_optim
- colossalai.zero.sharded_param
- colossalai.zero.utils
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.zero.zero_optimizer
diff --git a/docs/colossalai/colossalai.zero.shard_utils.base_shard_strategy.rst b/docs/colossalai/colossalai.zero.shard_utils.base_shard_strategy.rst
deleted file mode 100644
index d5b59e06a517..000000000000
--- a/docs/colossalai/colossalai.zero.shard_utils.base_shard_strategy.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.shard\_utils.base\_shard\_strategy
-==================================================
-
-.. automodule:: colossalai.zero.shard_utils.base_shard_strategy
- :members:
diff --git a/docs/colossalai/colossalai.zero.shard_utils.bucket_tensor_shard_strategy.rst b/docs/colossalai/colossalai.zero.shard_utils.bucket_tensor_shard_strategy.rst
deleted file mode 100644
index 952c5bbddf09..000000000000
--- a/docs/colossalai/colossalai.zero.shard_utils.bucket_tensor_shard_strategy.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.shard\_utils.bucket\_tensor\_shard\_strategy
-============================================================
-
-.. automodule:: colossalai.zero.shard_utils.bucket_tensor_shard_strategy
- :members:
diff --git a/docs/colossalai/colossalai.zero.shard_utils.commons.rst b/docs/colossalai/colossalai.zero.shard_utils.commons.rst
deleted file mode 100644
index aa6682d79ff2..000000000000
--- a/docs/colossalai/colossalai.zero.shard_utils.commons.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.shard\_utils.commons
-====================================
-
-.. automodule:: colossalai.zero.shard_utils.commons
- :members:
diff --git a/docs/colossalai/colossalai.zero.shard_utils.rst b/docs/colossalai/colossalai.zero.shard_utils.rst
deleted file mode 100644
index 580bfdab7d85..000000000000
--- a/docs/colossalai/colossalai.zero.shard_utils.rst
+++ /dev/null
@@ -1,14 +0,0 @@
-colossalai.zero.shard\_utils
-============================
-
-.. automodule:: colossalai.zero.shard_utils
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.zero.shard_utils.base_shard_strategy
- colossalai.zero.shard_utils.bucket_tensor_shard_strategy
- colossalai.zero.shard_utils.commons
- colossalai.zero.shard_utils.tensor_shard_strategy
diff --git a/docs/colossalai/colossalai.zero.shard_utils.tensor_shard_strategy.rst b/docs/colossalai/colossalai.zero.shard_utils.tensor_shard_strategy.rst
deleted file mode 100644
index 571b7bd7a588..000000000000
--- a/docs/colossalai/colossalai.zero.shard_utils.tensor_shard_strategy.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.shard\_utils.tensor\_shard\_strategy
-====================================================
-
-.. automodule:: colossalai.zero.shard_utils.tensor_shard_strategy
- :members:
diff --git a/docs/colossalai/colossalai.zero.sharded_model.reduce_scatter.rst b/docs/colossalai/colossalai.zero.sharded_model.reduce_scatter.rst
deleted file mode 100644
index cf861ee70aa0..000000000000
--- a/docs/colossalai/colossalai.zero.sharded_model.reduce_scatter.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.sharded\_model.reduce\_scatter
-==============================================
-
-.. automodule:: colossalai.zero.sharded_model.reduce_scatter
- :members:
diff --git a/docs/colossalai/colossalai.zero.sharded_model.rst b/docs/colossalai/colossalai.zero.sharded_model.rst
deleted file mode 100644
index fb3f5a8456d0..000000000000
--- a/docs/colossalai/colossalai.zero.sharded_model.rst
+++ /dev/null
@@ -1,13 +0,0 @@
-colossalai.zero.sharded\_model
-==============================
-
-.. automodule:: colossalai.zero.sharded_model
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.zero.sharded_model.reduce_scatter
- colossalai.zero.sharded_model.sharded_model_v2
- colossalai.zero.sharded_model.utils
diff --git a/docs/colossalai/colossalai.zero.sharded_model.sharded_model_v2.rst b/docs/colossalai/colossalai.zero.sharded_model.sharded_model_v2.rst
deleted file mode 100644
index a0e191377914..000000000000
--- a/docs/colossalai/colossalai.zero.sharded_model.sharded_model_v2.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.sharded\_model.sharded\_model\_v2
-=================================================
-
-.. automodule:: colossalai.zero.sharded_model.sharded_model_v2
- :members:
diff --git a/docs/colossalai/colossalai.zero.sharded_model.utils.rst b/docs/colossalai/colossalai.zero.sharded_model.utils.rst
deleted file mode 100644
index 5e376774296f..000000000000
--- a/docs/colossalai/colossalai.zero.sharded_model.utils.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.sharded\_model.utils
-====================================
-
-.. automodule:: colossalai.zero.sharded_model.utils
- :members:
diff --git a/docs/colossalai/colossalai.zero.sharded_optim.rst b/docs/colossalai/colossalai.zero.sharded_optim.rst
deleted file mode 100644
index db3dfdddbab4..000000000000
--- a/docs/colossalai/colossalai.zero.sharded_optim.rst
+++ /dev/null
@@ -1,11 +0,0 @@
-colossalai.zero.sharded\_optim
-==============================
-
-.. automodule:: colossalai.zero.sharded_optim
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.zero.sharded_optim.sharded_optim_v2
diff --git a/docs/colossalai/colossalai.zero.sharded_optim.sharded_optim_v2.rst b/docs/colossalai/colossalai.zero.sharded_optim.sharded_optim_v2.rst
deleted file mode 100644
index 01fbe0c4c031..000000000000
--- a/docs/colossalai/colossalai.zero.sharded_optim.sharded_optim_v2.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.sharded\_optim.sharded\_optim\_v2
-=================================================
-
-.. automodule:: colossalai.zero.sharded_optim.sharded_optim_v2
- :members:
diff --git a/docs/colossalai/colossalai.zero.sharded_param.rst b/docs/colossalai/colossalai.zero.sharded_param.rst
deleted file mode 100644
index 02e0fc6c29eb..000000000000
--- a/docs/colossalai/colossalai.zero.sharded_param.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-colossalai.zero.sharded\_param
-==============================
-
-.. automodule:: colossalai.zero.sharded_param
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.zero.sharded_param.sharded_param
- colossalai.zero.sharded_param.sharded_tensor
diff --git a/docs/colossalai/colossalai.zero.sharded_param.sharded_param.rst b/docs/colossalai/colossalai.zero.sharded_param.sharded_param.rst
deleted file mode 100644
index efa2f0de379c..000000000000
--- a/docs/colossalai/colossalai.zero.sharded_param.sharded_param.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.sharded\_param.sharded\_param
-=============================================
-
-.. automodule:: colossalai.zero.sharded_param.sharded_param
- :members:
diff --git a/docs/colossalai/colossalai.zero.sharded_param.sharded_tensor.rst b/docs/colossalai/colossalai.zero.sharded_param.sharded_tensor.rst
deleted file mode 100644
index 930c28de4542..000000000000
--- a/docs/colossalai/colossalai.zero.sharded_param.sharded_tensor.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.sharded\_param.sharded\_tensor
-==============================================
-
-.. automodule:: colossalai.zero.sharded_param.sharded_tensor
- :members:
diff --git a/docs/colossalai/colossalai.zero.utils.rst b/docs/colossalai/colossalai.zero.utils.rst
deleted file mode 100644
index 50ee9071e7d5..000000000000
--- a/docs/colossalai/colossalai.zero.utils.rst
+++ /dev/null
@@ -1,12 +0,0 @@
-colossalai.zero.utils
-=====================
-
-.. automodule:: colossalai.zero.utils
- :members:
-
-
-.. toctree::
- :maxdepth: 2
-
- colossalai.zero.utils.zero_hook
- colossalai.zero.utils.gemini_hook
diff --git a/docs/colossalai/colossalai.zero.utils.zero_hook.rst b/docs/colossalai/colossalai.zero.utils.zero_hook.rst
deleted file mode 100644
index 424f466dd4f5..000000000000
--- a/docs/colossalai/colossalai.zero.utils.zero_hook.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.utils.zero\_hook
-================================
-
-.. automodule:: colossalai.zero.utils.zero_hook
- :members:
diff --git a/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst b/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst
deleted file mode 100644
index e6d6673af131..000000000000
--- a/docs/colossalai/colossalai.zero.utils.zero_hook_v2.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.utils.zero\_hook\_v2
-====================================
-
-.. automodule:: colossalai.zero.utils.gemini_hook
- :members:
diff --git a/docs/colossalai/colossalai.zero.zero_optimizer.rst b/docs/colossalai/colossalai.zero.zero_optimizer.rst
deleted file mode 100644
index b945b081c866..000000000000
--- a/docs/colossalai/colossalai.zero.zero_optimizer.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-colossalai.zero.zero\_optimizer
-===============================
-
-.. automodule:: colossalai.zero.zero_optimizer
- :members:
diff --git a/docs/conda-doc-test-deps.yml b/docs/conda-doc-test-deps.yml
new file mode 100644
index 000000000000..74a232214adc
--- /dev/null
+++ b/docs/conda-doc-test-deps.yml
@@ -0,0 +1,2 @@
+dependencies:
+ - cmake
diff --git a/docs/conf.py b/docs/conf.py
deleted file mode 100644
index 893644f709d4..000000000000
--- a/docs/conf.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# Configuration file for the Sphinx documentation builder.
-#
-# This file only contains a selection of the most common options. For a full
-# list see the documentation:
-# https://www.sphinx-doc.org/en/master/usage/configuration.html
-
-# -- Path setup --------------------------------------------------------------
-
-import datetime
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-#
-import os
-import sys
-
-sys.path.insert(0, os.path.abspath('..'))
-
-# -- Project information -----------------------------------------------------
-
-project = 'Colossal-AI'
-copyright = f'{datetime.datetime.now().year}, HPC-AI Tech'
-author = 'HPC-AI Technology Inc.'
-
-# The full version, including alpha/beta/rc tags
-release = '0.0.1'
-
-
-# -- General configuration ---------------------------------------------------
-
-# Add any Sphinx extension module names here, as strings. They can be
-# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
-# ones.
-extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.napoleon',
- 'sphinx.ext.linkcode',
- 'myst_parser',
-]
-
-# Disable docstring inheritance
-autodoc_inherit_docstrings = False
-
-# Disable displaying type annotations, these can be very verbose
-autodoc_typehints = 'none'
-
-# Enable overriding of function signatures in the first line of the docstring.
-autodoc_docstring_signature = True
-autodoc_default_options = {
- 'member-order': 'bysource',
-}
-
-# Add any paths that contain templates here, relative to this directory.
-templates_path = ['_templates']
-
-# List of patterns, relative to source directory, that match files and
-# directories to ignore when looking for source files.
-# This pattern also affects html_static_path and html_extra_path.
-exclude_patterns = ['.build', 'Thumbs.db', '.DS_Store']
-
-# -- Options for HTML output -------------------------------------------------
-
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-#
-html_theme = 'sphinx_rtd_theme'
-html_show_sourcelink = False
-html_theme_options = {
- 'navigation_depth': 3,
-}
-
-html_context = {
- 'display_github': False,
- 'github_user': 'hpcaitech',
- 'github_repo': 'ColossalAI',
- # 'github_version': 'master/docs/',
-}
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = ['_static']
-
-html_css_files = [
- 'css/rtd_theme.css',
-]
-
-# -- Extension configuration -------------------------------------------------
-source_suffix = ['.rst', '.md', '.MD']
-
-import inspect
-import colossalai
-def linkcode_resolve(domain, info):
- """
- Determine the URL corresponding to Python object
- """
- if domain != 'py':
- return None
-
- modname = info['module']
- fullname = info['fullname']
-
- submod = sys.modules.get(modname)
- if submod is None:
- return None
-
- obj = submod
- for part in fullname.split('.'):
- try:
- obj = getattr(obj, part)
- except Exception:
- return None
-
- try:
- fn = inspect.getsourcefile(obj)
- except Exception:
- fn = None
- if not fn:
- return None
-
- try:
- source, lineno = inspect.findsource(obj)
- except Exception:
- lineno = None
-
- if lineno:
- linespec = "#L%d" % (lineno + 1)
- else:
- linespec = ""
-
- fn = os.path.relpath(fn, start=os.path.dirname(colossalai.__file__))
-
- github = "https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/{}{}"
- return github.format(fn, linespec)
diff --git a/docs/index.rst b/docs/index.rst
deleted file mode 100644
index f275f7829403..000000000000
--- a/docs/index.rst
+++ /dev/null
@@ -1,27 +0,0 @@
-.. Colossal-AI documentation master file, created by
- sphinx-quickstart on Mon Oct 11 17:05:05 2021.
- You can adapt this file completely to your liking, but it should at least
- contain the root `toctree` directive.
-
-Colossal-AI API documentation
-======================================
-
-.. toctree::
- :maxdepth: 2
- :caption: API REFERENCE
-
- colossalai/colossalai
-
-.. toctree::
- :maxdepth: 2
- :caption: Useful links for Colossal-AI
-
- links/Colossalai examples
- links/Colossalai benchmarks
- links/Colossalai tutorial
-
-
-Indices and tables
---------------------
-
-* :ref:`genindex`
diff --git a/docs/links/Colossalai Homepage.rst b/docs/links/Colossalai Homepage.rst
deleted file mode 100644
index 38e223bd22c9..000000000000
--- a/docs/links/Colossalai Homepage.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-Colossal-AI Github Homepage
-==================================
-
-*If you are looking for the Git homepage of Colossal-AI, please check*
-`Colossal-AI Tutorial `_
-*for our source code.*
\ No newline at end of file
diff --git a/docs/links/Colossalai benchmarks.rst b/docs/links/Colossalai benchmarks.rst
deleted file mode 100644
index 1835670a5f2a..000000000000
--- a/docs/links/Colossalai benchmarks.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-Colossal-AI Benchmarks
-==================================
-
-*If you are interested in the performance or the features of Colossal-AI, please check*
-`Colossal-AI Benchmark `_.
-*to get more details about our performance on CIFAR10, ImageNet1K or GPT2 ZeRO.*
\ No newline at end of file
diff --git a/docs/links/Colossalai examples.rst b/docs/links/Colossalai examples.rst
deleted file mode 100644
index c375f007a3ff..000000000000
--- a/docs/links/Colossalai examples.rst
+++ /dev/null
@@ -1,6 +0,0 @@
-Colossal-AI Examples
-==================================
-
-*If you are looking for the example code of using Colossal-AI in CV or NLP, please check*
-`Colossal-AI Example `_
-*to get more details about using colossalai in Resnet, Moe, Vit, Bert and GPT*
\ No newline at end of file
diff --git a/docs/links/Colossalai tutorial.rst b/docs/links/Colossalai tutorial.rst
deleted file mode 100644
index a4ab7f5b906b..000000000000
--- a/docs/links/Colossalai tutorial.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-Colossal-AI Tutorial
-==================================
-
-*If you are looking for the tutorial of using Colossal-AI, please check*
-`Colossal-AI Tutorial `_
-*to get more details about getting started, using TP (tensor parallel), PP (pipeline parallel)
-and training with colossalai trainer or engine.*
\ No newline at end of file
diff --git a/docs/make.bat b/docs/make.bat
deleted file mode 100644
index cf73214110f2..000000000000
--- a/docs/make.bat
+++ /dev/null
@@ -1,35 +0,0 @@
-@ECHO OFF
-
-pushd %~dp0
-
-REM Command file for Sphinx documentation
-
-if "%SPHINXBUILD%" == "" (
- set SPHINXBUILD=sphinx-build
-)
-set SOURCEDIR=.
-set BUILDDIR=.build
-
-if "%1" == "" goto help
-
-%SPHINXBUILD% >NUL 2>NUL
-if errorlevel 9009 (
- echo.
- echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
- echo.installed, then set the SPHINXBUILD environment variable to point
- echo.to the full path of the 'sphinx-build' executable. Alternatively you
- echo.may add the Sphinx directory to PATH.
- echo.
- echo.If you don't have Sphinx installed, grab it from
- echo.https://www.sphinx-doc.org/
- exit /b 1
-)
-
-%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
-goto end
-
-:help
-%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
-
-:end
-popd
diff --git a/docs/requirements-doc-test.txt b/docs/requirements-doc-test.txt
new file mode 100644
index 000000000000..6a6bb3bee9b0
--- /dev/null
+++ b/docs/requirements-doc-test.txt
@@ -0,0 +1,6 @@
+colossalai
+torch
+packaging
+tensornvme
+psutil
+transformers
diff --git a/docs/requirements.txt b/docs/requirements.txt
deleted file mode 100644
index 2b3b1a25bca4..000000000000
--- a/docs/requirements.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-tensorboard
-apex
-sphinx
-sphinx-rtd-theme
-myst-parser
diff --git a/docs/sidebars.json b/docs/sidebars.json
new file mode 100644
index 000000000000..44287c17eadf
--- /dev/null
+++ b/docs/sidebars.json
@@ -0,0 +1,80 @@
+{
+ "tutorialSidebar": [
+ {
+ "type": "category",
+ "label": "Get started",
+ "collapsed": true,
+ "items": [
+ "get_started/installation",
+ "get_started/run_demo",
+ "get_started/reading_roadmap"
+ ]
+ },
+ {
+ "type": "category",
+ "label": "Concepts",
+ "collapsed": true,
+ "items": [
+ "concepts/distributed_training",
+ "concepts/paradigms_of_parallelism",
+ "concepts/colossalai_overview"
+ ]
+ },
+ {
+ "type": "category",
+ "label": "Basics",
+ "collapsed": true,
+ "items": [
+ "basics/command_line_tool",
+ "basics/define_your_config",
+ "basics/launch_colossalai",
+ "basics/initialize_features",
+ "basics/engine_trainer",
+ "basics/configure_parallelization",
+ "basics/model_checkpoint",
+ "basics/colotensor_concept"
+ ]
+ },
+ {
+ "type": "category",
+ "label": "Features",
+ "collapsed": true,
+ "items": [
+ "features/mixed_precision_training",
+ "features/gradient_accumulation",
+ "features/gradient_clipping",
+ "features/gradient_handler",
+ "features/zero_with_chunk",
+ {
+ "type": "category",
+ "label": "Tensor Parallel",
+ "collapsed": true,
+ "items": [
+ "features/1D_tensor_parallel",
+ "features/2D_tensor_parallel",
+ "features/2p5D_tensor_parallel",
+ "features/3D_tensor_parallel"
+ ]
+ },
+ "features/pipeline_parallel",
+ "features/nvme_offload"
+ ]
+ },
+ {
+ "type": "category",
+ "label": "Advanced Tutorials",
+ "collapsed": true,
+ "items": [
+ "advanced_tutorials/train_vit_using_pipeline_parallelism",
+ "advanced_tutorials/train_vit_with_hybrid_parallelism",
+ "advanced_tutorials/train_gpt_using_hybrid_parallelism",
+ "advanced_tutorials/define_your_own_parallel_model",
+ "advanced_tutorials/add_your_parallel",
+ "advanced_tutorials/meet_gemini",
+ "advanced_tutorials/parallelize_your_training_like_Megatron",
+ "advanced_tutorials/integrate_mixture_of_experts_into_your_model",
+ "advanced_tutorials/opt_service"
+ ]
+ }
+ ]
+}
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/encoders/__init__.py b/docs/source/en/Colossal-Auto/feature/auto_checkpoint.md
similarity index 100%
rename from examples/tutorial/stable_diffusion/ldm/modules/encoders/__init__.py
rename to docs/source/en/Colossal-Auto/feature/auto_checkpoint.md
diff --git a/docs/source/en/Colossal-Auto/feature/device_mesh.md b/docs/source/en/Colossal-Auto/feature/device_mesh.md
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/docs/source/en/Colossal-Auto/feature/layout_converting_management.md b/docs/source/en/Colossal-Auto/feature/layout_converting_management.md
new file mode 100644
index 000000000000..2082a33d8a39
--- /dev/null
+++ b/docs/source/en/Colossal-Auto/feature/layout_converting_management.md
@@ -0,0 +1,13 @@
+When a tensor is required to have different sharding specs in upstream and downstream operators, we need to perform layout conversion processing, which can also be called redistribution. There are currently two mainstream methods, enumeration conversion, and dimension-by-dimension conversion. enumeration conversion is to enumerate all possible situations, and then find the corresponding conversion scheme in the table when conversion is required. However, it has a big problem. That is, as the dimension of the device mesh increases, the scale of this problem is so inflated that it cannot be solved by enumerating tables. Dimension-by-dimension conversion is for a sharding spec of an N-D tensor, X0X1...Xn-1, sharding spec is converted from 0 to n-1 dimension by dimension, so that no matter how many dimensions the device mesh and tensor have, with only one-time Scanning, a feasible conversion operation sequence is generated, the problem is that the conversion efficiency will be very poor.
+
+Therefore, we propose a novel algorithm, using heuristic search, to solve the conversion problem of sharding spec, which can be described as:
+1. Generate all one-step transform sharding specs from source spec
+2. In the one-step transform sharding specs, according to the similarity function, select a sharding spec with the "least difference" as the subsequent source sharding spec, and record the sharding spec in the transform path. If a sharding spec of the one-step transforms is the same as the target sharding spec, the algorithm ends.
+3. Repeat 1, 2 until the end of the algorithm
+
+
+| Source/target sharding spec pairs |All gather | Shard | All to All | One step transform | Best sharding spec |Transform path|
+| :-: | :-: | :-: | :-: | :-: | :-: |:-: |
+| $S_{01}RR, RS_{01}R$ | $S_0RR$ | - | $S_0RS_1, S_0S_1R$ | $S_0RR, S_0RS_1, S_0S_1R$ | $S_0RR$ | $S_0RR$
+| $S_0RR, RS_{01}RR$ | $RRR$ | $S_0S_1R, S_0RS_1$ | $RS_0R, RRS_0$ | $RRR$, $S_0S_1R$, $S_0RS_1$, $RS_0R$, $RRS_0$ | $RS_0R$ | $S_0RR$ -> $RS_0R$
+| $RS_0R, RS_{01}RR$ | $RRR$ | $RS_{01}R, S_1S_0R, RS_0S_1$ | $S_0RR, RRS_0$ | $RRR$, $RS_{01}R$, $S_1S_0R$, $RS_0S_1$, $S_0RR$, $RRS_0$ | $RS_{01}R$ | $S_0RR$ -> $RS_0R$ -> $RS_{01}R$
diff --git a/docs/source/en/Colossal-Auto/feature/tracer.md b/docs/source/en/Colossal-Auto/feature/tracer.md
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/docs/source/en/Colossal-Auto/get_started/installation.md b/docs/source/en/Colossal-Auto/get_started/installation.md
new file mode 100644
index 000000000000..d2a532bfa7b0
--- /dev/null
+++ b/docs/source/en/Colossal-Auto/get_started/installation.md
@@ -0,0 +1,27 @@
+# Setup
+
+## Announcement
+
+Our auto-parallel feature is a alpha version. It is still under development. We will keep updating it and make it more stable. If you encounter any problem, please feel free to raise an issue.
+
+## Requirements
+
+We need some extra dependencies to support auto-parallel. Please install them before using auto-parallel.
+
+### Install PyTorch
+
+We only support PyTorch 1.12 now, other versions are not tested. We will support more versions in the future.
+
+```bash
+#conda
+conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
+#pip
+pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
+```
+
+### Install pulp and coin-or-cbc
+
+```bash
+pip install pulp
+conda install -c conda-forge coin-or-cbc
+```
diff --git a/docs/source/en/Colossal-Auto/get_started/introduction.md b/docs/source/en/Colossal-Auto/get_started/introduction.md
new file mode 100644
index 000000000000..a2606dd2bf9f
--- /dev/null
+++ b/docs/source/en/Colossal-Auto/get_started/introduction.md
@@ -0,0 +1,44 @@
+# Introduction
+
+In recent years, the deployment of large-scale machine learning models has become increasingly important. However, distributed training systems often require **manual parallelization plans**, which can be complex and require expert knowledge in system engineering and configuration. This can be a challenge for most AI developers without the necessary skills. The need for manual parallelization can make deploying large-scale machine learning models difficult and expensive.
+
+**Colossal-Auto** simplifies the process of deploying large-scale machine learning models for AI developers. Compared to other solutions that require manual configuration of complex parallel policies and model modification, Colossal-Auto only requires one line of code from the user, along with cluster information and model configurations, to enable distributed training. Technically, It seamlessly **integrates with popular AI model frameworks like Hugging Face and Timm.**
+
+
+
+## Overview
+
+
+
+
+
+
+## Usage
+
+```python
+# wrap the model using auto_engine
+model = autoparallelize(model, meta_input_samples)
+# normal training loop
+...
+```
+
+
+## Graph Tracing
+
+Colossal-Auto is **the first auto-parallelism system** that uses static graph analysis based on the PyTorch framework. Obtaining a static execution plan for PyTorch, a dynamic graph framework, has long been an area of research in the field of machine learning systems. Colossal-Auto uses ColoTracer, a forked version of the torch.FX Tracer, to guide the search for an optimal parallelization strategy. The meta-information of each tensor, such as tensor shape, dims, dtype, etc., is computed and recorded during the tracing process. This approach has the advantage of better generalization, as it is not tied to specific models or configurations.
+
+
+
+## Fine-grained Parallelism Search
+We investigate and research a number of current automatic parallel systems( Tofu , Flexflow , Alpa ) and some auto activation checkpoint algorithms( Rotor , Sublinear ). Inspired from these advanced systems, we build Colossal-Auto which is an automatic parallel system upon PyTorch framework. Colossal-Auto searches for strategies in regard to each operand with the goal of achieving the fastest runtime while meeting memory budget constraints. It ultimately determines the actual training time strategy, including the tensor split strategy for each tensor, the type of communication operators to be inserted between different computing nodes, whether to replace operators, etc. The tensor, data, and hybrid parallelism such as column and row split used by NVIDIA in Megatron-LM and other parallelism systems are all subsets of strategies that can be searched by Colossal-AI. In addition to these parallelisms that can be manually specified, Colossal-AI can specify a unique parallelism method for each operation and, potentially finding a better parallelism strategy than what human experts could provide.
+
+
+
+## Distributed Tensor and Shape-Consistency System
+
+The Colossal-AI system uses a device-mesh, similar to PyTorch's latest DTensor release, to manage its cluster. Colossal-AI uses a sharding-spec to annotate the storage status of each tensor and facilitate their distribution across the cluster. The system also employs a shape-consistency manager to automatically transform tensors between different sharding-specs, allowing for seamless slicing and dicing of tensors, while the shape-consistency manager ensures that the output of upstream operands is consistently stored in the cluster, regardless of how the input of downstream operands is stored. This makes Colossal-AI highly versatile and easy to use without users worrying about the storage status of tensors when performing operations on them.
+
+Here are some key advantages of Colossal-AI compared to PyTorch DTensor:
+Colossal-AI's device-mesh uses cluster performance metrics and profiling results to estimate the time consumption of different communication operators. This helps Colossal-AI optimize communication between nodes and improve overall system efficiency.
+Colossal-AI's shape-consistency manager uses a greedy search algorithm to find relatively efficient ways to transform tensors between different sharding-specs, rather than simply transforming dimensions one by one. This can lead to more efficient and effective transformations.
+The integration of all-to-all operations in Colossal-AI increases the scalability of the system by enabling more efficient communication between nodes. This is especially useful for large-scale machine learning tasks that require the transfer of large amounts of data between nodes.
diff --git a/docs/source/en/Colossal-Auto/get_started/run_demo.md b/docs/source/en/Colossal-Auto/get_started/run_demo.md
new file mode 100644
index 000000000000..6f7a82966f20
--- /dev/null
+++ b/docs/source/en/Colossal-Auto/get_started/run_demo.md
@@ -0,0 +1,13 @@
+# Quick Demo
+
+Colossal-Auto simplifies the process of deploying large-scale machine learning models for AI developers. Compared to other solutions that require manual configuration of complex parallel policies and model modification, Colossal-Auto only requires one line of code from the user, along with cluster information and model configurations, to enable distributed training. Quick demos showing how to use Colossal-Auto are given below.
+
+### 1. Basic usage
+
+Colossal-Auto can be used to find a hybrid SPMD parallel strategy includes data, tensor(i.e., 1D, 2D, sequencial) for each operation. You can follow the [GPT example](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/experiments/auto_parallel).
+Detailed instructions can be found in its `README.md`.
+
+### 2. Integration with activation checkpoint
+
+Colossal-Auto's automatic search function for activation checkpointing finds the most efficient checkpoint within a given memory budget, rather than just aiming for maximum memory compression. To avoid a lengthy search process for an optimal activation checkpoint, Colossal-Auto has implemented a two-stage search process. This allows the system to find a feasible distributed training solution in a reasonable amount of time while still benefiting from activation checkpointing for memory management. The integration of activation checkpointing in Colossal-AI improves the efficiency and effectiveness of large model training. You can follow the [Resnet example](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/auto_parallel).
+Detailed instructions can be found in its `README.md`.
diff --git a/docs/source/en/advanced_tutorials/add_your_parallel.md b/docs/source/en/advanced_tutorials/add_your_parallel.md
new file mode 100644
index 000000000000..be7284a7ab64
--- /dev/null
+++ b/docs/source/en/advanced_tutorials/add_your_parallel.md
@@ -0,0 +1,124 @@
+# Add Your Own Parallel Mode
+
+Author: Shenggui Li, Yongbin Li
+
+**Prerequisite:**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Configure Parallelization](../basics/configure_parallelization.md)
+
+## Introduction
+
+To enable researchers and engineers to extend our system to other novel large-scale distributed training algorithm
+with less effort, we have decoupled various components in the training lifecycle. You can implement your own
+parallelism by simply inheriting from the base class.
+
+The main components are:
+
+1. `ProcessGroupInitializer`
+2. `GradientHandler`
+3. `Schedule`
+
+**This currently requires some code to the source code, thus we recommend that you install from source with the `-e` flag.
+`-e` flag makes the installation editable, thus, your code change will be reflected in your Python runtime.
+We will work on this to avoid change to source code in future releases.**
+
+
+## Process Group Initializer
+
+Parallelism is often managed by process groups where processes involved in the same parallel algorithm are placed in the same
+process group. For different parallel algorithms, different process groups need to be created. Colossal-AI provides a
+global context for users to easily manage their process groups. If you wish to add new process group, you can easily
+define a new class and set it in your configuration file. To define your own way of creating process groups, you can
+follow the steps below to create a new distributed initialization.
+
+1. Add your parallel mode in `colossalai.context.parallel_mode.ParallelMode`.
+ ```python
+ class ParallelMode(Enum):
+ GLOBAL = 'global'
+ DATA = 'data'
+ PIPELINE = 'pipe'
+ ...
+
+ NEW_MODE = 'new_mode' # define your mode here
+ ```
+
+2. Create a `ProcessGroupInitializer`. You can refer to examples given in `colossalai.context.dist_group_initializer`. The
+ first six arguments are fixed. `ParallelContext` will pass in these arguments for you. If you need to set other
+ arguments, you can add it behind like the `arg1, arg2` in the example below. Lastly, register your initializer to the
+ registry by adding the decorator `@DIST_GROUP_INITIALIZER.register_module`.
+ ```python
+ # sample initializer class
+ @DIST_GROUP_INITIALIZER.register_module
+ class MyParallelInitializer(ProcessGroupInitializer):
+
+ def __init__(self,
+ rank: int,
+ world_size: int,
+ config: Config,
+ data_parallel_size: int,
+ pipeline_parlalel_size: int,
+ tensor_parallel_size: int,
+ arg1,
+ arg2):
+ super().__init__(rank, world_size, config)
+ self.arg1 = arg1
+ self.arg2 = arg2
+ # ... your variable init
+
+ def init_parallel_groups(self):
+ # initialize your process groups
+ pass
+
+ ```
+
+ Then, you can insert your new initializer to the current mode-to-initialize mapping
+ in `colossalai.constants.INITIALIZER_MAPPING`. You can modify the file or insert new key-value pair dynamically.
+
+ ```python
+ colossalai.constants.INITIALIZER_MAPPING['new_mode'] = 'MyParallelInitializer'
+ ```
+
+3. Set your initializer in your config file. You can pass in your own arguments if there is any. This allows
+ the `ParallelContext` to create your initializer and initialize your desired process groups.
+
+ ```python
+ parallel = dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=x, mode='new_mode') # this is where you enable your new parallel mode
+ )
+ ```
+
+## Gradient Handler
+
+Gradient handlers are objects which execute the all-reduce operations on parameters' gradients. As different all-reduce
+strategies may be executed for different kinds of parallelism, users can
+inherit `colossalai.engine.gradient_handler.BaseGradientHandler` to implement their strategies. Currently, the library
+uses the normal data parallel gradient handler which all-reduces the gradients across data parallel ranks. The data
+parallel gradient handler is added to the engine automatically if data parallel is detected. You can add your own
+gradient handler like below:
+
+```python
+from colossalai.registry import GRADIENT_HANDLER
+from colossalai.engine import BaseGradientHandler
+
+@GRADIENT_HANDLER.register_module
+class YourGradientHandler(BaseGradientHandler):
+
+ def handle_gradient(self):
+ do_something()
+
+```
+
+Afterwards, you can specify the gradient handler you want to use in your configuration file.
+
+```python
+gradient_handlers = [
+ dict(type='YourGradientHandler'),
+]
+```
+
+## Schedule
+
+Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline
+schedules. If you want to modify how the forward and backward passes are executed, you can
+inherit `colossalai.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.
diff --git a/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md b/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md
new file mode 100644
index 000000000000..8e48737d2f64
--- /dev/null
+++ b/docs/source/en/advanced_tutorials/define_your_own_parallel_model.md
@@ -0,0 +1,36 @@
+# Define your own parallel model
+
+Author: Zhengda Bian, Yongbin Li
+
+> ⚠️ We are working on this documentation to make it more detailed. We will introduce the mechanism of different parallelism
+> and how to use them to write a model.
+
+Let's say that you have a huge MLP model with billions of parameters and its extremely large hidden layer size makes it
+impossible to fit into a single GPU directly. Don't worry, Colossal-AI is here to help you sort things out. With the help of Colossal-AI,
+you can write your model in the familiar way in which you used to write models for a single GPU, while Colossal-AI automatically
+splits your model weights and fit them perfectly into a set of GPUs. We give a simple example showing how to write a simple
+2D parallel model in the Colossal-AI context.
+
+## Write a simple 2D parallel model
+
+```python
+from colossalai.nn import Linear2D
+import torch.nn as nn
+
+class MLP_2D(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.linear_1 = Linear2D(in_features=1024, out_features=16384)
+ self.linear_2 = Linear2D(in_features=16384, out_features=1024)
+
+ def forward(self, x):
+ x = self.linear_1(x)
+ x = self.linear_2(x)
+ return x
+```
+
+## Use pre-defined model
+
+For the sake of your convenience, we kindly provide you in our Model Zoo with some prevalent models such as *BERT*, *ViT*, *MoE*,
+and *GPT*. Feel free to customize them into different sizes to fit into your special needs.
diff --git a/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md b/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md
new file mode 100644
index 000000000000..e01caf76d2b3
--- /dev/null
+++ b/docs/source/en/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md
@@ -0,0 +1,139 @@
+# Integrate Mixture-of-Experts Into Your Model
+
+Author: Haichen Huang
+
+**Example Code**
+- [ColossalAI-Examples WideNet](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet)
+
+**Related Paper**
+- [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961)
+- [Go Wider Instead of Deeper](https://arxiv.org/abs/2107.11817)
+
+
+## Introduction
+
+Since the advent of Switch Transformer, the AI community has found Mixture of Experts (MoE) a useful technique to enlarge the capacity of deep learning models.
+
+Colossal-AI provides an early access version of parallelism specifically designed for MoE models.
+The most prominent advantage of MoE in Colossal-AI is convenience.
+We aim to help our users to easily combine MoE with model parallelism and data parallelism.
+
+However, the current implementation has two main drawbacks now.
+The first drawback is its poor efficiency in large batch size and long sequence length training.
+The second drawback is incompatibility with tensor parallelism.
+We are working on system optimization to overcome the training efficiency problem.
+The compatibility problem with tensor parallelism requires more adaptation, and we will tackle this issue in the future.
+
+Here, we will introduce how to use MoE with model parallelism and data parallelism.
+
+## Table of Content
+In this tutorial we will cover:
+1. Set up MoE running environment
+2. Create MoE layer
+3. Train your model
+
+We provided the [example code](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) for this tutorial in [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples).
+This example uses [WideNet](https://arxiv.org/abs/2107.11817) as an example of MoE-based model.
+
+
+## Set up MoE running environment
+In your project folder, create a `config.py`.
+
+This file is to specify some features you may want to use to train your model.
+In order to enable MoE, you need to add a dict called parallel and specify the value of key moe.
+You can assign a value for the key size of moe, which represents the model parallel size of experts (i.e. the number of experts in one group to parallelize training).
+
+For example, if the size is 4, 4 processes will be assigned to 4 consecutive GPUs and these 4 processes form a moe model parallel group.
+Each process on the 4 GPUs will only get a portion of experts. Increasing the model parallel size will reduce communication cost, but increase computation cost in each GPU and activation cost in memory.
+The total data parallel size is auto-detected and set as the number of GPUs by default.
+
+```python
+MOE_MODEL_PARALLEL_SIZE = ...
+parallel = dict(
+ moe=dict(size=MOE_MODEL_PARALLEL_SIZE)
+)
+```
+
+If `MOE_MODEL_PARALLEL_SIZE = E` and set the number of experts as `E` where `E` is a constant number, the process flow of forward pass of a transformer encoder in a model parallel group is shown below.
+
+
+
+MoE Transformer, image source: GShard
+
+
+Since all experts are allocated to all GPUs in a model parallel group and a GPU only owns a portion of experts,
+original data parallel groups are no longer correct for the parameters of experts during gradient handling in backward pass anymore.
+So we create a new kind of parallel group called moe data parallel group.
+The difference among different kinds of parallel group, when the configuration is set as `WORLD_SIZE=4`,
+`MOE_MODEL_PARALLEL_SIZE=2`, is shown here.
+
+
+
+MoE process group
+
+
+
+As for gradient handling, we provide MoeGradientHandler to all-reduce every parameter of the model.
+If you use `colossalai.initialize` function to create your training engine, the MoE gradient handler will be added to your engine automatically.
+Otherwise, you should take care of gradient by yourself.
+All parameters of MoE running environment are stored in colossalai.global_variables.moe_env.
+You can access your configuration parameters to check whether your setup is correct.
+```python
+from colossalai.global_variables import moe_env
+```
+
+## Create MoE layer
+You can create a MoE layer from `colossalai.nn.moe`.
+But before doing that, you should set up random seeds for all processes like this.
+
+```python
+from colossalai.context.random import moe_set_seed
+from model_zoo.moe.models import Widenet
+
+moe_set_seed(42)
+model = Widenet(num_experts=4, capacity_factor=1.2)
+```
+
+`moe_set_seed` will set different seed for different processes in a moe model parallel group.
+This helps initialize parameters in experts.
+Then create an instance of experts and an instance of router.
+Here is the example in model zoo.
+
+```python
+from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator
+
+
+noisy_func = NormalNoiseGenerator(num_experts)
+shared_router = Top2Router(capacity_factor,
+ noisy_func=noisy_func)
+shared_experts = Experts(expert=VanillaFFN,
+ num_experts=num_experts,
+ **moe_mlp_args(
+ d_model=d_model,
+ d_ff=d_ff,
+ drop_rate=drop_rate
+ ))
+ffn=MoeLayer(dim_model=d_model, num_experts=num_experts,
+ router=shared_router, experts=shared_experts)
+```
+
+Inside the initialization of Experts, the local expert number of each GPU will be calculated automatically. You just need to specify the class of each expert and its parameters used in its initialization. As for routers, we have provided top1 router and top2 router. You can find them in colossalai.nn.layer.moe. After creating the instance of experts and router, the only thing initialized in Moelayer is gate module. More definitions of each class can be found in our API document and code.
+
+
+## Train Your Model
+Do not to forget to use `colossalai.initialize` function in `colosalai` to add gradient handler for the engine.
+We handle the back-propagation of MoE models for you.
+In `colossalai.initialize`, we will automatically create a `MoeGradientHandler` object to process gradients.
+You can find more information about the handler `MoeGradientHandler` in colossal directory.
+
+The loss criterion should be wrapped by `Moeloss` to add auxiliary loss of MoE. Example is like this.
+```python
+criterion = MoeLoss(
+ aux_weight=0.01,
+ loss_fn=nn.CrossEntropyLoss,
+ label_smoothing=0.1
+)
+```
+
+Finally, just use trainer or engine in `colossalai` to do your training.
+Otherwise, you should take care of gradient by yourself.
diff --git a/docs/source/en/advanced_tutorials/meet_gemini.md b/docs/source/en/advanced_tutorials/meet_gemini.md
new file mode 100644
index 000000000000..4889b30a6cf8
--- /dev/null
+++ b/docs/source/en/advanced_tutorials/meet_gemini.md
@@ -0,0 +1,88 @@
+
+# Meet Gemini:The Heterogeneous Memory Manager of Colossal-AI
+
+Author: [Jiarui Fang](https://github.com/feifeibear), Yang You
+
+## Brief
+
+When you only have a few GPUs for large model training tasks, **heterogeneous training** is the most effective approach. By accommodating model data in CPU and GPU and moving the data to the computing device when necessary, it can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel . We now describe the design details of **Gemini**, the heterogeneous memory space manager of Colossal-AI. Its idea comes from [PatrickStar](https://arxiv.org/abs/2108.05818), which has been adapted to Colossal-AI.
+
+## Usage
+
+At present, Gemini supports compatibility with ZeRO parallel mode, and it is really simple to use Gemini. Set attribute of zero model_config, i.e., tensor_placement_policy='auto'.
+
+```
+zero = dict(
+ model_config=dict(
+ tensor_placement_policy='auto',
+ shard_strategy=BucketTensorShardStrategy()
+ ),
+ optimizer_config=dict(
+ ...)
+)
+```
+
+Note that Gemini and parallel strategies such as tensor parallelism, data parallelism, pipeline parallelism and zero should be decoupled. However, Colossal-AI requires users to use Gemini with ZeRO. Although they are not necessarily coupled, we will improve it in the near future.
+
+## Concepts
+
+**OP**(**OP**erator):operation of a neural network layer, such as linear, LayerNorm, etc. The operator can be a forward propagation calculation or a back-propagation calculation.
+
+Neural networks must manage two types of training data during training.
+**model data**: consists of parameters, gradients and optimizer states, and its scale is related to the definition of model structure.
+
+**Non-model data**: mainly composed of the intermediate tensor generated by the operator and the temporary variables of the operator. Non-model data changes dynamically according to the configuration of training tasks, such as batch size. Model data and non-model data compete with each other for GPU memory.
+
+## Design Details
+
+
+In some solutions, the [Zero-offload](https://arxiv.org/abs/2101.06840) adopted by DeepSpeed statically divides model data between CPU and GPU memory, and their memory layout is constant for different training configurations. As shown on the left of the figure below, when the GPU memory is insufficient to meet its corresponding model data requirements, the system will crash even if there is still available memory on the CPU at that time. While Colossal-AI can complete the training by moving part of the model data to the CPU.
+
+
+
+Comparison of the memory management of Zero-Offload and Gemini
+
+
+
+Colossal-AI designed Gemini, just like two-stars, which manages the memory space of CPU and GPU efficiently. It can make the tensor dynamically distributed in the storage space of CPU-GPU during training, so that the model training can break through the memory wall of GPU. The memory manager consists of two parts: **MemStatsCollector (MSC)** and **StatefuleTensorMgr (STM)**.
+
+We take advantage of the iterative characteristics of the deep learning network training process. We divide iterations into two stages: warmup and non-warmup. One or several iterative steps at the beginning belong to the warmup stage, and the other iterative steps belong to the non-warmup stage. In the warmup stage, we collect information for the MSC, while in the non-warmup stage, STM gets the information collected by the MSC to move the tensor, so as to minimize the CPU-GPU data movement volume.
+
+
+
+The workflow of Gemini during warmup and non-warmup phase
+
+
+
+### StatefulTensorMgr
+
+STM manages the information of all model data tensors. In the process of model construction, Colossal-AI registers all model data tensors with STM. The memory manager marks each tensor with state information. The state set includes three types: HOLD, COMPUTE and FREE. The functions of STM are as follows:
+
+**Query memory usage:**by traversing the locations of all tensors in heterogeneous space, obtain the memory occupation of CPU and GPU by model data.
+
+**Transition tensor state:** it marks the tensor as COMPUTE state before each model data tensor participates in the operator calculation, and as HOLD state after calculation. The FREE state marked if the tensor is no longer in use.
+
+**Adjust tensor position:**tensor manager ensures that the tensor in COMPUTE state is placed on the computing device. If the storage space of the computing device is insufficient, it is necessary to move some tensors in HOLD state to other devices for storage. Tensor eviction strategy requires information from MSC, which will be introduced later.
+
+
+### MemStatsCollector
+In the warmup stage, the memory information statistician monitors the memory usage of model data and non-model data in CPU and GPU for reference in the non-warmup stage. We can obtain the memory usage of model data at a certain time by querying STM. However, the memory usage of non-model data is difficult to obtain. Owing to the life cycle of non-model data not being managed by users, the existing deep learning framework does not expose the tracking interface of non-model data to users. MSC obtains the usage of CPU and GPU memory by non-model in the warmup stage through sampling. The specific methods are as follows:
+
+We trigger the memory sampling operation at the beginning and end of the operator. We call this time point **sampling moment**, and the time between the two sampling moments is called **period**. The calculation process is a black box. Due to the possible allocation of temporary buffer, the memory usage is very complex. However, we can accurately obtain the maximum memory usage of the system during the period. The use of non-model data can be obtained by the maximum memory use of the system between two statistical moments-model memory use.
+
+How do we design the sampling time. Before we choose model data layout adjust of preOp. As shown in the figure below. We sample the system memory used of the previous period and the model data memory used of the next period. The parallel strategy will cause obstacles to the work of MSC. As shown in the figure, for example, for ZeRO or Tensor Parallel, because gathering model data is required before OP calculation, it will bring additional memory requirements. Therefore, we require to sample the system memory before the model data changes, so that the MSC will capture the model change memory of preOp within a period. For example, in period 2-3, we consider the memory changes brought by tensor gather and shard.
+
+Although the sampling time can be placed in other locations, such as excluding the new information of the change of the gather buffer, it will cause trouble. There are differences in the implementation of Op in different parallel modes. For example, for Linear Op, gather buffer in Tensor Parallel is allocated in Op. For ZeRO, the allocation of gather buffer is in PreOp. Sampling at the beginning of PreOp helps to unify the two situations.
+
+
+
+workflow
+
+
+### Tensor Eviction Strategy
+
+The important duty of MSC is to adjust the tensor layout position. For example, at S2 in the figure above, we reduce the model data on the device, and meet the peak memory requirement calculated in period 2-3.
+
+In the warmup stage, since we haven't finished a complete iteration yet, we don't know actual memory occupation. At this time, we limit the upper bound of memory usage of the model data. For example, only 30% of the GPU memory can be used. This ensures that we can successfully complete the warmup state.
+
+In the non-warmup stage, we need to use the memory information of non-model data collected in the warm-up stage to reserve the peak memory required by the computing device for the next Period, which requires us to move some model tensors. In order to avoid frequent replacement of the same tensor in and out of the CPU-GPU, causing a phenomenon similar to [cache thrashing](https://en.wikipedia.org/wiki/Thrashing_(computer_science)). Using the iterative characteristics of DNN training, we design the OPT cache swap out strategy. Specifically, in the warmup stage, we record the sampling time required by each tensor computing device. If we need to expel some HOLD tensors, we will choose the latest tensor needed on this device as the victim.
diff --git a/docs/source/en/advanced_tutorials/opt_service.md b/docs/source/en/advanced_tutorials/opt_service.md
new file mode 100644
index 000000000000..b317de91bbdd
--- /dev/null
+++ b/docs/source/en/advanced_tutorials/opt_service.md
@@ -0,0 +1,81 @@
+# Build an online OPT service using Colossal-AI in 5 minutes
+
+## Introduction
+
+This tutorial shows how to build your own service with OPT with the help of [Colossal-AI](https://github.com/hpcaitech/ColossalAI).
+
+## Colossal-AI Inference Overview
+Colossal-AI provides an inference subsystem [Energon-AI](https://github.com/hpcaitech/EnergonAI), a serving system built upon Colossal-AI, which has the following characteristics:
+
+- **Parallelism for Large-scale Models:** With the help of tensor parallel operations, pipeline parallel strategies from Colossal-AI, Colossal-AI inference enables efficient parallel inference for large-scale models.
+- **Pre-built large models:** There are pre-built implementations for popular models, such as OPT. It supports a caching technique for the generation task and checkpoints loading.
+- **Engine encapsulation:** There has an abstraction layer called an engine. It encapsulates the single instance multiple devices (SIMD) execution with the remote procedure call, making it act as the single instance single device (SISD) execution.
+- **An online service system:** Based on FastAPI, users can launch a web service of a distributed inference quickly. The online service makes special optimizations for the generation task. It adopts both left padding and bucket batching techniques to improve efficiency.
+
+## Basic Usage:
+
+1. Download OPT model
+
+To launch the distributed inference service quickly, you can download the OPT-125M from [here](https://huggingface.co/patrickvonplaten/opt_metaseq_125m/blob/main/model/restored.pt). You can get details for loading other sizes of models [here](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt/script).
+
+2. Prepare a prebuilt service image
+
+Pull a docker image from dockerhub installed with Colossal-AI inference.
+
+```bash
+docker pull hpcaitech/energon-ai:latest
+```
+
+3. Launch an HTTP service
+
+To launch a service, we need to provide python scripts to describe the model type and related configurations, and settings for the HTTP service.
+We have provided a set of [examples](https://github.com/hpcaitech/EnergonAI/tree/main/examples]). We will use the [OPT example](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt) in this tutorial.
+The entrance of the service is a bash script server.sh.
+The config of the service is at opt_config.py, which defines the model type, the checkpoint file path, the parallel strategy, and http settings. You can adapt it for your own case.
+For example, set the model class as opt_125M and set the correct checkpoint path as follows.
+
+```bash
+model_class = opt_125M
+checkpoint = 'your_file_path'
+```
+
+Set the tensor parallelism degree the same as your gpu number.
+
+```bash
+tp_init_size = #gpu
+```
+
+Now, we can launch a service using docker. You can map the path of the checkpoint and directory containing configs to local disk path `/model_checkpoint` and `/config`.
+
+
+```bash
+export CHECKPOINT_DIR="your_opt_checkpoint_path"
+# the ${CONFIG_DIR} must contain a server.sh file as the entry of service
+export CONFIG_DIR="config_file_path"
+
+docker run --gpus all --rm -it -p 8020:8020 -v ${CHECKPOINT_DIR}:/model_checkpoint -v ${CONFIG_DIR}:/config --ipc=host energonai:lastest
+```
+
+Then open `https://[IP-ADDRESS]:8020/docs#` in your browser to try out!
+
+
+## Advance Features Usage:
+
+1. Batching Optimization
+
+To use our advanced batching technique to collect multiple queries in batches to serve, you can set the executor_max_batch_size as the max batch size. Note, that only the decoder task with the same top_k, top_p and temperature can be batched together.
+
+```
+executor_max_batch_size = 16
+```
+
+All queries are submitted to a FIFO queue. All consecutive queries whose number of decoding steps is less than or equal to that of the head of the queue can be batched together. Left padding is applied to ensure correctness. executor_max_batch_size should not be too large. This ensures batching won't increase latency. For opt-30b, `executor_max_batch_size=16` may be a good choice, while for opt-175b, `executor_max_batch_size=4` may be better.
+
+2. Cache Optimization.
+
+You can cache several recently served query results for each independent serving process. Set the cache_size and cache_list_size in config.py. The cache size is the number of queries cached. The cache_list_size is the number of results stored for each query. And a random cached result will be returned. When the cache is full, LRU is applied to evict cached queries. cache_size=0means no cache is applied.
+
+```
+cache_size = 50
+cache_list_size = 2
+```
diff --git a/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md
new file mode 100644
index 000000000000..e7698e5e9d1b
--- /dev/null
+++ b/docs/source/en/advanced_tutorials/parallelize_your_training_like_Megatron.md
@@ -0,0 +1,192 @@
+# Parallelize Your Training like Megatron-LM via ColoTensor
+
+Author: [Haichen Huang](https://github.com/1SAA) and [Jiarui Fang](https://github.com/feifeibear)
+
+**Prerequisite:**
+- [ColoTensor Concepts](../basics/colotensor_concept.md)
+
+## Introduction
+
+Thanks to the convenience given by ColoTensor, users can apply parallelism with the least edition to their serial code.
+In this tutorial, we will illustrate how to modify the training model to automatically adapt the code to parallel training like Megatron-LM.
+We take the GPT-2 model offered by HuggingFace as an example and provide a way for you to pre-train the GPT-2 model on a single GPU.
+
+Megatron-LM provided a profound paradigm to parallelize large transformer language models.
+However, in order to train large transformer language models at scale, users have to build their models with those modules provided by Megatron.
+It imposes several difficult jobs on users, such as loading the weights from the pre-trained models and constructing the parallelized models.
+To mitigate users' trouble, we offer ColoTensor to enable the tensor model parallelism automatically.
+
+## Definitions of the model and the loss function
+
+First we use the GPTModel and GPTLoss directly from the HuggingFace library.
+
+```python
+import torch
+import torch.nn as nn
+from transformers import GPT2Config, GPT2LMHeadModel
+
+class GPTLMModel(nn.Module):
+ def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False):
+ super().__init__()
+ self.checkpoint = checkpoint
+ self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers,
+ n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size))
+ if checkpoint:
+ self.model.gradient_checkpointing_enable()
+
+ def forward(self, input_ids, attention_mask):
+ # Only return lm_logits
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
+
+
+class GPTLMLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.loss_fn = nn.CrossEntropyLoss()
+
+ def forward(self, logits, labels):
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+```
+
+## Brief Review of GPT-2
+
+Now, we recall the structure of each GPT-2 model.
+Every GPT-2 model can be represented as a DAG.
+As shown in the below pictures, each circle represents an operator and each square represents a weight.
+An arrow indicates the flow of the input data, and the notation alongside the arrow demonstrates the shape of the input data.
+
+Then, let's take an insight into this GPT-2 model. It consists of three parts.
+They are the **embedding module**, **transformer layers**, and the **classification head**.
+
+The embedding module contains two weights, token embedding weight and position embedding weight.
+After the forward operation of the embedding module, each word in all sequences of the raw input data will be embedded into a hidden state.
+
+
+
+The embedding module
+
+
+Each transformer layer contains two blocks. The self-attention operation is called in the first block and a two-layer percepton is located in the second block.
+
+
+
+The transformer layer
+
+
+In the end, the classification head is just a linear module without bias, which only has a weight inside.
+
+## Applied with ColoTensor
+
+Two steps make your serial code adapted to Megatron-LM tensor parallel style.
+1. Initialize the model in the context of ColoInitContext.
+2. Setting ColoTensorSpec for each parameter.
+
+### Initialize with ColoInitContext
+
+We should build the model in the ColoInitContext.
+In this context, any parameter initialized would be transformed to ColoParameter and moved to the corresponded device automatically.
+
+```python
+from colossalai.utils.model.colo_init_context import ColoInitContext
+
+with ColoInitContext(device=torch.device('cpu')):
+ model = GPTLMModel()
+```
+
+### Setting ColoTensorSpec for each parameter
+
+After the creation of the model, we establish the distributed environment through ProcessGroup.
+Here, we specify the degree of the tensor parallelism as the same as the number of all GPUs, which means the degree of data parallelism is 1.
+
+```python
+import torch.distributed as dist
+from colossalai.tensor import ProcessGroup
+
+pg = ProcessGroup(tp_degree=dist.get_world_size())
+```
+
+Now, some auxiliary functions are necessary for the next step. We define two functions to split a parameter.
+Megatron-LM-like tensor parallelism requires splitting a parameter tensor along its first dimension or its last dimension.
+
+```python
+from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern, ColoParameter, ProcessGroup
+
+def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
+ spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
+ if param.process_group.tp_world_size() == 1:
+ param.set_process_group(pg)
+ param.set_tensor_spec(*spec)
+
+
+def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
+ split_param_single_dim_tp1d(0, param, pg)
+
+
+def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
+ split_param_single_dim_tp1d(-1, param, pg)
+```
+
+Then we adapt the model to the tensor parallelism.
+According to the tensor parallelism applied in Megatron, it is supposed to shard along the last dimension of tensors, including the weights of token embedding, position embedding, all linear weights and biases in self-attention blocks, the first weight linear and bias in each MLP.
+And it shards the second linear weight along its first dimension.
+
+```python
+for mn, module in model.named_modules():
+ for pn, param in module.named_parameters(recurse=False):
+ # set process group for all parameters
+ param.set_process_group(pg)
+
+ if 'mlp.c_fc' in mn:
+ if 'weight' in pn or 'bias' in pn:
+ split_param_col_tp1d(param, pg) # colmn slice
+ # keep the shape of the output from c_fc
+ param.compute_spec.set_output_replicate(False)
+ elif 'mlp.c_proj' in mn:
+ if 'weight' in pn:
+ split_param_row_tp1d(param, pg) # row slice
+ elif 'wte' in mn or 'wpe' in mn:
+ split_param_col_tp1d(param, pg) # colmn slice
+ elif 'c_attn' in mn or 'c_proj' in mn:
+ split_param_col_tp1d(param, pg) # colmn slice
+```
+
+The modified model is illustrated below.
+
+The embedding module:
+
+
+
+The modified embedding module
+
+
+The transformer layers:
+
+
+
+The modified transformer layer
+
+
+Once users have specified the distributed pattern of each parameter, ColoTensor is capable of inferring the computation patterns of all operators, including matrix multiplication, the linear function, other elementwise functions in torch.nn.functional, etc.
+In this way, users can train their models as usual.
+
+In our latest example, a Gemini + ZeRO DDP model is also defined to reduce overhead and improve efficiency.For the details of this part, please refer to [ZeRO](../features/zero_with_chunk.md). You can combine these two parts to understand our entire training process:
+
+```python
+def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
+ from colossalai.nn.parallel import GeminiDDP
+ model = GeminiDDP(model,
+ device=get_current_device(),
+ placement_policy=placememt_policy,
+ pin_memory=True,
+ search_range_mb=32)
+ return model
+```
+
+## Pretrain GPT-2 On Single GPU
+
+The above optimization we made allows us to pretrain the GPT-2 model on a single GPU. We only need to set the parameter `GPUNUM`=1 in `run.sh`, and then we can complete the model training on a single GPU when running the file.
+
+The GPT-2 example is accessible at [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt).
diff --git a/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
new file mode 100644
index 000000000000..715c15eb6300
--- /dev/null
+++ b/docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -0,0 +1,270 @@
+# Train GPT Using Hybrid Parallelism
+
+Author: Hongxin Liu, Yongbin Li
+
+**Example Code**
+- [ColossalAI-Examples GPT2](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_2)
+- [ColossalAI-Examples GPT3](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_3)
+
+**Related Paper**
+- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
+- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473)
+
+## Introduction
+
+In the previous tutorial, we introduce how to train ViT with pipeline. In this tutorial, you will learn a more complex scenario -- train GPT with hybrid parallelism. In this case, GPT-3 is so large that CPU memory cannot fit it as well. Therefore, you must split the model by yourself.
+
+## Table of content
+
+In this tutorial we will cover:
+
+1. The definition of GPT model, based on colossalai/model_zoo
+2. Processing the dataset
+3. Training GPT using hybrid parallelism
+
+## Import libraries
+
+```python
+import json
+import os
+from typing import Callable
+
+import colossalai
+import colossalai.utils as utils
+import model_zoo.gpt.gpt as col_gpt
+import torch
+import torch.nn as nn
+from colossalai import nn as col_nn
+from colossalai.amp import AMP_TYPE
+from colossalai.builder.pipeline import partition_uniform
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+ PipelineSchedule)
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
+from colossalai.trainer import Trainer, hooks
+from colossalai.utils.timer import MultiTimer
+from model_zoo.gpt import GPTLMLoss
+from torch.nn import functional as F
+from torch.utils.data import Dataset
+from transformers import GPT2Tokenizer
+```
+
+
+
+## Define GPT model
+
+In the previous tutorial, we introduced 3 ways to build a pipelined model. But for huge models like GPT-3, you can't even build the model in CPU. In this case, you must split the model by yourself.
+
+GPT dataloader returns `input_ids` and `attention_mask`, so we use two keyword arguments in `forward()` to get them. Note that for stages except the first stage, the first positional argument of `forward()` is the output tensor from the previous stage. So the `hidden_states` is from the previous stage, and for the first stage it's `None`.
+
+For GPT, the *word embedding layer* shares the weights with the *output head*. We provide `PipelineSharedModuleWrapper` to share parameters among pipeline stages. It takes a `list` of `int` as argument, which means those ranks share the parameters. You can use `register_module()` or `register_parameter()` to register a module or a parameter as the shared module or parameter. If you have multiple sets of shared modules / parameters, you should have multiple `PipelineSharedModuleWrapper` instance. If the parameter is shared within **one** stage, you should not use `PipelineSharedModuleWrapper`, and just use the same module / parameter instance. In this example, the *word embedding layer* is at the first stage, and the *output head* is at the last stage. Thus, they are shared among ranks `[0, pipeline_size - 1]`.
+
+For the first stage, it maintains the embedding layer and some transformer blocks. For the last stage, it maintains some transformer blocks and the output head layer. For other stages, they just maintain some transformer blocks. `partition_uniform(num_layers, pipeline_size, num_chunks)` returns the parts of all ranks, and the part is a `tuple` of `(start, end)` (exclude end). `start == 0` means that it's the first stage, and `end == num_layers` means it's the last stage.
+
+```python
+class PipelineGPTHybrid(nn.Module):
+ def __init__(self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ embed_drop_rate: float = 0.,
+ act_func: Callable = F.gelu,
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0.,
+ drop_rate: float = 0.,
+ dtype: torch.dtype = torch.float,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ first: bool = False,
+ last: bool = False):
+ super().__init__()
+ self.embedding = None
+ self.norm = None
+ self.head = None
+ if first:
+ self.embedding = col_gpt.GPTEmbedding(
+ hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype)
+ self.blocks = nn.ModuleList([
+ col_gpt.GPTBlock(hidden_size, num_attention_heads, mlp_ratio=mlp_ratio, attention_dropout=attn_drop_rate,
+ dropout=drop_rate, dtype=dtype, checkpoint=checkpoint, activation=act_func)
+ for _ in range(num_layers)
+ ])
+ if last:
+ self.norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+ self.head = col_gpt.GPTLMHead(vocab_size=vocab_size,
+ dim=hidden_size,
+ dtype=dtype,
+ bias=False)
+
+ def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
+ if self.embedding is not None:
+ hidden_states = self.embedding(input_ids=input_ids)
+ batch_size = hidden_states.shape[0]
+ attention_mask = attention_mask.view(batch_size, -1)
+ attention_mask = attention_mask[:, None, None, :]
+ attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * -10000.0
+ for block in self.blocks:
+ hidden_states, attention_mask = block(hidden_states, attention_mask)
+ if self.norm is not None:
+ hidden_states = self.head(self.norm(hidden_states))
+ return hidden_states
+
+
+def build_gpt_pipeline(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ logger = get_dist_logger()
+ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
+ pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
+ rank = gpc.get_global_rank()
+ wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
+ parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
+ models = []
+ for start, end in parts:
+ kwargs['num_layers'] = end - start
+ kwargs['first'] = start == 0
+ kwargs['last'] = end == num_layers
+ logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
+ chunk = PipelineGPTHybrid(**kwargs).to(device)
+ if start == 0:
+ wrapper.register_module(chunk.embedding.word_embeddings)
+ elif end == num_layers:
+ wrapper.register_module(chunk.head)
+ models.append(chunk)
+ if len(models) == 1:
+ model = models[0]
+ else:
+ model = nn.ModuleList(models)
+ return model
+
+
+def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float):
+ cfg = dict(hidden_size=1600, num_attention_heads=32, checkpoint=checkpoint, dtype=dtype)
+ return build_gpt_pipeline(48, num_chunks, **cfg)
+
+
+def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float):
+ cfg = dict(hidden_size=12288, num_attention_heads=96,
+ checkpoint=checkpoint, max_position_embeddings=2048, dtype=dtype)
+ return build_gpt_pipeline(96, num_chunks, **cfg)
+```
+
+## Process the dataset
+
+We provide a small GPT web-text dataset here. The original format is loose JSON, and we will save the processed dataset.
+
+```python
+class WebtextDataset(Dataset):
+ def __init__(self, path, seq_len=1024) -> None:
+ super().__init__()
+ root = os.path.dirname(path)
+ encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
+ if os.path.isfile(encoded_data_cache_path):
+ seq_len_, data, attention_mask = torch.load(
+ encoded_data_cache_path)
+ if seq_len_ == seq_len:
+ self.data = data
+ self.attention_mask = attention_mask
+ return
+ raw_data = []
+ with open(path) as f:
+ for line in f.readlines():
+ raw_data.append(json.loads(line)['text'])
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.unk_token
+ encoded_data = tokenizer(
+ raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
+ self.data = encoded_data['input_ids']
+ self.attention_mask = encoded_data['attention_mask']
+ torch.save((seq_len, self.data, self.attention_mask),
+ encoded_data_cache_path)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ return {
+ 'input_ids': self.data[index],
+ 'attention_mask': self.attention_mask[index]
+ }, self.data[index]
+```
+
+## Training GPT using hybrid parallelism
+
+In the previous tutorial, we explained the meanings of some pipeline arguments. In this case, we can determine the shape of each output tensor which is exchanged among pipeline stages. For GPT, the shape is `(MICRO BATCH SIZE, SEQUENCE LEN, HIDDEN SIZE)`. By setting this, we can avoid exchanging the tensor shape of each stage. When you are not sure of the tensor shape, you can just leave it `None`, and the shape is inferred automatically. Make sure that the `dtype` of your model is correct. When you use `fp16`, the `dtype` of your model must be `torch.half`. Otherwise, the `dtype` must be `torch.float`. For pipeline parallelism, only `AMP_TYPE.NAIVE` is supported.
+
+You can easily use tensor parallel by setting `parallel` in `CONFIG`. The data parallelism size is automatically set based on the number of GPUs.
+
+```python
+NUM_EPOCHS = 60
+SEQ_LEN = 1024
+BATCH_SIZE = 192
+NUM_CHUNKS = None
+TENSOR_SHAPE = (1, 1024, 1600)
+# only pipeline parallel
+# CONFIG = dict(parallel=dict(pipeline=2), fp16=dict(mode=AMP_TYPE.NAIVE))
+# pipeline + 1D model parallel
+CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2, tensor=dict(mode='1d', size=2)), fp16=dict(mode=AMP_TYPE.NAIVE))
+
+
+def train():
+ disable_existing_loggers()
+ parser = colossalai.get_default_parser()
+ args = parser.parse_args()
+ colossalai.launch_from_torch(config=CONFIG, backend=args.backend)
+ logger = get_dist_logger()
+
+ train_ds = WebtextDataset(os.environ['DATA'], seq_len=SEQ_LEN)
+ train_dataloader = utils.get_dataloader(train_ds,
+ seed=42,
+ batch_size=BATCH_SIZE,
+ pin_memory=True,
+ shuffle=True,
+ drop_last=True)
+
+ use_interleaved = NUM_CHUNKS is not None
+ num_chunks = 1 if not use_interleaved else NUM_CHUNKS
+ model = GPT2_exlarge_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half)
+ # model = GPT3_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half)
+ if use_interleaved and not isinstance(model, nn.ModuleList):
+ model = nn.ModuleList([model])
+
+ criterion = GPTLMLoss()
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2,)
+
+ engine, train_dataloader, _, _ = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader=train_dataloader)
+ global_batch_size = BATCH_SIZE * \
+ gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
+ logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
+
+ timer = MultiTimer()
+
+ trainer = Trainer(
+ engine=engine,
+ logger=logger,
+ timer=timer
+ )
+
+ hook_list = [
+ hooks.LossHook(),
+ hooks.LogMetricByEpochHook(logger),
+ hooks.ThroughputHook(),
+ hooks.LogMetricByStepHook(),
+ ]
+
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ epochs=NUM_EPOCHS,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True,
+ return_output_label=False,
+ )
+```
diff --git a/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
new file mode 100644
index 000000000000..b26599740c5f
--- /dev/null
+++ b/docs/source/en/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -0,0 +1,247 @@
+# Train ViT Using Pipeline Parallelism
+
+Author: Hongxin Liu, Yongbin Li
+
+**Example Code**
+- [ColossalAI-Examples Pipeline Parallel ViT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer/pipeline_parallel)
+
+**Related Paper**
+- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473)
+
+## Introduction
+
+In this tutorial, you will learn how to train Vision Transformer for image classification from scratch, using pipeline.
+Pipeline parallelism is a kind of model parallelism, which is useful when your GPU memory cannot fit your model.
+By using it, we split the original model into multi stages, and each stage maintains a part of the original model.
+We assume that your GPU memory cannot fit ViT/L-16, and your memory can fit this model.
+
+## Table of contents
+
+In this tutorial we will cover:
+
+1. The definition of ViT model, based on [TIMM](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)
+2. Processing the dataset
+3. Training ViT using pipeline
+
+## Import libraries
+
+```python
+import os
+from collections import OrderedDict
+from functools import partial
+
+import colossalai
+import colossalai.nn as col_nn
+import torch
+import torch.nn as nn
+from colossalai.builder import build_pipeline_model
+from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+ PipelineSchedule)
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.trainer import Trainer, hooks
+from colossalai.utils import MultiTimer, get_dataloader
+from timm.models import vision_transformer as vit
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+```
+
+
+
+## Define Vision Transformer model
+
+Generally, we provide 3 ways to build a pipelined model:
+
+1. `colossalai.builder.build_pipeline_model_from_cfg`
+2. `colossalai.builder.build_pipeline_model`
+3. Split the model by stages by yourself
+
+When your memory can fit the model, you can use the first two methods to build your model, otherwise you must split the model by yourself. The first two methods first build the whole model on CPU, then split the model, and finally you can just move the corresponding part of model to GPU.
+
+`colossalai.builder.build_pipeline_model_from_cfg()` receives a config file of model, and it can split the model uniformly (by layer) or balanced (by parameter size).
+
+If you are familiar with `PyTorch`, you can use `colossalai.builder.build_pipeline_model()` which receives a `torch.nn.Sequential` model and split it by layer uniformly.
+
+In this tutorial, we will modify [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential` and then use `colossalai.builder.build_pipeline_model()` to build the pipelined model.
+
+When the data is **one** `Tensor`, you can use the positional argument in `forward()` of your model to get the data tensor. For the first stage of pipeline, the first positional argument of `forward()` is the data tensor loaded from data loader. For other stages, the first positional argument of `forward()` is the output tensor from the previous stage. Note that if the stage is not the last stage, the return of `forward()` must be a `Tensor`.
+
+When the data is a `dict` of `Tensor`, you can use named keyword arguments in `forward()` of your model to get the data `dict`.
+
+```python
+class ViTEmbedding(nn.Module):
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, embed_layer=vit.PatchEmbed, drop_rate=0., distilled=False):
+ super().__init__()
+ self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 2 if distilled else 1
+ self.patch_embed = embed_layer(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ self.init_weights()
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ if self.dist_token is None:
+ x = torch.cat((cls_token, x), dim=1)
+ else:
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = self.pos_drop(x + self.pos_embed)
+ return x
+
+ def init_weights(self):
+ vit.trunc_normal_(self.pos_embed, std=.02)
+ if self.dist_token is not None:
+ vit.trunc_normal_(self.dist_token, std=.02)
+ vit.trunc_normal_(self.cls_token, std=.02)
+ self.apply(vit._init_vit_weights)
+
+
+class ViTHead(nn.Module):
+ def __init__(self, embed_dim=768, num_classes=1000, norm_layer=None, distilled=False, representation_size=None):
+ super().__init__()
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ self.norm = norm_layer(embed_dim)
+ self.num_classes = num_classes
+ self.distilled = distilled
+ self.num_features = embed_dim
+ # Representation layer
+ if representation_size and not distilled:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+ # Classifier head(s)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = None
+ if distilled:
+ self.head_dist = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.init_weights()
+
+ def forward(self, x):
+ x = self.norm(x)
+ if self.distilled:
+ x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1])
+ if self.training and not torch.jit.is_scripting():
+ # during inference, return the average of both classifier predictions
+ return x, x_dist
+ else:
+ return (x + x_dist) / 2
+ else:
+ x = self.pre_logits(x[:, 0])
+ x = self.head(x)
+ return x
+
+ def init_weights(self):
+ self.apply(vit._init_vit_weights)
+
+
+def sequential_vit(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=vit.PatchEmbed, norm_layer=None,
+ act_layer=None):
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+ embedding = ViTEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
+ embed_dim=embed_dim, embed_layer=embed_layer, drop_rate=drop_rate, distilled=distilled)
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ blocks = [vit.Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
+ for i in range(depth)]
+ for block in blocks:
+ block.apply(vit._init_vit_weights)
+ head = ViTHead(embed_dim=embed_dim, num_classes=num_classes, norm_layer=norm_layer,
+ distilled=distilled, representation_size=representation_size)
+ return nn.Sequential(embedding, *blocks, head)
+
+
+def vit_large_patch16_224(**kwargs):
+ model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ return sequential_vit(**model_kwargs)
+```
+
+## Process the dataset
+
+Generally, we train ViT on large dataset like Imagenet. For simplicity, we just use CIFAR-10 here, since this tutorial is just for pipeline training.
+
+```python
+def build_cifar(batch_size):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(224, pad_if_needed=True),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
+ train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
+ test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
+ return train_dataloader, test_dataloader
+```
+
+## Training ViT using pipeline
+
+You can set the size of pipeline parallel and number of microbatches in config. `NUM_CHUNKS` is useful when using interleved-pipeline (for more details see [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) ). The original batch will be split into `num_microbatches`, and each stage will load a micro batch each time. Then we will generate an approriate schedule for you to execute the pipeline training. If you don't need the output and label of model, you can set `return_output_label` to `False` when calling `trainer.fit()` which can further reduce GPU memory usage.
+
+You should `export DATA=/path/to/cifar`.
+
+```python
+BATCH_SIZE = 16
+NUM_EPOCHS = 60
+NUM_CHUNKS = 1
+CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))
+
+
+def train():
+ disable_existing_loggers()
+ parser = colossalai.get_default_parser()
+ args = parser.parse_args()
+ colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
+ logger = get_dist_logger()
+
+ # build model
+ model = vit_large_patch16_224()
+ model = build_pipeline_model(model, num_chunks=NUM_CHUNKS, verbose=True)
+
+ # build criterion
+ criterion = nn.CrossEntropyLoss()
+
+ # optimizer
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
+
+ # build dataloader
+ train_dataloader, test_dataloader = build_cifar(BATCH_SIZE)
+
+ engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, optimizer, criterion,
+ train_dataloader, test_dataloader)
+ timer = MultiTimer()
+
+ trainer = Trainer(engine=engine, timer=timer, logger=logger)
+
+ hook_list = [
+ hooks.LossHook(),
+ hooks.AccuracyHook(col_nn.metric.Accuracy()),
+ hooks.LogMetricByEpochHook(logger),
+ ]
+
+ trainer.fit(train_dataloader=train_dataloader,
+ epochs=NUM_EPOCHS,
+ test_dataloader=test_dataloader,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True)
+```
diff --git a/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
new file mode 100644
index 000000000000..1f3086559939
--- /dev/null
+++ b/docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -0,0 +1,646 @@
+# Step By Step: Accelerate ViT Training With Colossal-AI (From Data Parallel to Hybrid Parallel)
+
+Author: Yuxuan Lou
+
+**Example Code**
+
+- [Colossal-AI Examples ViT on Cifar10](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer)
+
+**Related Paper**
+- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf)
+
+
+## Introduction
+
+In this example for ViT model, Colossal-AI provides three different parallelism techniques which acclerate model training: data parallelism, pipeline parallelism and tensor parallelism.
+We will show you how to train ViT on CIFAR-10 dataset with these parallelism techniques. To run this example, you will need 2-4 GPUs.
+
+
+## Tabel of Contents
+1. Colossal-AI installation
+2. Steps to train ViT with data parallelism
+3. Steps to train ViT with pipeline parallelism
+4. Steps to train ViT with tensor parallelism or hybrid parallelism
+
+## Colossal-AI Installation
+You can install Colossal-AI pacakage and its dependencies with PyPI.
+```bash
+pip install colossalai
+```
+
+
+
+## Data Parallelism
+Data parallism is one basic way to accelerate model training process. You can apply data parallism to training by only two steps:
+1. Define a configuration file
+2. Change a few lines of code in train script
+
+### Define your configuration file (`data_parallel/config.py`)
+To use Colossal-AI, the first step is to define a configuration file. And there are two kinds of variables here:
+
+1. **Colossal-AI feature specification**
+
+There is an array of features Colossal-AI provides to speed up training (parallel mode, mixed precision, ZeRO, etc.). Each feature is defined by a corresponding field in the config file. If we apply data parallel only, we do not need to specify the parallel mode. In this example, we use mixed precision training natively provided by PyTorch by define the mixed precision configuration `fp16 = dict(mode=AMP_TYPE.TORCH)`.
+
+2. **Global hyper-parameters**
+
+Global hyper-parameters include model-specific hyper-parameters, training settings, dataset information, etc.
+
+```python
+from colossalai.amp import AMP_TYPE
+
+# ViT Base
+BATCH_SIZE = 256
+DROP_RATE = 0.1
+NUM_EPOCHS = 300
+
+# mix precision
+fp16 = dict(
+ mode=AMP_TYPE.TORCH,
+)
+
+gradient_accumulation = 16
+clip_grad_norm = 1.0
+
+dali = dict(
+ gpu_aug=True,
+ mixup_alpha=0.2
+)
+```
+
+### Modify train script (`/data_parallel/train_with_cifar10.py`)
+
+#### Import modules
+- Colossal-AI related modules
+```python
+import colossalai
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.lr_scheduler import LinearWarmupLR
+from colossalai.nn.metric import Accuracy
+from colossalai.trainer import Trainer, hooks
+```
+
+- Other modules
+```python
+import os
+
+import torch
+from timm.models import vit_base_patch16_224
+
+
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+```
+
+#### Lauch Colossal-AI
+
+In train script, you need to initialize the distributed environment for Colossal-AI after your config file is prepared. We call this process `launch`. In Colossal-AI, we provided several launch methods to initialize the distributed backend. In most cases, you can use `colossalai.launch` and `colossalai.get_default_parser` to pass the parameters via command line. Besides, Colossal-AI can utilize the existing launch tool provided by PyTorch as many users are familiar with by using `colossalai.launch_from_torch`. For more details, you can view the related [documents](https://www.colossalai.org/docs/basics/launch_colossalai).
+
+```python
+# initialize distributed setting
+parser = colossalai.get_default_parser()
+args = parser.parse_args()
+colossalai.launch_from_torch(config=args.config)
+
+disable_existing_loggers()
+logger = get_dist_logger()
+```
+
+After initialization, you can acess the variables in the config file by using `colossalai.core.global_context`.
+
+```python
+#access parameters
+print(gpc.config.BATCH_SIZE)
+```
+
+#### Build Model
+
+If only data parallelism is required, you do not need to make any changes to your model. Here, we use `vit_base_patch16_224` from `timm`.
+```python
+# build model
+model = vit_base_patch16_224(drop_rate=0.1, num_classes=gpc.config.NUM_CLASSES)
+```
+
+#### Build CIFAR-10 Dataloader
+`colossalai.utils.get_dataloader` can help you build dataloader easily.
+
+```python
+def build_cifar(batch_size):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(224, pad_if_needed=True),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
+ train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
+ test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
+ return train_dataloader, test_dataloader
+
+
+# build dataloader
+train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE)
+```
+
+#### Define optimizer, loss function and LR scheduler
+
+Colossal-AI provides its own optimizer, loss function and LR scheduler. Those from PyTorch are also compatible.
+
+```python
+# build optimizer
+optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1)
+
+# build loss
+criterion = torch.nn.CrossEntropyLoss()
+
+# lr_scheduelr
+lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS)
+```
+
+#### Start Colossal-AI engine
+
+Engine is essentially a wrapper class for model, optimizer and loss function. When we call `colossalai.initialize`, an engine object will be returned, and it has already been equipped with functionalities such as gradient clipping, gradient accumulation and zero optimizer as specified in your configuration file. Further model training is based on Colossal-AI engine.
+
+```python
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
+ model, optimizer, criterion, train_dataloader, test_dataloader
+ )
+```
+
+#### Train: Trainer API
+Trainer is a more high-level wrapper for the user to execute training with fewer lines of code. It is easy to create a trainer object by passing the engine object.
+
+Besides, In trainer, the user can customize some hooks and attach these hooks to the trainer object. A hook object will execute life-cycle methods periodically based on the training scheme. For example, The `LRSchedulerHook` will execute `lr_scheduler.step()` to update the learning rate of the model during either `after_train_iter` or `after_train_epoch` stages.
+
+```python
+# build trainer
+trainer = Trainer(engine=engine, logger=logger)
+
+# build hooks
+hook_list = [
+ hooks.LossHook(),
+ hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
+ hooks.LogMetricByEpochHook(logger),
+ hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
+
+ # comment if you do not need to use the hooks below
+ hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'),
+ hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
+]
+```
+
+Use `trainer.fit` for training:
+
+```python
+# start training
+trainer.fit(
+ train_dataloader=train_dataloader,
+ test_dataloader=test_dataloader,
+ epochs=gpc.config.NUM_EPOCHS,
+ hooks=hook_list,
+ display_progress=True,
+ test_interval=1
+)
+```
+
+### Start training
+`DATA` is the filepath where CIFAR-10 dataset will be automatically downloaded and stored.
+
+`` is the number of GPUs you want to use to train ViT on CIFAR-10 with data parallelism.
+
+```bash
+export DATA=
+# If your torch >= 1.10.0
+torchrun --standalone --nproc_per_node train_dp.py --config ./configs/config_data_parallel.py
+# If your torch >= 1.9.0
+# python -m torch.distributed.run --standalone --nproc_per_node= train_dp.py --config ./configs/config_data_parallel.py
+# Otherwise
+# python -m torch.distributed.launch --nproc_per_node --master_addr --master_port 29500 train_dp.py --config ./configs/config.py
+```
+
+
+
+## Pipeline Parallelism
+Aside from data parallelism, Colossal-AI also support pipleline parallelism. In specific, Colossal-AI uses 1F1B pipeline introduced by NVIDIA. For more details, you can view the related [documents](https://www.colossalai.org/tutorials/features/pipeline_parallel).
+
+### Define your configuration file(`hybrid_parallel/configs/vit_pipeline.py`)
+To apply pipleline parallel on the data parallel basis, you only need to add a **parallel dict**
+```python
+from colossalai.amp import AMP_TYPE
+
+parallel = dict(
+ pipeline=2
+)
+# pipeline config
+NUM_MICRO_BATCHES = parallel['pipeline']
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE)
+
+fp16 = dict(mode=AMP_TYPE.NAIVE)
+clip_grad_norm = 1.0
+```
+
+Other configs:
+```python
+# hyperparameters
+# BATCH_SIZE is as per GPU
+# global batch size = BATCH_SIZE x data parallel size
+BATCH_SIZE = 256
+LEARNING_RATE = 3e-3
+WEIGHT_DECAY = 0.3
+NUM_EPOCHS = 300
+WARMUP_EPOCHS = 32
+
+# model config
+IMG_SIZE = 224
+PATCH_SIZE = 16
+HIDDEN_SIZE = 768
+DEPTH = 12
+NUM_HEADS = 12
+MLP_RATIO = 4
+NUM_CLASSES = 10
+CHECKPOINT = True
+SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
+```
+
+### Build pipeline model (`/hybrid_parallel/model/vit.py`)
+Colossal-AI provides two methods to build a pipeline model from the existing model.
+- `colossalai.builder.build_pipeline_model_from_cfg`
+- `colossalai.builder.build_pipeline_model`
+
+Besides, you can also build a pipeline model from scrath with Colossal-AI.
+```python
+import math
+from typing import Callable
+
+import inspect
+import torch
+from colossalai import nn as col_nn
+from colossalai.registry import LAYERS, MODELS
+from colossalai.logging import get_dist_logger
+from colossalai.core import global_context as gpc
+from colossalai.context import ParallelMode
+from colossalai.builder.pipeline import partition_uniform
+from torch import dtype, nn
+from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead
+
+
+@MODELS.register_module
+class PipelineVisionTransformer(nn.Module):
+ def __init__(self,
+ img_size: int = 224,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ depth: int = 12,
+ num_heads: int = 12,
+ dim: int = 768,
+ mlp_ratio: int = 4,
+ attention_dropout: float = 0.,
+ dropout: float = 0.1,
+ drop_path: float = 0.,
+ layernorm_epsilon: float = 1e-6,
+ activation: Callable = nn.functional.gelu,
+ representation_size: int = None,
+ dtype: dtype = None,
+ bias: bool = True,
+ checkpoint: bool = False,
+ init_method: str = 'torch',
+ first_stage=True,
+ last_stage=True,
+ start_idx=None,
+ end_idx=None,):
+ super().__init__()
+
+ layers = []
+
+ if first_stage:
+ embed = ViTEmbedding(img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embedding_dim=dim,
+ dropout=dropout,
+ dtype=dtype,
+ init_method=init_method)
+ layers.append(embed)
+
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
+
+ if start_idx is None and end_idx is None:
+ start_idx = 0
+ end_idx = depth
+
+ blocks = [
+ ViTBlock(
+ dim=dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ attention_dropout=attention_dropout,
+ dropout=dropout,
+ drop_path=dpr[i],
+ activation=activation,
+ dtype=dtype,
+ bias=bias,
+ checkpoint=checkpoint,
+ init_method=init_method,
+ ) for i in range(start_idx, end_idx)
+ ]
+ layers.extend(blocks)
+
+ if last_stage:
+ norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
+ head = ViTHead(dim=dim,
+ num_classes=num_classes,
+ representation_size=representation_size,
+ dtype=dtype,
+ bias=bias,
+ init_method=init_method)
+ layers.extend([norm, head])
+
+ self.layers = nn.Sequential(
+ *layers
+ )
+
+ def forward(self, x):
+ x = self.layers(x)
+ return x
+
+
+def _filter_kwargs(func, kwargs):
+ sig = inspect.signature(func)
+ return {k: v for k, v in kwargs.items() if k in sig.parameters}
+
+
+def _build_pipeline_vit(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ logger = get_dist_logger()
+ if gpc.is_initialized(ParallelMode.PIPELINE):
+ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
+ pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
+ else:
+ pipeline_size = 1
+ pipeline_rank = 0
+ rank = gpc.get_global_rank()
+ parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
+ models = []
+
+ for start, end in parts:
+ kwargs['first_stage'] = start == 0
+ kwargs['last_stage'] = end == num_layers
+ kwargs['start_idx'] = start
+ kwargs['end_idx'] = end
+ logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
+ chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
+ models.append(chunk)
+ if len(models) == 1:
+ model = models[0]
+ else:
+ model = nn.ModuleList(models)
+ return model
+
+
+def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ return _build_pipeline_vit(PipelineVisionTransformer, num_layers, num_chunks, device, **kwargs)
+```
+
+### Modify train script (`/hybrid_parallel/train_with_cifar10.py`)
+
+#### Import modules
+```python
+from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+ PipelineSchedule)
+from colossalai.utils import MultiTimer
+import os
+
+import colossalai
+
+import torch
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.nn import CrossEntropyLoss
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.utils import is_using_pp, get_dataloader
+from model.vit import build_pipeline_vit
+from model_zoo.vit.vit import _create_vit_model
+from tqdm import tqdm
+
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+```
+
+#### Launch Colossal-AI
+`colossalai.utils.is_using_pp` can help check whether pipeline parallelism is required in config file.
+
+```python
+# initialize distributed setting
+parser = colossalai.get_default_parser()
+args = parser.parse_args()
+
+# launch from torch
+colossalai.launch_from_torch(config=args.config)
+
+# get logger
+logger = get_dist_logger()
+logger.info("initialized distributed environment", ranks=[0])
+
+if hasattr(gpc.config, 'LOG_PATH'):
+ if gpc.get_global_rank() == 0:
+ log_path = gpc.config.LOG_PATH
+ if not os.path.exists(log_path):
+ os.mkdir(log_path)
+ logger.log_to_file(log_path)
+
+use_pipeline = is_using_pp()
+```
+
+#### Define model
+
+```python
+# create model
+model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
+ patch_size=gpc.config.PATCH_SIZE,
+ dim=gpc.config.HIDDEN_SIZE,
+ depth=gpc.config.DEPTH,
+ num_heads=gpc.config.NUM_HEADS,
+ mlp_ratio=gpc.config.MLP_RATIO,
+ num_classes=gpc.config.NUM_CLASSES,
+ init_method='jax',
+ checkpoint=gpc.config.CHECKPOINT)
+
+if use_pipeline:
+ model = build_pipeline_vit(num_layers=model_kwargs['depth'], num_chunks=1, **model_kwargs)
+else:
+ model = _create_vit_model(**model_kwargs)
+```
+
+#### Count number of parameters
+
+You can count model parameters on different pipeline stages easily.
+
+```
+# count number of parameters
+total_numel = 0
+for p in model.parameters():
+ total_numel += p.numel()
+if not gpc.is_initialized(ParallelMode.PIPELINE):
+ pipeline_stage = 0
+else:
+ pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
+logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
+```
+
+#### Build dataloader, optimizer, etc.
+
+```python
+def build_cifar(batch_size):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(224, pad_if_needed=True),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
+ train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
+ test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
+ return train_dataloader, test_dataloader
+
+
+# craete dataloaders
+train_dataloader , test_dataloader = build_cifar()
+
+# create loss function
+criterion = CrossEntropyLoss(label_smoothing=0.1)
+
+# create optimizer
+optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
+
+# create lr scheduler
+lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
+ total_steps=gpc.config.NUM_EPOCHS,
+ warmup_steps=gpc.config.WARMUP_EPOCHS)
+```
+
+#### Start Colossal-AI engine
+
+```python
+# intiailize
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader,
+ test_dataloader=test_dataloader)
+
+logger.info("Engine is built", ranks=[0])
+```
+
+#### Train: based on engine
+
+In the data parallelism example, we show how to train a model with Trainer API. We can also directly train a model based on engine. In this way, you can customize your training with more features.
+
+```python
+data_iter = iter(train_dataloader)
+
+for epoch in range(gpc.config.NUM_EPOCHS):
+ # training
+ engine.train()
+
+ if gpc.get_global_rank() == 0:
+ description = 'Epoch {} / {}'.format(
+ epoch,
+ gpc.config.NUM_EPOCHS
+ )
+ progress = tqdm(range(len(train_dataloader)), desc=description)
+ else:
+ progress = range(len(train_dataloader))
+ for _ in progress:
+ engine.zero_grad()
+ engine.execute_schedule(data_iter, return_output_label=False)
+ engine.step()
+ lr_scheduler.step()
+```
+
+### Start training
+```bash
+export DATA=
+# If your torch >= 1.10.0
+torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_pipeline_parallel.py
+# If your torch >= 1.9.0
+# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_pipeline_parallel.py
+```
+
+
+
+
+## Tensor Parallelism and Hybrid Parallelism
+Tensor parallelism partitions each weight parameter across multiple devices in order to reduce memory load. Colossal-AI support 1D, 2D, 2.5D and 3D tensor parallelism. Besides, you can combine tensor parallelism with pipeline parallelism and data parallelism to reach hybrid parallelism. Colossal-AI also provides an easy way to apply tensor parallelism and hybrid parallelism. On the basis of pipeline parallelism, a few lines of code changing in config file is all you need.
+
+### Define your configuration file(`/hybrid_parallel/configs/vit_1d_tp2_pp2.py`)
+To use tensor parallelism, you only need to add related information to the **parallel dict**. To be specific, `TENSOR_PARALLEL_MODE` can be '1d', '2d', '2.5d', '3d'. And the size of different parallelism should satisfy: `#GPUs = pipeline parallel size x tensor parallel size x data parallel size`. `data parallel size` will automatically computed after you specify the number of GPUs, pipeline parallel size and tensor parallel size.
+
+```python
+from colossalai.amp import AMP_TYPE
+# parallel setting
+TENSOR_PARALLEL_SIZE = 2
+TENSOR_PARALLEL_MODE = '1d'
+
+parallel = dict(
+ pipeline=2,
+ tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE)
+)
+
+fp16 = dict(mode=AMP_TYPE.NAIVE)
+clip_grad_norm = 1.0
+
+
+# pipeline config
+NUM_MICRO_BATCHES = parallel['pipeline']
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE)
+```
+
+Ohter configs:
+```python
+# hyperparameters
+# BATCH_SIZE is as per GPU
+# global batch size = BATCH_SIZE x data parallel size
+BATCH_SIZE = 256
+LEARNING_RATE = 3e-3
+WEIGHT_DECAY = 0.3
+NUM_EPOCHS = 300
+WARMUP_EPOCHS = 32
+
+# model config
+IMG_SIZE = 224
+PATCH_SIZE = 16
+HIDDEN_SIZE = 768
+DEPTH = 12
+NUM_HEADS = 12
+MLP_RATIO = 4
+NUM_CLASSES = 10
+CHECKPOINT = True
+SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
+```
+
+### Start training
+```bash
+export DATA=
+# If your torch >= 1.10.0
+torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_hybrid_parallel.py
+# If your torch >= 1.9.0
+# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py
+```
diff --git a/docs/source/en/basics/colotensor_concept.md b/docs/source/en/basics/colotensor_concept.md
new file mode 100644
index 000000000000..2d8acd88dfd4
--- /dev/null
+++ b/docs/source/en/basics/colotensor_concept.md
@@ -0,0 +1,97 @@
+# ColoTensor Concepts
+
+Author: [Jiarui Fang](https://github.com/feifeibear), [Hongxin Liu](https://github.com/ver217) and [Haichen Huang](https://github.com/1SAA)
+
+**Prerequisite:**
+- [Colossal-AI Overview](../concepts/colossalai_overview.md)
+- [Distributed Training](../concepts/distributed_training.md)
+- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
+
+## Introduction
+
+After ColossalAI version 0.1.8, [ColoTensor](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ColoTensor) becomes the basic data structure for tensors in ColossalAI. It is a subclass of torch.Tensor and can be used as a PyTorch Tensor. Additionally, some unique features make it possible to represent a Global Tensor with a payload distributed across multiple GPU devices. With the help of ColoTensor, the users can write distributed DNN training program similar to a serial one.support the following features.
+
+ColoTensor contains extra attributes capsuled in a [ColoTensorSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.tensor_spec.html#colossalai.tensor.tensor_spec.ColoTensorSpec) instance to describe the tensor's payload distribution and computing pattern.
+
+- ProcessGroup: how processes are organized as communication groups.
+- Distributed Spec: how tensor is distributed among process groups.
+- Compute Spec: how the tensor is used during computation.
+
+We elaborate on them one by one.
+
+## ProcessGroup
+
+An instance of class [ProcessGroup](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ProcessGroup) describes how processes are organized in process groups. Processes in a process group can participate in the same collective communication operations together, such as allgather, allreduce, etc. The way the process group is organized is dominated by the Tensor's parallelism strategy. For example, if the user defines the tensor parallel (TP) and data parallel (DP) modes of a tensor, then the process organization of the process group will be automatically deduced. The process group settings can vary among different tensors. Therefore, it enables us to support more complicated hybrid parallel. The pipeline parallel (PP) definition is not in the ProcessGroup, it needs another set of mechanisms . We will supplement the related content of ColoTensor applied to PP in the future.
+
+Currently, a process group of ColoTensor is defined by two configurations, i.e. tp_degree and dp_degree. In the case of DP+TP hybrid parallelism, the device can be viewed as a 2D mesh. We place TP communication groups on the leading low dimension of the device mesh and then place the data parallel groups along the high dimension of the device mesh. The reason is that tensor parallelism has a larger communication overhead than data parallelism. Neighboring devices are placed inside a TP process group and are often placed in the same node.
+
+Considering that 8 processes are configured as tp_degree=4, and dp_degree=2, the layout is shown below. Process group tp0 contains gpu 0,1,2,3. Process dp1 contains gpu 1 and 5.
+
+
+
+Process Group using tp_degree=4, dp_degree=2
+
+
+## Distributed Spec
+
+An instance of [Distributed Spec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html) describes how a ColoTensor is distributed among the ProcessGroup.
+
+How tensors are distributed among DP process groups is automatically derived and does not need to be manually specified by the user. If this tensor is a model parameter, it is replicated within the DP process group. If it is an activation tensor, it is split along the process with the highest dimension and evenly distributed the tensor payload among processes in the DP process group.
+
+Therefore, when using Distributed Spec, we only need to describe the way that the tensor is distributed among TP process groups. There are currently two ways to distribute among TP process group, i.e. [ShardSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ShardSpec) and [ReplicaSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ReplicaSpec). ShardSpec needs to specify the dimension index dim of the partition and the number of partitions num_partitions. Currently, we only support the split on a single dim. Different dist specs on the TP process groups can be converted to each other through the set_dist_spec() interface. The spec conversions are recorded by the autograd mechanism and it will trigger corresponding reverse operations during backward propagation.
+
+## Compute Spec
+
+An instance of class [ComputeSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.compute_spec.html#colossalai.tensor.compute_spec.ComputeSpec) describes how a Coloensor be used in DNN training. Currently, we will set the correct Compute Pattern for the ColoTensor as the parameters of the module. The specific application scenarios will be shown in the next document.
+
+## ColoParameter
+
+[ColoParameter](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.colo_parameter.html#colossalai.tensor.colo_parameter.ColoParameter) is a subclass of ColoTensor. Used to define a Global Parameter tensor. Its relationship with ColoTensor is consistent with Torch.Tensor and torch.Parameter. The latter allows the tensor to appear in the return values of the module's parameters() and name_parameters() methods.
+
+## Example
+
+Let's see an example. A ColoTensor is initialized and sharded on 8 GPUs using tp_degree=4, dp_dgree=2. And then the tensor is sharded along the last dim among the TP process groups. Finally, we reshard it along the first dim (0 dim) among the TP process groups. We encourage users to run the code and observe the shape of each tensor.
+
+
+```python
+import torch
+import torch.multiprocessing as mp
+from colossalai.utils import free_port, print_rank_0
+from functools import partial
+
+import colossalai
+from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
+from colossalai.utils import free_port
+
+import torch
+
+def run_dist_tests(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ pg = ProcessGroup(tp_degree=2, dp_degree=2)
+
+ torch.manual_seed(0)
+ local_tensor = torch.randn(2, 3, 1).cuda()
+ print_rank_0(f"shape {local_tensor.shape}, {local_tensor.data}")
+
+ spec = ColoTensorSpec(pg, ShardSpec(dims=[-1], num_partitions=[pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
+ t1 = ColoTensor.from_torch_tensor(local_tensor, spec)
+ t1 = t1.to_replicate()
+ print_rank_0(f"shape {t1.shape}, {t1.data}")
+
+ spec2 = ShardSpec([0], [pg.tp_world_size()])
+ t1.set_dist_spec(spec2)
+ print_rank_0(f"shape {t1.shape}, {t1.data}")
+
+def test_dist_cases(world_size):
+ run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+if __name__ == '__main__':
+ test_dist_cases(4)
+```
+
+:::caution
+
+The ColoTensor is an experimental feature and may be updated.
+
+:::
diff --git a/docs/source/en/basics/command_line_tool.md b/docs/source/en/basics/command_line_tool.md
new file mode 100644
index 000000000000..48b199cf78e9
--- /dev/null
+++ b/docs/source/en/basics/command_line_tool.md
@@ -0,0 +1,53 @@
+# Command Line Tool
+
+Author: Shenggui Li
+
+**Prerequisite:**
+- [Distributed Training](../concepts/distributed_training.md)
+- [Colossal-AI Overview](../concepts/colossalai_overview.md)
+
+## Introduction
+
+Colossal-AI provides command-line utilities for the user.
+The current command line tools support the following features.
+
+- verify Colossal-AI build
+- launch distributed jobs
+- tensor parallel micro-benchmarking
+
+## Check Installation
+
+To verify whether your Colossal-AI is built correctly, you can use the command `colossalai check -i`.
+This command will inform you information regarding the version compatibility and cuda extension.
+
+
+
+Check Installation Demo
+
+
+## Launcher
+
+To launch distributed jobs on single or multiple nodes, the command `colossalai run` can be used for process launching.
+You may refer to [Launch Colossal-AI](./launch_colossalai.md) for more details.
+
+## Tensor Parallel Micro-Benchmarking
+
+As Colossal-AI provides an array of tensor parallelism methods, it is not intuitive to choose one for your hardware and
+model. Therefore, we provide a simple benchmarking to evaluate the performance of various tensor parallelisms on your system.
+This benchmarking is run on a simple MLP model where the input data is of the shape `(batch_size, seq_length, hidden_size)`.
+Based on the number of GPUs, the CLI will look for all possible tensor parallel configurations and display the benchmarking results.
+You can customize the benchmarking configurations by checking out `colossalai benchmark --help`.
+
+```shell
+# run on 4 GPUs
+colossalai benchmark --gpus 4
+
+# run on 8 GPUs
+colossalai benchmark --gpus 8
+```
+
+:::caution
+
+Only single-node benchmarking is supported currently.
+
+:::
diff --git a/docs/source/en/basics/configure_parallelization.md b/docs/source/en/basics/configure_parallelization.md
new file mode 100644
index 000000000000..4ac0299eac14
--- /dev/null
+++ b/docs/source/en/basics/configure_parallelization.md
@@ -0,0 +1,156 @@
+# Configure Parallelization
+
+Author: Shenggui Li, Siqi Mai
+
+**Prerequisite:**
+- [Distributed Training](../concepts/distributed_training.md)
+- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
+- [Define Your Configuration](./define_your_config.md)
+
+
+## Introduction
+
+We support multiple parallelization in Colossal-AI. Hybrid parallelism in our codebase refers to namely the combination
+of data parallelism, pipeline parallelism and tensor parallelism (1D, 2D, 2.5D, 3D).
+
+Each parallelism requires different network topology and thus initialize different process groups.
+You can initialize the corresponding process group by setting `parallel` in the config file.
+The configuration for `parallel` must obey the following format. Data parallel size will be
+inferred automatically based on your inputs to pipeline parallelism and tensor parallelism.
+`colossalai.launch` will initialize these distributed process groups automatically based on your configuration.
+
+Some sample configurations are shown below:
+
+```python
+# sampler format
+parallel = dict(
+ pipeline=dict("size": int),
+ tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any)
+)
+
+# this is ok
+parallel = dict(
+ pipeline=dict(size=2),
+ tensor=dict(size=4, mode='2d')
+)
+
+# this is ok
+parallel = dict(
+ pipeline=2,
+ tensor=dict(size=4, mode='2d')
+)
+
+# this is not ok
+# as you need to specify the mode for tensor parallelism
+parallel = dict(
+ pipeline=2,
+ tensor=4
+)
+
+# this is ok as well as tensor will be default to size 1
+# and mode None
+parallel = dict(
+ pipeline=2
+)
+
+# this is ok as well as pipeline will default to size 1
+parallel = dict(
+ tensor=dict(size=4, mode='2d')
+)
+
+```
+
+The key name `size` refers to the parallel size of the parallelism dimension. For example, pipeline size 2 means there
+will be 2 pipeline stages. The key name `mode` in tensor parallel config means the corresponding tensor parallelism
+will be initialized.
+
+**You can choose to not have 'parallel' in your configuration and both pipeline and tensor will default to size 1.**
+
+**Total number of GPUs must be equal to `data parallel size * tensor parallel size * pipeline parallel size`**
+
+## Data Parallel
+
+Data parallel is the most common way to distribute your training task by splitting data into several shards and train on
+a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not
+have to explicitly set them in your configurations. There are two ways to handle the all-reduce in data parallel in Colossal-AI.
+
+1. If you specify gradient handlers, gradients will be all-reduced according to the gradient handlers
+2. Otherwise, PyTorch DistributedDataParallel will be used
+
+In most cases, you will be using the second mode unless you have complex handling of the gradients.
+
+## 1D, 2D, 2.5D and 3D Parallel
+
+To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
+tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
+
+- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
+
+- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
+ 2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer
+ outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of `P = N^2` devices where
+ `N` is the number of tensor chunks in a single dimension.
+
+- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
+ Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which
+ further parallelizes 2D tensor parallelism. An amount of `P = N^2 ∗ d` processors are arranged into `d` layers, where
+ each layer performs matrix multiplication operations independently with a dimension `N`.
+
+- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
+ We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method
+ achieves the optimal, `O(P^{1/3})` communication overhead on $P$ processors, while both computation and memory usage
+ are evenly distributed through optimized load balancing of parameters as well as activations.
+
+```python
+# 1D parallel
+parallel = dict(
+ tensor=dict(size=4, mode='1d')
+)
+
+# 2D parallel
+parallel = dict(
+ tensor=dict(size=4, mode='2d')
+)
+
+# 2.5D parallel
+parallel = dict(
+ tensor=dict(size=8, mode='2.5d', depth=2)
+)
+
+# 3D parallel
+parallel = dict(
+ tensor=dict(size=8, mode='3d')
+)
+```
+
+Once you specify the tensor parallel mode in your configuration, you can proceed to use its corresponding distributed
+operator. For example, if you mode is '2d', you can use `colossalai.nn.Linear2D` in you model construction.
+
+
+## Pipeline Parallel
+
+Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple
+model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU
+and the second layer to the second GPU.
+
+You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
+will automatically creates the pipeline schedule which defines the forward and backward step.
+
+```python
+parallel = dict(
+ pipeline=dict(size=4), # number of pipeline stages
+)
+```
+
+## Sequence Parallel
+
+Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging.
+This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120).
+You can use specify the mode to be `sequence` to initialize its process group.
+
+
+```python
+parallel = dict(
+ tensor=dict(size=4, mode='sequence')
+)
+```
diff --git a/docs/source/en/basics/define_your_config.md b/docs/source/en/basics/define_your_config.md
new file mode 100644
index 000000000000..d2569691b7dc
--- /dev/null
+++ b/docs/source/en/basics/define_your_config.md
@@ -0,0 +1,82 @@
+# Define Your Configuration
+
+Author: Guangyang Lu, Shenggui Li, Siqi Mai
+
+**Prerequisite:**
+- [Distributed Training](../concepts/distributed_training.md)
+- [Colossal-AI Overview](../concepts/colossalai_overview.md)
+
+
+## Introduction
+
+In Colossal-AI, a configuration file is required to specify the features the system will inject into the training process.
+In this tutorial, we will introduce you how to construct your configuration file and how this config file will be used.
+Using configuration file has several advantages:
+
+1. You can store your feature configuration and training hyper-parameters in different configuration files
+2. New features released in the future can be specified in the configuration without code change in the training script
+
+In this tutorial, we will cover how to define your configuration file.
+
+## Configuration Definition
+
+In a configuration file, there are two types of variables. One serves as feature specification and the other serves
+as hyper-parameters. All feature-related variables are reserved keywords. For example, if you want to use mixed precision
+training, you need to use the variable name `fp16` in the config file and follow a pre-defined format.
+
+### Feature Specification
+
+There is an array of features Colossal-AI provides to speed up training. Each feature is defined by a corresponding field
+in the config file. In this tutorial, we are not giving the config details for all the features, but rather we are providing
+an illustration of how to specify a feature. **The details of each feature can be found in its respective tutorial.**
+
+To illustrate the use of config file, we use mixed precision training as an example here. In order to do so, you need to
+follow the steps below.
+
+1. create a configuration file (e.g. `config.py`, the file name can be anything)
+2. define the mixed precision configuration in the config file. For example, in order to use mixed precision training
+natively provided by PyTorch, you can just write these lines of code below into your config file.
+
+ ```python
+ from colossalai.amp import AMP_TYPE
+
+ fp16 = dict(
+ mode=AMP_TYPE.TORCH
+ )
+ ```
+
+3. Tell Colossal-AI where your config file is when launch the distributed environment. For example, the config file is in
+the current directory.
+
+ ```python
+ import colossalai
+
+ colossalai.launch(config='./config.py', ...)
+ ```
+
+In this way, Colossal-AI knows what features you want to use and will inject this feature during `colossalai.initialize`.
+
+### Global Hyper-parameters
+
+Besides feature specification, the config file can also serve as a place to define your training hyper-parameters. This
+comes handy when you want to perform multiple experiments, each experiment details can be put into a single config file
+to avoid confusion. These parameters will be stored in the global parallel context and can be accessed in the training script.
+
+For example, you can specify the batch size in your config file.
+
+```python
+BATCH_SIZE = 32
+```
+
+After launch, you are able to access your hyper-parameters through global parallel context.
+
+```python
+import colossalai
+from colossalai.core import global_context as gpc
+
+colossalai.launch(config='./config.py', ...)
+
+# access your parameter
+print(gpc.config.BATCH_SIZE)
+
+```
diff --git a/docs/source/en/basics/engine_trainer.md b/docs/source/en/basics/engine_trainer.md
new file mode 100644
index 000000000000..39792f622aa9
--- /dev/null
+++ b/docs/source/en/basics/engine_trainer.md
@@ -0,0 +1,387 @@
+# Use Engine and Trainer in Training
+
+Author: Shenggui Li, Siqi Mai
+
+**Prerequisite:**
+- [Initialize Features](./initialize_features.md)
+
+## Introduction
+
+In this tutorial, you will learn how to use the engine and trainer provided in Colossal-AI to train your model.
+Before we delve into the details, we would like to first explain the concept of engine and trainer.
+
+### Engine
+
+Engine is essentially a wrapper class for model, optimizer and loss function.
+When we call `colossalai.initialize`, an engine object will be returned, and it has already been equipped with
+functionalities such as gradient clipping, gradient accumulation and zero optimizer as specified in your configuration file.
+An engine object will use similar APIs to those of PyTorch training components such that the user has minimum change
+to their code.
+
+Below is a table which shows the commonly used APIs for the engine object.
+
+| Component | Function | PyTorch | Colossal-AI |
+| ------------------------------------- | --------------------------------------------- | ------------------------------- | -------------------------------------- |
+| optimizer | Set all gradients to zero before an iteration | optimizer.zero_grad() | engine.zero_grad() |
+| optimizer | Update the parameters | optimizer.step() | engine.step() |
+| model | Run a forward pass | outputs = model(inputs) | outputs = engine(inputs) |
+| criterion | Calculate the loss value | loss = criterion(output, label) | loss = engine.criterion(output, label) |
+| criterion | Execute back-propagation on the model | loss.backward() | engine.backward(loss) |
+
+The reason why we need such an engine class is that we can add more functionalities while hiding the implementations in
+the `colossalai.initialize` function.
+Imaging we are gonna add a new feature, we can manipulate the model, optimizer, dataloader and loss function in the
+`colossalai.initialize` function and only expose an engine object to the user.
+The user only needs to modify their code to the minimum extent by adapting the normal PyTorch APIs to the Colossal-AI
+engine APIs. In this way, they can enjoy more features for efficient training.
+
+A normal training iteration using engine can be:
+
+```python
+import colossalai
+
+# build your model, optimizer, criterion, dataloaders
+...
+
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader,
+ test_dataloader)
+for img, label in train_dataloader:
+ engine.zero_grad()
+ output = engine(img)
+ loss = engine.criterion(output, label)
+ engine.backward(loss)
+ engine.step()
+```
+
+### Trainer
+
+Trainer is a more high-level wrapper for the user to execute training with fewer lines of code. However, in pursuit of more abstraction, it loses some flexibility compared to engine. The trainer is designed to execute a forward and backward step to perform model weight update. It is easy to create a trainer object by passing the engine object. The trainer has a default value `None` for the argument `schedule`. In most cases, we leave this value to `None` unless we want to use pipeline parallelism. If you wish to explore more about this parameter, you can go to the tutorial on pipeline parallelism.
+
+```python
+from colossalai.logging import get_dist_logger
+from colossalai.trainer import Trainer, hooks
+
+# build components and initialize with colossalai.initialize
+...
+
+# create a logger so that trainer can log on the console
+logger = get_dist_logger()
+
+# create a trainer object
+trainer = Trainer(
+ engine=engine,
+ logger=logger
+)
+```
+
+
+
+In trainer, the user can customize some hooks and attach these hooks to the trainer object. A hook object will execute life-cycle methods periodically based on the training scheme. For example, The `LRSchedulerHook` will execute `lr_scheduler.step()` to update the learning rate of the model during either `after_train_iter` or `after_train_epoch` stages depending on whether the user wants to update the learning rate after each training iteration or only after the entire training epoch. You can store the hook objects in a list and pass it to `trainer.fit` method. `trainer.fit` method will execute training and testing based on your parameters. If `display_process` is True, a progress bar will be displayed on your console to show the training process.
+
+```python
+# define the hooks to attach to the trainer
+hook_list = [
+ hooks.LossHook(),
+ hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
+ hooks.AccuracyHook(accuracy_func=Accuracy()),
+ hooks.LogMetricByEpochHook(logger),
+]
+
+# start training
+trainer.fit(
+ train_dataloader=train_dataloader,
+ epochs=NUM_EPOCHS,
+ test_dataloader=test_dataloader,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True
+)
+```
+
+If you want to customize your own hook class, you can inherit `hooks.BaseHook` and override the life-cycle methods of your interest. A dummy example to demonstrate how to create a simple log message hook is provided below for your reference.
+
+```python
+from colossalai.logging import get_dist_logger
+from colossalai.trainer import hooks
+
+class LogMessageHook(hooks.BaseHook):
+
+ def __init__(self, priority=10):
+ self._logger = get_dist_logger()
+
+ def before_train(self, trainer):
+ self._logger.info('training starts')
+
+ def after_train(self, trainer):
+ self._logger.info('training finished')
+
+
+...
+
+# then in your training script
+hook_list.append(LogMessageHook())
+```
+
+
+
+In the sections below, I will guide you through the steps required to train a ResNet model with both engine and trainer.
+
+
+
+## Explain with ResNet
+
+### Overview
+
+In this section we will cover:
+
+1. Use an engine object to train a ResNet34 model on CIFAR10 dataset
+2. Use a trainer object to train a ResNet34 model on CIFAR10 dataset
+
+The project structure will be like:
+
+```bash
+-- config.py
+-- run_resnet_cifar10_with_engine.py
+-- run_resnet_cifar10_with_trainer.py
+```
+
+Steps 1-4 below are commonly used regardless of using engine or trainer. Thus, steps 1-4 + step 5 will be your `run_resnet_cifar10_with_engine.py` and steps 1-4 + step 6 will form `run_resnet_cifar10_with_trainer.py`.
+
+### Hands-on Practice
+
+#### Step 1. Create a Config File
+
+In your project folder, create a `config.py`. This file is to specify some features you may want to use to train your model. A sample config file is as below:
+
+```python
+from colossalai.amp import AMP_TYPE
+
+BATCH_SIZE = 128
+NUM_EPOCHS = 200
+
+fp16=dict(
+ mode=AMP_TYPE.TORCH
+)
+```
+
+In this config file, we specify that we want to use batch size 128 per GPU and run for 200 epochs. These two parameters are exposed by `gpc.config`. For example, you can use `gpc.config.BATCH_SIZE` to access the value you store in your config file. The `fp16` configuration tells `colossalai.initialize` to use mixed precision training provided by PyTorch to train the model with better speed and lower memory consumption.
+
+#### Step 2. Initialize Distributed Environment
+
+We need to initialize the distributed training environment. This has been introduced in the tutorial on how to
+[launch Colossal-AI](./launch_colossalai.md). For this demostration, we use `launch_from_torch` and PyTorch launch utility.
+
+```python
+import colossalai
+
+# ./config.py refers to the config file we just created in step 1
+colossalai.launch_from_torch(config='./config.py')
+```
+
+#### Step 3. Create all the training components
+
+In this step, we can create all the components used for training. These components include:
+
+1. Model
+2. Optimizer
+3. Criterion/loss function
+4. Training/Testing dataloaders
+5. Learning rate Scheduler
+6. Logger
+
+
+
+To build these components, you need to import the following modules:
+
+```python
+from pathlib import Path
+from colossalai.logging import get_dist_logger
+import torch
+import os
+from colossalai.core import global_context as gpc
+from colossalai.utils import get_dataloader
+from torchvision import transforms
+from colossalai.nn.lr_scheduler import CosineAnnealingLR
+from torchvision.datasets import CIFAR10
+from torchvision.models import resnet34
+```
+
+
+
+Then build your components in the same way as how to normally build them in your PyTorch scripts. In the script below, we set the root path for CIFAR10 dataset as an environment variable `DATA`. You can change it to any path you like, for example, you can change `root=Path(os.environ['DATA'])` to `root='./data'` so that there is no need to set the environment variable.
+
+```python
+# build logger
+logger = get_dist_logger()
+
+# build resnet
+model = resnet34(num_classes=10)
+
+# build datasets
+train_dataset = CIFAR10(
+ root='./data',
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.RandomCrop(size=32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
+ 0.2023, 0.1994, 0.2010]),
+ ]
+ )
+)
+
+test_dataset = CIFAR10(
+ root='./data',
+ train=False,
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
+ 0.2023, 0.1994, 0.2010]),
+ ]
+ )
+)
+
+# build dataloaders
+train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=gpc.config.BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ )
+
+test_dataloader = get_dataloader(dataset=test_dataset,
+ add_sampler=False,
+ batch_size=gpc.config.BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ )
+
+# build criterion
+criterion = torch.nn.CrossEntropyLoss()
+
+# optimizer
+optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
+
+# lr_scheduler
+lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
+```
+
+#### Step 4. Initialize with Colossal-AI
+
+Next, the essential step is to obtain the engine class by calling `colossalai.initialize`. As stated in `config.py`, we will be using mixed precision training for training ResNet34 model. `colossalai.initialize` will automatically check your config file and assign relevant features to your training components. In this way, our engine object has already been able to train with mixed precision, but you do not have to explicitly take care of it.
+
+```python
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader,
+ test_dataloader,
+ )
+```
+
+
+
+#### Step 5. Train with engine
+
+With all the training components ready, we can train ResNet34 just like how to normally deal with PyTorch training.
+
+```python
+for epoch in range(gpc.config.NUM_EPOCHS):
+ # execute a training iteration
+ engine.train()
+ for img, label in train_dataloader:
+ img = img.cuda()
+ label = label.cuda()
+
+ # set gradients to zero
+ engine.zero_grad()
+
+ # run forward pass
+ output = engine(img)
+
+ # compute loss value and run backward pass
+ train_loss = engine.criterion(output, label)
+ engine.backward(train_loss)
+
+ # update parameters
+ engine.step()
+
+ # update learning rate
+ lr_scheduler.step()
+
+ # execute a testing iteration
+ engine.eval()
+ correct = 0
+ total = 0
+ for img, label in test_dataloader:
+ img = img.cuda()
+ label = label.cuda()
+
+ # run prediction without back-propagation
+ with torch.no_grad():
+ output = engine(img)
+ test_loss = engine.criterion(output, label)
+
+ # compute the number of correct prediction
+ pred = torch.argmax(output, dim=-1)
+ correct += torch.sum(pred == label)
+ total += img.size(0)
+
+ logger.info(
+ f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0])
+```
+
+#### Step 6. Train with trainer
+
+If you wish to train with a trainer object, you can follow the code snippet below:
+
+```python
+from colossalai.nn.metric import Accuracy
+from colossalai.trainer import Trainer, hooks
+
+
+# create a trainer object
+trainer = Trainer(
+ engine=engine,
+ logger=logger
+)
+
+# define the hooks to attach to the trainer
+hook_list = [
+ hooks.LossHook(),
+ hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
+ hooks.AccuracyHook(accuracy_func=Accuracy()),
+ hooks.LogMetricByEpochHook(logger),
+ hooks.LogMemoryByEpochHook(logger)
+]
+
+# start training
+# run testing every 1 epoch
+trainer.fit(
+ train_dataloader=train_dataloader,
+ epochs=gpc.config.NUM_EPOCHS,
+ test_dataloader=test_dataloader,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True
+)
+```
+
+
+
+#### Step 7. Start Distributed Training
+
+Lastly, we can invoke the scripts using the distributed launcher provided by PyTorch as we used `launch_from_torch` in Step 2. You need to replace `` with the number of GPUs available on your machine. This number can be 1 if you only want to use 1 GPU. If you wish to use other launchers, you can refer to the tutorial on How to Launch Colossal-AI.
+
+```bash
+# with engine
+python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py
+# with trainer
+python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
+```
diff --git a/docs/source/en/basics/initialize_features.md b/docs/source/en/basics/initialize_features.md
new file mode 100644
index 000000000000..e768d2022ad8
--- /dev/null
+++ b/docs/source/en/basics/initialize_features.md
@@ -0,0 +1,49 @@
+# Initialize Features
+
+Author: Shenggui Li, Siqi Mai
+
+**Prerequisite:**
+- [Distributed Training](../concepts/distributed_training.md)
+- [Colossal-AI Overview](../concepts/colossalai_overview.md)
+
+## Introduction
+
+In this tutorial, we will cover the use of `colossalai.initialize` which injects features into your training components
+(e.g. model, optimizer, dataloader) seamlessly. Calling `colossalai.initialize` is the standard procedure before you run
+into your training loops.
+
+In the section below, I will cover how `colossalai.initialize` works and what we should take note of.
+
+## Usage
+
+In a typical workflow, we will launch distributed environment at the beginning of our training script.
+Afterwards, we will instantiate our objects such as model, optimizer, loss function, dataloader etc. At this moment, `colossalai.initialize`
+can come in to inject features into these objects. A pseudo-code example is like below:
+
+```python
+import colossalai
+import torch
+...
+
+
+# launch distributed environment
+colossalai.launch(config='./config.py', ...)
+
+# create your objects
+model = MyModel()
+optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+criterion = torch.nn.CrossEntropyLoss()
+train_dataloader = MyTrainDataloader()
+test_dataloader = MyTrainDataloader()
+
+# initialize features
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader,
+ test_dataloader)
+```
+
+The `colossalai.initialize` function will return an `Engine` object. The engine object is a wrapper
+for model, optimizer and loss function. **The engine object will run with features specified in the config file.**
+More details about the engine can be found in the [Use Engine and Trainer in Training](./engine_trainer.md).
diff --git a/docs/source/en/basics/launch_colossalai.md b/docs/source/en/basics/launch_colossalai.md
new file mode 100644
index 000000000000..be487f8539a5
--- /dev/null
+++ b/docs/source/en/basics/launch_colossalai.md
@@ -0,0 +1,232 @@
+# Launch Colossal-AI
+
+Author: Chuanrui Wang, Shenggui Li, Siqi Mai
+
+**Prerequisite:**
+- [Distributed Training](../concepts/distributed_training.md)
+- [Colossal-AI Overview](../concepts/colossalai_overview.md)
+
+
+## Introduction
+
+As mentioned in the previous tutorials stated in the prerequisite, you need to initialize the distributed environment
+for Colossal-AI after your config file is prepared.
+We call this process `launch`.
+In this tutorial, you will learn how to launch Colossal-AI on your server, be it a small one or big one.
+
+In Colossal-AI, we provided several launch methods to initialize the distributed backend.
+In most cases, you can use `colossalai.launch` and `colossalai.get_default_parser` to pass the
+parameters via command line.
+If you happen to use launchers such as SLURM, OpenMPI and PyTorch launch utility,
+we also provide several launching helper methods to access the rank and world size from the environment variables
+set by these launchers directly for your convenience.
+
+In this tutorial we will cover how to launch Colossal-AI to initialize the distributed backends:
+- Launch with `colossalai.launch`
+- Launch with Colossal-AI CLI
+- Launch with SLURM
+- Launch with OpenMPI
+
+## Launch Distributed Environment
+
+In order to launch Colossal-AI, we need two types of arguments:
+1. config file
+2. distributed settings
+
+The config file is always required regardless of the launch method but distributed settings can vary. The config file
+can be a path to the configuration file or a Python dictionary. The distributed settings can be passed via command line
+or multi-process launchers.
+
+### Command Line Parser
+
+Before we jump to `launch`, we firstly need to understand what parameters we need for initialization.
+As stated in the `Basic Concepts in Distributed Training` section of [Distributed Training](../concepts/distributed_training.md),
+the important parameters are:
+
+1. host
+2. port
+3. rank
+4. world_size
+5. backend
+
+In Colossal-AI, we provided a command line parser which has added these arguments in advance. You can get this parser by calling
+`colossalai.get_default_parser()`. This parser is usually used with `colossalai.launch`.
+
+```python
+# add these lines in your train.py
+import colossalai
+
+# get default parser
+parser = colossalai.get_default_parser()
+
+# if you want to add your own arguments
+parser.add_argument(...)
+
+# parse arguments
+args = parser.parse_args()
+```
+
+Then in your terminal, you can pass in these arguments:
+```shell
+
+python train.py --host --rank --world_size --port --backend
+```
+
+`backend` is optional and the default value is `nccl`.
+
+### Native Launch
+
+To initialize the distributed environment, we provided a general `colossalai.launch` API. The `colossalai.launch` function takes in the parameters
+listed above and create a default process group in the communication network. This function is often used with the default
+parser for convenience.
+
+```python
+import colossalai
+
+# parse arguments
+args = colossalai.get_default_parser().parse_args()
+
+# launch distributed environment
+colossalai.launch(config=,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend
+)
+
+```
+
+
+### Launch with Colossal-AI CLI
+
+To enable easy launching on both single or multi nodes, we have implemented a launcher for Colossal-AI. This launcher is
+a wrapper of the torch distributed launch utility but enhanced with the capability of launching multi-node jobs easily.
+
+First, we need to set the launch method in our code. As this is a wrapper of the torch distributed launch utility, we will
+use `colossalai.launch_from_torch`. The arguments required for distributed environment such as rank, world size, host and port are all set by the PyTorch
+launcher and can be read from the environment variable directly.
+
+```python
+import colossalai
+
+colossalai.launch_from_torch(
+ config=,
+)
+```
+
+Next, we can easily start multiple processes with `colossalai run` in your terminal. Below is an example to run the code
+on a single node with 4 GPUs. You can change the number of GPUs by `nproc_per_node` and the default port by `master_port`.
+
+```shell
+# run on the local node with 4 GPUs (default port: 29500)
+colossalai run --nproc_per_node 4 train.py
+
+# run on the local node with 4 GPUs with a different port
+colossalai run --nproc_per_node 4 --master_port 29505 test.py
+```
+
+If you are in a cluster and want to launch multi-node training, the CLI can help you start processes on different nodes
+with one simple command. There are two ways you can launch multi-node jobs.
+
+- Run with `--hosts`
+
+This is suitable when you only have a few nodes. Let's say I have two nodes, namely `host1` and `host2`, I can start
+multi-node training with the following command. Compared to single-node training, you must specify the `master_addr`
+option, which is auto-set to localhost if running on a single node only.
+
+:::caution
+
+`master_addr` cannot be localhost when running on multiple nodes, it should be the hostname or IP address of a node.
+
+:::
+
+```shell
+# run on these two nodes
+colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py
+```
+- Run with `--hostfile`
+
+This method is suitable when you have a lot of nodes. The host file is a simple text file listing the available nodes.
+The list of nodes is commonly provided by cluster managers such as SLURM and PBS Pro. For example, you can get the list
+of nodes allocated to you via the environment variable `SLURM_NODELIST` in SLURM and `PBS_NODEFILE` in PBS Pro.
+Just do `echo $SLURM_NODELIST` or `cat $PBS_NODEFILE` to check it out. If you do not have such cluster managers, you can
+manually create one for your own use.
+
+The host file given to Colossal-AI launcher must be in the following format where each line is the host name of a node.
+
+```text
+host1
+host2
+```
+
+With the host file ready, we can launch multi-node jobs with the following commands. Just like using `--host`, you also
+need to specify the `master_addr` option. Some extra options are provided for `--hostfile` as listed below:
+
+- `--include`: specify the hosts to include for multi-node jobs. For example, if your host file has 8 nodes, but you
+happen to only want to run on 6 nodes instead, you can add `--include host1,host2,host3,...,host6` so that the job will only
+be launcher on the 6 nodes.
+- `--exclude`: specify the hosts to exclude for multi-node jobs. This is useful when some nodes are faulty. For example,
+if host1 GPU has some problems and you do not wish to run on host1 but all other nodes, you can add `--exclude host1` so that
+the job will only be launched on the remaining nodes.
+
+```shell
+# run with a hostfile
+colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 test.py
+
+# only include certain hosts to execute commands
+# this is used to manually select nodes to run
+colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --include host1 test.py
+
+# exclude certain hosts to execute commands
+# this can be used when certain nodes are faulty
+colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --exclude host2 test.py
+```
+
+### Launch with SLURM
+
+If you are on a system managed by the SLURM scheduler, you can also rely on the `srun` launcher to kickstart your Colossal-AI scripts.
+We provided the helper function `launch_from_slurm` for compatibility with the SLURM scheduler.
+`launch_from_slurm` will automatically read the rank and world size from the environment variables `SLURM_PROCID` and `SLURM_NPROCS` respectively
+and use them to start the distributed backend.
+Do this in your training script:
+
+```python
+import colossalai
+
+colossalai.launch_from_slurm(
+ config=,
+ host=args.host,
+ port=args.port
+)
+```
+
+You can initialize the distributed environment by using this command in terminal.
+
+```bash
+srun python train.py --host --port 29500
+```
+
+### Launch with OpenMPI
+If you are more familiar with OpenMPI, you can use `launch_from_openmpi` instead.
+`launch_from_openmpi` will automatically read the local rank, global rank and world size from the environment variables
+`OMPI_COMM_WORLD_LOCAL_RANK`, `MPI_COMM_WORLD_RANK` and `OMPI_COMM_WORLD_SIZE` respectively and
+use them to start the distributed backend.
+
+Do this in your train.py:
+```python
+colossalai.launch_from_openmpi(
+ config=,
+ host=args.host,
+ port=args.port
+)
+```
+
+A sample command to launch multiple processes with OpenMPI would be:
+
+```bash
+mpirun --hostfile -np python train.py --host --port 29500
+```
+
+- --hostfile: use this option to specify a list of hosts on which to run
+- --np: set the number of processes (GPUs) to launch in total. For example, if --np 4, 4 python processes will be initialized to run train.py.
diff --git a/docs/source/en/basics/model_checkpoint.md b/docs/source/en/basics/model_checkpoint.md
new file mode 100644
index 000000000000..09d44e7c2709
--- /dev/null
+++ b/docs/source/en/basics/model_checkpoint.md
@@ -0,0 +1,61 @@
+# Model Checkpoint
+
+Author : Guangyang Lu
+
+**Prerequisite:**
+- [Launch Colossal-AI](./launch_colossalai.md)
+- [Initialize Colossal-AI](./initialize_features.md)
+
+**Example Code:**
+- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint)
+
+**This function is experiential.**
+
+## Introduction
+
+In this tutorial, you will learn how to save and load model checkpoints.
+
+To leverage the power of parallel strategies in Colossal-AI, modifications to models and tensors are needed, for which you cannot directly use `torch.save` or `torch.load` to save or load model checkpoints. Therefore, we have provided you with the API to achieve the same thing.
+
+Moreover, when loading, you are not demanded to use the same parallel strategy as saving.
+
+## How to use
+
+### Save
+
+There are two ways to train a model in Colossal-AI, by engine or by trainer.
+**Be aware that we only save the `state_dict`.** Therefore, when loading the checkpoints, you need to define the model first.
+
+#### Save when using engine
+
+```python
+from colossalai.utils import save_checkpoint
+model = ...
+engine, _, _, _ = colossalai.initialize(model=model, ...)
+for epoch in range(num_epochs):
+ ... # do some training
+ save_checkpoint('xxx.pt', epoch, model)
+```
+
+#### Save when using trainer
+```python
+from colossalai.trainer import Trainer, hooks
+model = ...
+engine, _, _, _ = colossalai.initialize(model=model, ...)
+trainer = Trainer(engine, ...)
+hook_list = [
+ hooks.SaveCheckpointHook(1, 'xxx.pt', model)
+ ...]
+
+trainer.fit(...
+ hook=hook_list)
+```
+
+### Load
+
+```python
+from colossalai.utils import load_checkpoint
+model = ...
+load_checkpoint('xxx.pt', model)
+... # train or test
+```
diff --git a/docs/source/en/concepts/colossalai_overview.md b/docs/source/en/concepts/colossalai_overview.md
new file mode 100644
index 000000000000..d75d20196b08
--- /dev/null
+++ b/docs/source/en/concepts/colossalai_overview.md
@@ -0,0 +1,36 @@
+# Colossal-AI Overview
+
+Author: Shenggui Li, Siqi Mai
+
+## About Colossal-AI
+
+With the development of deep learning model size, it is important to shift to a new training paradigm. The traditional training method with no parallelism and optimization became a thing of the past and new training methods are the key to make training large-scale models efficient and cost-effective.
+
+Colossal-AI is designed to be a unfied system to provide an integrated set of training skills and utilities to the user. You can find the common training utilities such as mixed precision training and gradient accumulation. Besides, we provide an array of parallelism including data, tensor and pipeline parallelism. We optimize tensor parallelism with different multi-dimensional distributed matrix-matrix multiplication algorithm. We also provided different pipeline parallelism methods to allow the user to scale their model across nodes efficiently. More advanced features such as offloading can be found in this tutorial documentation in detail as well.
+
+## General Usage
+
+We aim to make Colossal-AI easy to use and non-instrusive to user code. There is a simple general workflow if you want to use Colossal-AI.
+
+
+
+Workflow
+
+
+1. Prepare a configiguration file where specifies the features you want to use and your parameters.
+2. Initialize distributed backend with `colossalai.launch`
+3. Inject the training features into your training components (e.g. model, optimizer) with `colossalai.initialize`.
+4. Run training and testing
+
+We will cover the whole workflow in the `basic tutorials` section.
+
+## Future Development
+
+The Colossal-AI system will be expanded to include more training skills, these new developments may include but are not limited to:
+
+1. optimization of distributed operations
+2. optimization of training on heterogenous system
+3. implementation of training utilities to reduce model size and speed up training while preserving model performance
+4. expansion of existing parallelism methods
+
+We welcome ideas and contribution from the community and you can post your idea for future development in our forum.
diff --git a/docs/source/en/concepts/distributed_training.md b/docs/source/en/concepts/distributed_training.md
new file mode 100644
index 000000000000..5038714f754b
--- /dev/null
+++ b/docs/source/en/concepts/distributed_training.md
@@ -0,0 +1,120 @@
+# Distributed Training
+
+Author: Shenggui Li, Siqi Mai
+
+## What is a distributed system?
+
+
+
+Image source: Towards Data Science
+
+
+A distributed system consists of multiple software components which run on multiple machines. For example, the traditional
+database runs on a single machine. As the amount of data gets incredibly large, a single machine can no longer deliver desirable
+performance to the business, especially in situations such as Black Friday where network traffic can be unexpectedly high.
+To handle such pressure, modern high-performance database is designed to run on multiple machines, and they work together to provide
+high throughput and low latency to the user.
+
+One important evaluation metric for distributed system is scalability. For example, when we run an application on 4 machines,
+we naturally expect that the application can run 4 times faster. However, due to communication overhead and difference in
+hardware performance, it is difficult to achieve linear speedup. Thus, it is important to consider how to make the application
+faster when we implement it. Algorithms of good design and system optimization can help to deliver good performance. Sometimes,
+it is even possible to achieve linear and super-linear speedup.
+
+
+## Why we need distributed training for machine learning?
+
+Back in 2012, [AlexNet](https://arxiv.org/abs/1404.5997) won the champion of the ImageNet competition, and it was trained
+on two GTX 580 3GB GPUs.
+Today, most models that appear in the top AI conferences are trained on multiple GPUs. Distributed training is definitely
+a common practice when researchers and engineers develop AI models. There are several reasons behind this trend.
+
+1. Model size increases rapidly. [ResNet50](https://arxiv.org/abs/1512.03385) has 20 million parameters in 2015,
+[BERT-Large](https://arxiv.org/abs/1810.04805) has 345 million parameters in 2018,
+[GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
+has 1.5 billion parameters in 2018, and [GPT-3](https://arxiv.org/abs/2005.14165) has 175 billion parameters in 2020.
+It is obvious that the model size grows exponentially with time. The current largest model has exceeded more than 1000
+billion parameters. Super large models generally deliver more superior performance compared to their smaller counterparts.
+
+
+Image source: HuggingFace
+
+
+
+2. Dataset size increases rapidly. For most machine learning developers, MNIST and CIFAR10 datasets are often the first few
+datasets on which they train their models. However, these datasets are very small compared to well-known ImageNet datasets.
+Google even has its own (unpublished) JFT-300M dataset which has around 300 million images, and this is close to 300 times
+larger than the ImageNet-1k dataset.
+
+
+3. Computing power gets stronger. With the advancement in the semiconductor industry, graphics cards become more and more
+powerful. Due to its larger number of cores, GPU is the most common compute platform for deep learning.
+From K10 GPU in 2012 to A100 GPU in 2020, the computing power has increased several hundred times. This allows us to performance
+compute-intensive tasks faster and deep learning is exactly such a task.
+
+Nowadays, the model can be too large to fit into a single GPU, and the dataset can be large enough to train for a hundred
+days on a single GPU. Only by training our models on multiple GPUs with different parallelization techniques, we are able
+to speed up the training process and obtain results in a reasonable amount of time.
+
+
+## Basic Concepts in Distributed Training
+
+Distributed training requires multiple machines/GPUs. During training, there will be communication among these devices.
+To understand distributed training better, there are several important terms to be made clear.
+
+- host: host is the main device in the communication network. It is often required as an argument when initializing the
+distributed environment.
+- port: port here mainly refers to master port on the host for communication.
+- rank: the unique ID given to a device in the network.
+- world size: the number of devices in the network.
+- process group: a process group is a communication network which include a subset of the devices. There is always a default
+process group which contains all the devices. A subset devices can form a process group so that they only communicate among
+the devices within the group.
+
+
+
+A distributed system example
+
+
+To illustrate these concepts, let's assume we have 2 machines (also called nodes), and each machine has 4 GPUs. When we
+initialize distributed environment over these two machines, we essentially launch 8 processes (4 processes on each machine)
+and each process is bound to a GPU.
+
+Before initializing the distributed environment, we need to specify the host (master address) and port (master port). In
+this example, we can let host be node 0 and port be a number such as 29500. All the 8 processes will then look for the
+address and port and connect to one another.
+The default process group will then be created. The default process group has a world size of 8 and details are as follows:
+
+| process ID | rank | Node index | GPU index |
+| ---------- | ---- | ---------- | --------- |
+| 0 | 0 | 0 | 0 |
+| 1 | 1 | 0 | 1 |
+| 2 | 2 | 0 | 2 |
+| 3 | 3 | 0 | 3 |
+| 4 | 4 | 1 | 0 |
+| 5 | 5 | 1 | 1 |
+| 6 | 6 | 1 | 2 |
+| 7 | 7 | 1 | 3 |
+
+
+We can also create a new process group. This new process group can contain any subset of the processes.
+For example, we can create one containing only even-number processes, and the details of this new group will be:
+
+| process ID | rank | Node index | GPU index |
+| ---------- | ---- | ---------- | --------- |
+| 0 | 0 | 0 | 0 |
+| 2 | 1 | 0 | 2 |
+| 4 | 2 | 1 | 0 |
+| 6 | 3 | 1 | 2 |
+
+**Please note that rank is relative to the process group and one process can have a different rank in different process
+groups. The max rank is always `world size of the process group - 1`.**
+
+In the process group, the processes can communicate in two ways:
+1. peer-to-peer: one process send data to another process
+2. collective: a group of process perform operations such as scatter, gather, all-reduce, broadcast together.
+
+
+
+Collective communication, source: PyTorch distributed tutorial
+
diff --git a/docs/source/en/concepts/paradigms_of_parallelism.md b/docs/source/en/concepts/paradigms_of_parallelism.md
new file mode 100644
index 000000000000..1a5dab7a76f7
--- /dev/null
+++ b/docs/source/en/concepts/paradigms_of_parallelism.md
@@ -0,0 +1,124 @@
+# Paradigms of Parallelism
+
+Author: Shenggui Li, Siqi Mai
+
+## Introduction
+
+With the development of deep learning, there is an increasing demand for parallel training. This is because that model
+and datasets are getting larger and larger and training time becomes a nightmare if we stick to single-GPU training. In
+this section, we will provide a brief overview of existing methods to parallelize training. If you wish to add on to this
+post, you may create a discussion in the [GitHub forum](https://github.com/hpcaitech/ColossalAI/discussions).
+
+## Data Parallel
+
+Data parallel is the most common form of parallelism due to its simplicity. In data parallel training, the dataset is split
+into several shards, each shard is allocated to a device. This is equivalent to parallelize the training process along the
+batch dimension. Each device will hold a full copy of the model replica and trains on the dataset shard allocated. After
+back-propagation, the gradients of the model will be all-reduced so that the model parameters on different devices can stay
+synchronized.
+
+
+
+Data parallel illustration
+
+
+## Model Parallel
+
+In data parallel training, one prominent feature is that each GPU holds a copy of the whole model weights. This brings
+redundancy issue. Another paradigm of parallelism is model parallelism, where model is split and distributed over an array
+of devices. There are generally two types of parallelism: tensor parallelism and pipeline parallelism. Tensor parallelism is
+to parallelize computation within an operation such as matrix-matrix multiplication. Pipeline parallelism is to parallelize
+computation between layers. Thus, from another point of view, tensor parallelism can be seen as intra-layer parallelism and
+pipeline parallelism can be seen as inter-layer parallelism.
+
+### Tensor Parallel
+
+Tensor parallel training is to split a tensor into `N` chunks along a specific dimension and each device only holds `1/N`
+of the whole tensor while not affecting the correctness of the computation graph. This requires additional communication
+to make sure that the result is correct.
+
+Taking a general matrix multiplication as an example, let's say we have C = AB. We can split B along the column dimension
+into `[B0 B1 B2 ... Bn]` and each device holds a column. We then multiply `A` with each column in `B` on each device, we
+will get `[AB0 AB1 AB2 ... ABn]`. At this moment, each device still holds partial results, e.g. device rank 0 holds `AB0`.
+To make sure the result is correct, we need to all-gather the partial result and concatenate the tensor along the column
+dimension. In this way, we are able to distribute the tensor over devices while making sure the computation flow remains
+correct.
+
+
+
+Tensor parallel illustration
+
+
+In Colossal-AI, we provide an array of tensor parallelism methods, namely 1D, 2D, 2.5D and 3D tensor parallelism. We will
+talk about them in detail in `advanced tutorials`.
+
+
+Related paper:
+- [GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding](https://arxiv.org/abs/2006.16668)
+- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
+- [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
+- [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
+- [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
+
+### Pipeline Parallel
+
+Pipeline parallelism is generally easy to understand. If you recall your computer architecture course, this indeed exists
+in the CPU design.
+
+
+
+Pipeline parallel illustration
+
+
+The core idea of pipeline parallelism is that the model is split by layer into several chunks, each chunk is
+given to a device. During the forward pass, each device passes the intermediate activation to the next stage. During the backward pass,
+each device passes the gradient of the input tensor back to the previous pipeline stage. This allows devices to compute simultaneously,
+and increases the training throughput. One drawback of pipeline parallel training is that there will be some bubble time where
+some devices are engaged in computation, leading to waste of computational resources.
+
+
+
+Source: GPipe
+
+
+Related paper:
+- [PipeDream: Fast and Efficient Pipeline Parallel DNN Training](https://arxiv.org/abs/1806.03377)
+- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965)
+- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
+- [Chimera: Efficiently Training Large-Scale Neural Networks with Bidirectional Pipelines](https://arxiv.org/abs/2107.06925)
+
+
+## Optimizer-Level Parallel
+
+Another paradigm works at the optimizer level, and the current most famous method of this paradigm is ZeRO which stands
+for [zero redundancy optimizer](https://arxiv.org/abs/1910.02054). ZeRO works at three levels to remove memory redundancy
+(fp16 training is required for ZeRO):
+
+- Level 1: The optimizer states are partitioned across the processes
+- Level 2: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process
+only stores the gradients corresponding to its partition of the optimizer states.
+- Level 3: The 16-bit model parameters are partitioned across the processes
+
+Related paper:
+- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054)
+
+
+## Parallelism on Heterogeneous System
+
+The methods mentioned above generally require a large number of GPU to train a large model. However, it is often neglected
+that CPU has a much larger memory compared to GPU. On a typical server, CPU can easily have several hundred GB RAM while each GPU
+typically only has 16 or 32 GB RAM. This prompts the community to think why CPU memory is not utilized for distributed training.
+
+Recent advances rely on CPU and even NVMe disk to train large models. The main idea is to offload tensors back to CPU memory
+or NVMe disk when they are not used. By using the heterogeneous system architecture, it is possible to accommodate a huge
+model on a single machine.
+
+
+
+Heterogenous system illustration
+
+
+Related paper:
+- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
+- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)
+- [PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management](https://arxiv.org/abs/2108.05818)
diff --git a/docs/source/en/features/1D_tensor_parallel.md b/docs/source/en/features/1D_tensor_parallel.md
new file mode 100644
index 000000000000..530c2e7b64bc
--- /dev/null
+++ b/docs/source/en/features/1D_tensor_parallel.md
@@ -0,0 +1,111 @@
+# 1D Tensor Parallelism
+
+Author: Zhengda Bian, Yongbin Li
+
+**Prerequisite**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Configure Parallelization](../basics/configure_parallelization.md)
+
+**Example Code**
+- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_1d.py)
+
+**Related Paper**
+- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf)
+
+## Introduction
+
+Tensor parallelism partitions model weights across multiple devices in order to reduce memory load.
+An efficient 1D tensor parallelism implementation was introduced by [Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf).
+
+Let's take a linear layer as an example, which consists of a GEMM $Y = XA$. Given 2 processors, we split the columns of $A$ into $[A_1 ~ A_2]$, and calculate $Y_i = XA_i$ on each processor, which then forms $[Y_1 ~ Y_2] = [XA_1 ~ XA_2]$. This is called a column-parallel fashion.
+
+When a second linear layer $Z=YB$ follows the column-parallel one, we split $B$ into $\left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right]$,
+which is called a row-parallel fashion.
+To calculate $Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right]$, we first calculate $Y_iB_i$ on each processor, then use an all-reduce to aggregate the results as $Z=Y_1B_1+Y_2B_2$.
+
+We also need to note that in the backward pass, the column-parallel linear layer needs to aggregate the gradients of the input tensor $X$, because on each processor $i$ we only have $\dot{X_i}=\dot{Y_i}A_i^T$.
+Thus, we apply an all-reduce across the processors to get $\dot{X}=\dot{Y}A^T=\dot{Y_1}A_1^T+\dot{Y_2}A_2^T$.
+
+## Efficiency
+Given $P$ processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 1D tensor parallelism.
+
+| Computation | Memory (parameters) | Memory (activations) | Communication (bandwidth) | Communication (latency) |
+| :-: | :-: | :-: | :-: | :-: |
+| $O(1/P)$ | $O(1/P)$ | $O(1)$ | $O(2(P-1)/P)$ | $O(2(P-1))$ |
+
+## Usage
+
+To enable 1D tensor parallelism for our model, e.g. on 2 GPUs, we need to configure the parallism setting as below.
+```python
+CONFIG = dict(parallel=dict(
+ data=1,
+ pipeline=1,
+ tensor=dict(size=2, mode='1d'),
+))
+```
+Then Colossal-AI will automatically apply 1D parallelism to all the layers from `colossalai.nn`.
+
+Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
+```python
+import colossalai
+import colossalai.nn as col_nn
+import torch
+from colossalai.utils import print_rank_0
+
+class MLP(torch.nn.Module):
+ def __init__(self, dim: int = 256):
+ super().__init__()
+ intermediate_dim = dim * 4
+ self.dense_1 = col_nn.Linear(dim, intermediate_dim)
+ print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}')
+ self.activation = torch.nn.GELU()
+ self.dense_2 = col_nn.Linear(intermediate_dim, dim)
+ print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}')
+ self.dropout = col_nn.Dropout(0.1)
+
+ def forward(self, x):
+ x = self.dense_1(x)
+ print_rank_0(f'Output of the first linear layer: {x.shape}')
+ x = self.activation(x)
+ x = self.dense_2(x)
+ print_rank_0(f'Output of the second linear layer: {x.shape}')
+ x = self.dropout(x)
+ return x
+```
+
+Launch Colossal-AI on 2 GPUs and build the model.
+
+```python
+parser = colossalai.get_default_parser()
+colossalai.launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ local_rank=args.local_rank,
+ host=args.host,
+ port=args.port)
+
+m = MLP()
+```
+We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
+```shell
+Weight of the first linear layer: torch.Size([256, 512])
+Weight of the second linear layer: torch.Size([512, 256])
+```
+The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the column-parallel partitioning, it becomes `[256, 512]`.
+Similarly, the second row-parallel layer partitions the weight `[1024, 256]` into `[512, 256]`.
+
+We can run the model with some random inputs.
+```python
+from colossalai.utils import get_current_device
+
+x = torch.randn((16, 256), device=get_current_device())
+torch.distributed.broadcast(x, src=0) # synchronize input
+
+x = m(x)
+```
+Then we can see the shapes of activation results.
+```shell
+Output of the first linear layer: torch.Size([16, 512])
+Output of the second linear layer: torch.Size([16, 256])
+```
+The output of the first linear layer is split into 2 partitions (each has the shape `[16, 512]`), while the second layer has identical outputs across the GPUs.
diff --git a/docs/source/en/features/2D_tensor_parallel.md b/docs/source/en/features/2D_tensor_parallel.md
new file mode 100644
index 000000000000..582614c2f2f4
--- /dev/null
+++ b/docs/source/en/features/2D_tensor_parallel.md
@@ -0,0 +1,142 @@
+# 2D Tensor Parallelism
+
+Author: Zhengda Bian, Yongbin Li
+
+**Prerequisite**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Configure Parallelization](../basics/configure_parallelization.md)
+- [1D Tensor Parallelism](./1D_tensor_parallel.md)
+
+**Example Code**
+- [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2d.py)
+
+**Related Paper**
+- [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/pdf/2104.05343.pdf)
+
+## Introduction
+
+1D tensor parallelism does not partition activations, which can also consume a great amount of memory in terms of large-scale models.
+To evenly distribute the computation and memory load, [an efficient 2D tensor parallelism algorithm](https://arxiv.org/pdf/2104.05343.pdf) was introduced based on SUMMA (Scalable Universal Matrix Multiplication Algorithm).
+
+Let's still take a linear layer $Y = XA$ as an example.
+Given $P=q\times q$ processors (necessary condition), e.g. $q=2$, we split both the input $X$ and weight $A$ into
+
+$$
+\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right]
+\text{~and~}
+\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right].
+$$
+
+The calculation includes $q$ steps. When $t=1$, $X_{i0}$ is broadcasted in its row, and $A_{0j}$ is broadcasted in its column. So, we have
+
+$$
+\left[\begin{matrix} X_{10},A_{00} & X_{10},A_{01} \\ X_{00},A_{00} & X_{00},A_{01} \end{matrix} \right].
+$$
+
+Then we multiply $X_{i0}$ and $A_{0j}$ on each processor $(i, j)$ as
+
+$$
+\left[\begin{matrix} X_{10}A_{00} & X_{10}A_{01} \\ X_{00}A_{00} & X_{00}A_{01} \end{matrix} \right] (1).
+$$
+
+Similarly, when $t=2$, $X_{i1}$ is broadcasted in its row, $A_{1j}$ is broadcasted in its column, and we multiply them as
+
+$$
+\left[\begin{matrix} X_{11}A_{10} & X_{11}A_{11} \\ X_{01}A_{10} & X_{01}A_{11} \end{matrix} \right] (2).
+$$
+
+By adding $(1)$ and $(2)$ up, we have
+
+$$
+Y = XA = \left[\begin{matrix} X_{10}A_{00}+X_{11}A_{10} & X_{10}A_{01}+X_{11}A_{11} \\ X_{00}A_{00}+X_{01}A_{10} & X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right].
+$$
+
+## Efficiency
+Given $P=q\times q$ processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 2D tensor parallelism.
+
+| Computation | Memory (parameters) | Memory (activations) | Communication (bandwidth) | Communication (latency) |
+| :-: | :-: | :-: | :-: | :-: |
+| $O(1/q^2)$ | $O(1/q^2)$ | $O(1/q^2)$ | $O(6(q-1)/q)$ | $O(6(q-1))$ |
+
+## Usage
+
+To enable 2D tensor parallelism for our model, e.g. on 4 GPUs, we need to configure the parallism setting as below.
+```python
+CONFIG = dict(parallel=dict(
+ data=1,
+ pipeline=1,
+ tensor=dict(size=4, mode='2d'),
+))
+```
+Then Colossal-AI will automatically apply 2D parallelism to all the layers from `colossalai.nn`.
+
+Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
+```python
+import colossalai
+import colossalai.nn as col_nn
+import torch
+from colossalai.utils import print_rank_0
+
+class MLP(torch.nn.Module):
+ def __init__(self, dim: int = 256):
+ super().__init__()
+ intermediate_dim = dim * 4
+ self.dense_1 = col_nn.Linear(dim, intermediate_dim)
+ print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
+ self.activation = torch.nn.GELU()
+ self.dense_2 = col_nn.Linear(intermediate_dim, dim)
+ print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
+ self.dropout = col_nn.Dropout(0.1)
+
+ def forward(self, x):
+ x = self.dense_1(x)
+ print_rank_0(f'Output of the first linear layer: {x.shape}')
+ x = self.activation(x)
+ x = self.dense_2(x)
+ print_rank_0(f'Output of the second linear layer: {x.shape}')
+ x = self.dropout(x)
+ return x
+```
+Launch Colossal-AI on 4 GPUs and build the model
+```python
+parser = colossalai.get_default_parser()
+colossalai.launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ local_rank=args.local_rank,
+ host=args.host,
+ port=args.port)
+
+m = MLP()
+```
+We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
+```shell
+Weight of the first linear layer: torch.Size([128, 512])
+Weight of the second linear layer: torch.Size([512, 128])
+```
+The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2D parallelism, it becomes `[128, 512]` on each GPU.
+Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`.
+
+We can run the model with some random inputs.
+```python
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.utils import get_current_device
+
+x = torch.randn((16, 256), device=get_current_device())
+# partition input
+torch.distributed.broadcast(x, src=0)
+x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)]
+x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)]
+print_rank_0(f'Input: {x.shape}')
+
+x = m(x)
+```
+Then we can see the shapes of activation results.
+```shell
+Input: torch.Size([8, 128])
+Output of the first linear layer: torch.Size([8, 512])
+Output of the second linear layer: torch.Size([8, 128])
+```
+The activation tensors in 2D parallelism are all split in both row and column.
+E.g. the output of the first linear layer has the shape `[8, 512]`, while the second layer has the output of `[8, 128]`.
diff --git a/docs/source/en/features/2p5D_tensor_parallel.md b/docs/source/en/features/2p5D_tensor_parallel.md
new file mode 100644
index 000000000000..34a261ea0aa0
--- /dev/null
+++ b/docs/source/en/features/2p5D_tensor_parallel.md
@@ -0,0 +1,142 @@
+# 2.5D Tensor Parallelism
+
+Author: Zhengda Bian, Yongbin Li
+
+**Prerequisite**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Configure Parallelization](../basics/configure_parallelization.md)
+- [1D Tensor Parallelism](./1D_tensor_parallel.md)
+- [2D Tensor Parallelism](./2D_tensor_parallel.md)
+
+**Example Code**
+- [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2p5d.py)
+
+**Related Paper**
+- [2.5-dimensional distributed model training](https://arxiv.org/pdf/2105.14500.pdf)
+
+## Introduction
+
+Compared with 1D tensor parallelism, 2D parallelism reduces the memory cost, but may introduce more communication.
+Therefore, a [2.5D tensor parallelism algorithm](https://arxiv.org/pdf/2105.14500.pdf) was proposed based on 2.5D SUMMA to reduce communication by using more devices.
+
+Let's still take a linear layer $Y = XA$ as an example.
+Given $P=q \times q \times d$ processors (necessary condition), e.g. $q=d=2$, we split the input $X$ into $d\times q$ rows and $q$ columns as
+
+$$
+\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \\ X_{10} & X_{11} \\ X_{00} & X_{01}\end{matrix} \right],
+$$
+which can be reshaped into $d$ layers as
+
+$$
+\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] \text{~and~}\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \end{matrix} \right].
+$$
+
+Also, the weight $A$ is split into
+
+$$
+\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right].
+$$
+
+For each layer of $X$, we use the SUMMA algorithm to multiply $X$ and $A$.
+Then, we have the output
+
+$$
+\left[\begin{matrix} Y_{10}=X_{10}A_{00}+X_{11}A_{10} & Y_{11}=X_{10}A_{01}+X_{11}A_{11} \\ Y_{00}=X_{00}A_{00}+X_{01}A_{10} & Y_{01}=X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right]
+\text{~and~}
+$$
+$$
+\left[\begin{matrix} Y_{30}=X_{30}A_{00}+X_{31}A_{10} & Y_{31}=X_{30}A_{01}+X_{31}A_{11} \\ Y_{20}=X_{20}A_{00}+X_{21}A_{10} & Y_{21}=X_{20}A_{01}+X_{21}A_{11} \end{matrix} \right].
+$$
+
+## Efficiency
+Given $P=q \times q \times d$ processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 2.5D tensor parallelism.
+
+| Computation | Memory (parameters) | Memory (activations) | Communication (bandwidth) | Communication (latency) |
+| :-: | :-: | :-: | :-: | :-: |
+| $O(1/dq^2)$ | $O(1/q^2)$ | $O(1/dq^2)$ | $\small O(3(q-1)(d+1)/dq)$ | $O(6(q-1))$ |
+
+## Usage
+
+To enable 2.5D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallism setting as below.
+```python
+CONFIG = dict(parallel=dict(
+ data=1,
+ pipeline=1,
+ tensor=dict(size=8, mode='2.5d', depth=2),
+))
+
+```
+Then Colossal-AI will automatically apply 2.5D parallelism to all the layers from `colossalai.nn`.
+
+Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
+```python
+import colossalai
+import colossalai.nn as col_nn
+import torch
+from colossalai.utils import print_rank_0
+
+class MLP(torch.nn.Module):
+ def __init__(self, dim: int = 256):
+ super().__init__()
+ intermediate_dim = dim * 4
+ self.dense_1 = col_nn.Linear(dim, intermediate_dim)
+ print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
+ self.activation = torch.nn.GELU()
+ self.dense_2 = col_nn.Linear(intermediate_dim, dim)
+ print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
+ self.dropout = col_nn.Dropout(0.1)
+
+ def forward(self, x):
+ x = self.dense_1(x)
+ print_rank_0(f'Output of the first linear layer: {x.shape}')
+ x = self.activation(x)
+ x = self.dense_2(x)
+ print_rank_0(f'Output of the second linear layer: {x.shape}')
+ x = self.dropout(x)
+ return x
+```
+Launch Colossal-AI on 8 GPUs and build the model
+```python
+parser = colossalai.get_default_parser()
+colossalai.launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ local_rank=args.local_rank,
+ host=args.host,
+ port=args.port)
+
+m = MLP()
+```
+We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
+```shell
+Weight of the first linear layer: torch.Size([128, 512])
+Weight of the second linear layer: torch.Size([512, 128])
+```
+The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 2.5D parallelism, it becomes `[128, 512]` on each GPU.
+Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 128]`.
+
+We can run the model with some random inputs.
+```python
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.utils import get_current_device
+
+x = torch.randn((16, 256), device=get_current_device())
+# partition input
+torch.distributed.broadcast(x, src=0)
+x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)]
+x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)]
+x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)]
+print_rank_0(f'Input: {x.shape}')
+
+x = m(x)
+```
+Then we can see the shapes of activation results.
+```shell
+Input: torch.Size([4, 128])
+Output of the first linear layer: torch.Size([4, 512])
+Output of the second linear layer: torch.Size([4, 128])
+```
+The activation tensors in 2.5D parallelism are all split by $d \times q$ in the row and $q$ in the column.
+E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`.
+Note, 2.5D parallelism use the same partition method as 2D parallelism for weights, where the difference is the partition of input.
diff --git a/docs/source/en/features/3D_tensor_parallel.md b/docs/source/en/features/3D_tensor_parallel.md
new file mode 100644
index 000000000000..1207376335ce
--- /dev/null
+++ b/docs/source/en/features/3D_tensor_parallel.md
@@ -0,0 +1,151 @@
+# 3D Tensor Parallelism
+
+Author: Zhengda Bian, Yongbin Li
+
+**Prerequisite**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Configure Parallelization](../basics/configure_parallelization.md)
+- [1D Tensor Parallelism](./1D_tensor_parallel.md)
+- [2D Tensor Parallelism](./2D_tensor_parallel.md)
+
+**Example Code**
+- [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_3d.py)
+
+**Related Paper**
+- [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/pdf/2105.14450.pdf)
+
+## Introduction
+
+The [3D tensor parallelism](https://arxiv.org/pdf/2105.14450.pdf) is an approach to parallelize the computation of neural models, hoping to obtain the optimal communication cost.
+
+Let's still take a linear layer $Y = XA$ as an example.
+Given $P=q \times q \times q$ processors (necessary condition), e.g. $q=2$, we split the input $X$ and weight $A$ into
+
+$$
+\left[\begin{matrix}
+ X_{000} & X_{001} \\
+ X_{010} & X_{011} \\
+ X_{100} & X_{101} \\
+ X_{110} & X_{111} \end{matrix}
+\right]
+\text{~and~}
+\left[\begin{matrix}
+ A_{000} & A_{001} & A_{010} & A_{011} \\
+ A_{100} & A_{101} & A_{110} & A_{111} \end{matrix}
+\right]
+\text{~respectively,}$$
+where each $X_{ijl}$ and $A_{lji}$ are stored at processor $(i,j,l)$, as shown in the figure below.
+
+
+
+
+
+
+
+
+Then we all-gather $X_{ijl}$ across $(i, 0...q,l)$, as well as $A_{lji}$ across $(0...q, j, l)$.
+So, we have $X_{il}$ and $A_{lj}$ on each processor $(i,j,l)$ to get $X_{il}A_{lj}$.
+Finally, we reduce-scatter the results across $(i, j, 0...q)$ to get $Y_{ijl}$, which forms
+$$
+Y=
+\left[\begin{matrix}
+ Y_{000} & Y_{001} \\
+ Y_{010} & Y_{011} \\
+ Y_{100} & Y_{101} \\
+ Y_{110} & Y_{111} \end{matrix}
+\right].
+$$
+
+We also need to note that in the backward pass, we need to all-gather the gradient $\dot{Y_{ijl}}$, and then reduce-scatter the gradient $\dot{X_{il}}=\dot{Y_{ij}}A_{lj}^T$ and $\dot{A_{lj}}=X_{il}^T\dot{Y_{ij}}$.
+
+## Efficiency
+Given $P=q \times q \times q$ processors, we present the theoretical computation and memory cost, as well as the communication cost based on the ring algorithm in both the forward and backward pass of 3D tensor parallelism.
+
+| Computation | Memory (parameters) | Memory (activations) | Communication (bandwidth) | Communication (latency) |
+| :-: | :-: | :-: | :-: | :-: |
+| $O(1/q^3)$ | $O(1/q^3)$ | $O(1/q^3)$ | $O(6(q-1)/q^3)$ | $O(6(q-1))$ |
+
+## Usage
+
+To enable 3D tensor parallelism for our model, e.g. on 8 GPUs, we need to configure the parallism setting as below.
+```python
+CONFIG = dict(parallel=dict(
+ data=1,
+ pipeline=1,
+ tensor=dict(size=8, mode='3d'),
+))
+```
+Then Colossal-AI will automatically apply 3D parallelism to all the layers from `colossalai.nn`.
+
+Let's define a model that consists of a two-layer multi-layer perceptron (MLP) as below.
+```python
+import colossalai
+import colossalai.nn as col_nn
+import torch
+from colossalai.utils import print_rank_0
+
+class MLP(torch.nn.Module):
+ def __init__(self, dim: int = 256):
+ super().__init__()
+ intermediate_dim = dim * 4
+ self.dense_1 = col_nn.Linear(dim, intermediate_dim)
+ print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
+ self.activation = torch.nn.GELU()
+ self.dense_2 = col_nn.Linear(intermediate_dim, dim)
+ print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
+ self.dropout = col_nn.Dropout(0.1)
+
+ def forward(self, x):
+ x = self.dense_1(x)
+ print_rank_0(f'Output of the first linear layer: {x.shape}')
+ x = self.activation(x)
+ x = self.dense_2(x)
+ print_rank_0(f'Output of the second linear layer: {x.shape}')
+ x = self.dropout(x)
+ return x
+```
+Launch Colossal-AI on 8 GPUs and build the model
+```python
+parser = colossalai.get_default_parser()
+colossalai.launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ local_rank=args.local_rank,
+ host=args.host,
+ port=args.port)
+
+m = MLP()
+```
+We will see the shapes of partitioned parameters(e.g. weights) in the MLP model.
+```shell
+Weight of the first linear layer: torch.Size([128, 256])
+Weight of the second linear layer: torch.Size([512, 64])
+```
+The complete weight of the first linear layer is supposed to have the shape `[256, 1024]`. After the partitioning of 3D parallelism, it becomes `[128, 256]` on each GPU.
+Similarly, the second layer partitions the weight `[1024, 256]` into `[512, 64]`.
+
+We can run the model with some random inputs.
+```python
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.utils import get_current_device
+
+x = torch.randn((16, 256), device=get_current_device())
+# partition input
+torch.distributed.broadcast(x, src=0)
+x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)]
+x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)]
+x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)]
+print_rank_0(f'Input: {x.shape}')
+
+x = m(x)
+```
+Then we can see the shapes of activation results.
+```shell
+Input: torch.Size([4, 128])
+Output of the first linear layer: torch.Size([4, 512])
+Output of the second linear layer: torch.Size([4, 128])
+```
+The activation tensors in 3D parallelism are all split by $q^2$ in the row and $q$ in the column.
+E.g. the output of the first linear layer has the shape `[4, 512]`), while the second layer has the output of `[4, 128]`.
+Note, although the results of 3D parallelism have the same shape as that of 2.5D parallelism for weights here, the content of each partition is different.
diff --git a/docs/source/en/features/gradient_accumulation.md b/docs/source/en/features/gradient_accumulation.md
new file mode 100644
index 000000000000..d8781ee691bc
--- /dev/null
+++ b/docs/source/en/features/gradient_accumulation.md
@@ -0,0 +1,45 @@
+# Gradient Accumulation
+
+Author: Shenggui Li, Yongbin Li
+
+**Prerequisite**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Use Engine and Trainer in Training](../basics/engine_trainer.md)
+
+**Example Code**
+- [ColossalAI-Examples Gradient Accumulation](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation)
+
+## Introduction
+
+Gradient accumulation is a common way to enlarge your batch size for training.
+When training large-scale models, memory can easily become the bottleneck and the batch size can be very small, (e.g. 2),
+leading to unsatisfactory convergence. Gradient accumulation works by adding up the gradients calculated in multiple iterations,
+and only update the parameters in the preset iteration.
+
+## Usage
+
+It is simple to use gradient accumulation in Colossal-AI. Just add this following configuration into your config file.
+The integer represents the number of iterations to accumulate gradients.
+
+```python
+gradient_accumulation =
+```
+
+## Hands-on Practice
+
+We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_accumulation)
+to demonstrate gradient accumulation. In this example, we set the gradinet accumulation size to be 4. You can run the script using this command:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py
+```
+
+You will see output similar to the text below. This shows gradient is indeed accumulated as the parameter is not updated
+in the first 3 steps, but only updated in the last step.
+
+```text
+iteration 0, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=)
+iteration 1, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=)
+iteration 2, first 10 elements of param: tensor([-0.0208, 0.0189, 0.0234, 0.0047, 0.0116, -0.0283, 0.0071, -0.0359, -0.0267, -0.0006], device='cuda:0', grad_fn=)
+iteration 3, first 10 elements of param: tensor([-0.0141, 0.0464, 0.0507, 0.0321, 0.0356, -0.0150, 0.0172, -0.0118, 0.0222, 0.0473], device='cuda:0', grad_fn=)
+```
diff --git a/docs/source/en/features/gradient_clipping.md b/docs/source/en/features/gradient_clipping.md
new file mode 100644
index 000000000000..f606dde6c393
--- /dev/null
+++ b/docs/source/en/features/gradient_clipping.md
@@ -0,0 +1,62 @@
+# Gradient Clipping
+
+Author: Boxiang Wang, Haichen Huang, Yongbin Li
+
+**Prerequisite**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Use Engine and Trainer in Training](../basics/engine_trainer.md)
+
+**Example Code**
+- [ColossalAI-Examples Gradient Clipping](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping)
+
+**Related Paper**
+- [On the difficulty of training Recurrent Neural Networks](https://arxiv.org/abs/1211.5063)
+
+## Introduction
+
+In order to speed up training process and seek global optimum for better performance, more and more learning
+rate schedulers have been proposed. People turn to control learning rate to adjust descent pace during training,
+which makes gradient vector better to be uniformed in every step. In that case, the descent pace can be
+controlled as expected. As a result, gradient clipping, a technique which can normalize the gradient vector
+to circumscribe it in a uniformed length, becomes indispensable for those who desire their better
+performance of their models.
+
+You do not have to worry about implementing gradient clipping when using Colossal-AI, we support gradient
+clipping in a powerful and convenient way. All you need is just an additional command in your configuration
+file.
+
+## Why you should use gradient clipping provided by Colossal-AI
+
+The reason of why we do not recommend users to write gradient clipping by themselves is that naive gradient clipping
+may fail when applying tensor parallelism, pipeline parallelism or MoE.
+
+According to the illustration below, each GPU only owns a portion of parameters of the weight in a linear layer.
+To get correct norm of gradient vector of the weight of the linear layer, the norm of every gradient vector in each GPU
+should be summed together.
+More complicated thing is that the distribution of bias is different from the distribution of the weight.
+The communication group is different in the sum operation.
+
+(PS: This situation is an old version of 2D parallelism, the implementation in the code is not the same.
+But it is a good example about the difficulty to unify all communication in gradient clipping.)
+
+
+
+Layout of parameters
+
+
+Do not worry about it, since Colossal-AI have handled it for you.
+
+### Usage
+To use gradient clipping, you can just simply add gradient clipping norm in your configuration file.
+```python
+clip_grad_norm = 1.0
+```
+
+### Hands-On Practice
+
+We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_clipping)
+to demonstrate gradient clipping. In this example, we set the gradient clipping vector norm to be 1.0. You can run the script using this command:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node 1 --master_addr localhost --master_port 29500 train_with_engine.py
+```
diff --git a/docs/source/en/features/gradient_handler.md b/docs/source/en/features/gradient_handler.md
new file mode 100644
index 000000000000..757016fcb53a
--- /dev/null
+++ b/docs/source/en/features/gradient_handler.md
@@ -0,0 +1,63 @@
+# Gradient Handler
+
+Author: Shenggui Li, Yongbin Li
+
+**Prerequisite**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Use Engine and Trainer in Training](../basics/engine_trainer.md)
+
+**Example Code**
+- [ColossalAI-Examples Gradient Handler](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler)
+
+## Introduction
+
+In distributed training, gradient synchronization is required at the end of each iteration. This is important because we
+need to make sure the parameters are updated with the same gradients in different machines so that the resulting parameters
+are the same. This is often seen in data parallel as the model is replicated across data parallel ranks.
+
+In Colossal-AI, we provide an interface for users to customize how they want to handle the synchronization. This brings
+flexibility in cases such as implementing a new parallelism method.
+
+When gradient handlers are used, PyTorch `DistributedDataParallel` will not be used as it will synchronize automatically.
+
+## Customize Your Gradient Handlers
+
+To implement a customized gradient handler, you need to follow these steps.
+1. inherit `BaseGradientHandler` in Colossal-AI.
+2. register the gradient handler into the `GRADIENT_HANDLER`.
+3. implement `handle_gradient` method.
+
+```python
+from colossalai.registry import GRADIENT_HANDLER
+from colossalai.engine.gradient_handler import BaseGradientHandler
+
+
+@GRADIENT_HANDLER.register_module
+class MyGradientHandler(BaseGradientHandler):
+
+ def handle_gradient(self):
+ do_something()
+
+
+```
+
+
+## Usage
+
+To use a gradient handler, you need to specify your gradient handler in the config file. The gradient handler
+will be automatically built and attached to the engine.
+
+```python
+gradient_handler = [dict(type='MyGradientHandler')]
+```
+
+
+### Hands-On Practice
+
+We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/gradient_handler)
+to demonstrate the use of gradient handler. In this example, we used `DataParallelGradientHandler` instead of PyTorch
+`DistributedDataParallel` for data parallel training.
+
+```shell
+python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py
+```
diff --git a/docs/source/en/features/mixed_precision_training.md b/docs/source/en/features/mixed_precision_training.md
new file mode 100644
index 000000000000..71cb6971d346
--- /dev/null
+++ b/docs/source/en/features/mixed_precision_training.md
@@ -0,0 +1,367 @@
+# Auto Mixed Precision Training
+
+Author: Chuanrui Wang, Shenggui Li, Yongbin Li
+
+**Prerequisite**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Use Engine and Trainer in Training](../basics/engine_trainer.md)
+
+**Example Code**
+- [ColossalAI-Examples AMP](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp)
+
+**Related Paper**
+- [Accelerating Scientific Computations with Mixed Precision Algorithms](https://arxiv.org/abs/0808.2794)
+
+
+## Introduction
+
+AMP stands for automatic mixed precision training.
+In Colossal-AI, we have incorporated different implementations of mixed precision training:
+
+1. torch.cuda.amp
+2. apex.amp
+3. naive amp
+
+
+| Colossal-AI | support tensor parallel | support pipeline parallel | fp16 extent |
+| ----------- | ----------------------- | ------------------------- | ----------- |
+| AMP_TYPE.TORCH | ✅ | ❌ | Model parameters, activation, gradients are downcast to fp16 during forward and backward propagation |
+| AMP_TYPE.APEX | ❌ | ❌ | More fine-grained, we can choose opt_level O0, O1, O2, O3 |
+| AMP_TYPE.NAIVE | ✅ | ✅ | Model parameters, forward and backward operations are all downcast to fp16 |
+
+The first two rely on the original implementation of PyTorch (version 1.6 and above) and NVIDIA Apex.
+The last method is similar to Apex O2 level.
+Among these methods, apex AMP is not compatible with tensor parallelism.
+This is because that tensors are split across devices in tensor parallelism, thus, it is required to communicate among different processes to check if inf or nan occurs in the whole model weights.
+We modified the torch amp implementation so that it is compatible with tensor parallelism now.
+
+> ❌️ fp16 and zero configuration are not compatible
+>
+> ⚠️ Pipeline only support naive AMP currently
+
+We recommend you to use torch AMP as it generally gives better accuracy than naive AMP if no pipeline is used.
+
+## Table of Contents
+
+In this tutorial we will cover:
+
+1. AMP introduction
+2. AMP in Colossal-AI
+3. Hands-on Practice
+
+## AMP Introduction
+
+Automatic Mixed Precision training is a mixture of FP16 and FP32 training.
+
+Half-precision float point format (FP16) has lower arithmetic complexity and higher compute efficiency.
+Besides, fp16 requires half of the storage needed by fp32 and saves memory & network bandwidth, which makes more memory
+available for large batch size and model size.
+
+However, there are other operations, like reductions, which require the dynamic range of fp32 to avoid numeric overflow/underflow. That's the reason why we introduce automatic mixed precision, attempting to match each operation to its appropriate data type, which can reduce the memory footprint and augment training efficiency.
+
+
+
+Illustration of an ordinary AMP (figure from PatrickStar paper)
+
+
+## AMP in Colossal-AI
+
+We supported three AMP training methods and allowed the user to train with AMP with no code. You can just simply add `fp16`
+configuration in your configuration file to use AMP.
+
+
+```python
+from colossalai.amp import AMP_TYPE
+
+# use Torch AMP
+fp16=dict(
+ mode = AMP_TYPE.TORCH
+)
+
+# use naive AMP
+fp16=dict(
+ mode = AMP_TYPE.NAIVE
+)
+
+# use NVIDIA Apex AMP
+fp16=dict(
+ mode = AMP_TYPE.APEX
+)
+
+```
+
+> These are the minimum configuration, full configuration are stated in the section later
+
+### AMP Modularity
+
+AMP module is designed to be completely modular and can be used independently.
+If you wish to only use AMP in your code base without `colossalai.initialize`,
+you can use `colossalai.amp.convert_to_amp`.
+
+```python
+from colossalai.amp import AMP_TYPE
+
+# exmaple of using torch amp
+model, optimizer, criterion = colossalai.amp.convert_to_amp(model,
+ optimizer,
+ criterion,
+ AMP_TYPE.TORCH)
+```
+
+### Torch AMP Configuration
+
+```python
+from colossalai.amp import AMP_TYPE
+
+fp16=dict(
+ mode=AMP_TYPE.TORCH,
+
+ # below are default values for grad scaler
+ init_scale=2.**16,
+ growth_factor=2.0,
+ backoff_factor=0.5,
+ growth_interval=2000,
+ enabled=True
+)
+```
+
+With optional arguments:
+- init_scale(float, optional, default=2.**16): Initial scale factor
+- growth_factor(float, optional, default=2.0): Factor by which the scale is multiplied during `update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
+- backoff_factor(float, optional, default=0.5): Factor by which the scale is multiplied during `update` if inf/NaN gradients occur in an iteration.
+- growth_interval(int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``.
+- enabled(bool, optional, default=True): If ``False``, disables gradient scaling. `step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops.
+
+### Apex AMP Configuration
+
+For this mode, we rely on the Apex implementation for mixed precision training.
+We support this plugin because it allows for finer control on the granularity of mixed precision.
+For example, O2 level (optimization level 2) will keep batch normalization in fp32.
+
+If you look for more details, please refer to [Apex Documentation](https://nvidia.github.io/apex/).
+
+```python
+from colossalai.amp import AMP_TYPE
+
+fp16 = dict(
+ mode=AMP_TYPE.APEX,
+
+ # below are the default values
+ enabled=True,
+ opt_level='O1',
+ cast_model_type=None,
+ patch_torch_functions=None,
+ keep_batchnorm_fp32=None,
+ master_weights=None,
+ loss_scale=None,
+ cast_model_outputs=None,
+ num_losses=1,
+ verbosity=1,
+ min_loss_scale=None,
+ max_loss_scale=16777216.0
+)
+```
+
+Parameters:
+- enabled(bool, optional, default=True): If False, renders all AMP calls no-ops, so your script should run as if Amp were not present.
+
+- opt_level(str, optional, default="O1" ): Pure or mixed precision optimization level.
+Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above Apex AMP Documentation.
+
+- num_losses(int, optional, default=1): Option to tell AMP in advance how many losses/backward passes you plan to use.
+When used in conjunction with the loss_id argument to `amp.scale_loss`, enables Amp to use a different loss scale per
+loss/backward pass, which can improve stability. If num_losses is left to 1, Amp will still support multiple
+losses/backward passes, but use a single global loss scale for all of them.
+
+- verbosity(int, default=1): Set to 0 to suppress Amp-related output.
+
+- min_loss_scale(float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic loss scaling.
+The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored.
+
+- max_loss_scale(float, default=2.**24 ): Sets a ceiling for the loss scale values that can be chosen by dynamic loss
+scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.
+
+Currently, the under-the-hood properties that govern pure or mixed precision training are the following:
+cast_model_type, patch_torch_functions, keep_batchnorm_fp32, master_weights, loss_scale.
+They are optional properties override once opt_level is determined
+
+- cast_model_type: Casts your model’s parameters and buffers to the desired type.
+- patch_torch_functions: Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32.
+- keep_batchnorm_fp32: To enhance precision and enable cudnn batchnorm (which improves performance), it’s often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16.
+- master_weights: Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients.
+- loss_scale: If loss_scale is a float value, use this value as the static (fixed) loss scale. If loss_scale is the string "dynamic", adaptively adjust the loss scale over time. Dynamic loss scale adjustments are performed by Amp automatically.
+
+
+### Naive AMP Configuration
+
+In Naive AMP mode, we achieved mixed precision training while maintaining compatibility with complex tensor and pipeline parallelism.
+This AMP mode will cast all operations into fp16.
+The following code block shows the `config.py` file for this mode.
+
+```python
+from colossalai.amp import AMP_TYPE
+
+fp16 = dict(
+ mode=AMP_TYPE.NAIVE,
+
+ # below are the default values
+ log_num_zeros_in_grad=False,
+ initial_scale=2 ** 32,
+ min_scale=1,
+ growth_factor=2,
+ backoff_factor=0.5,
+ growth_interval=1000,
+ hysteresis=2
+)
+```
+
+The default parameters of Naive AMP:
+- log_num_zeros_in_grad(bool): return number of zeros in the gradients.
+- initial_scale(int): initial scale of gradient scaler
+- growth_factor(int): the growth rate of loss scale
+- backoff_factor(float): the decrease rate of loss scale
+- hysterisis(int): delay shift in dynamic loss scaling
+- max_scale(int): maximum loss scale allowed
+- verbose(bool): if set to `True`, will print debug info
+
+When using `colossalai.initialize`, you are required to first instantiate a model, an optimizer and a criterion.
+The output model is converted to AMP model of smaller memory consumption.
+If your input model is already too large to fit in a GPU, please instantiate your model weights in `dtype=torch.float16`.
+Otherwise, try smaller models or checkout more parallelization training techniques!
+
+
+## Hands-on Practice
+
+We provide a [runnable example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/amp) which demonstrates
+the use of AMP with Colossal-AI. In this practice, we will use Torch AMP as an example, but do note that config files are provided for all AMP modes.
+
+### Step 1. Create a config file
+
+Create a `config.py` and add the `fp16` configuration.
+
+```python
+# in config.py
+from colossalai.amp import AMP_TYPE
+
+BATCH_SIZE = 128
+DROP_RATE = 0.1
+NUM_EPOCHS = 300
+
+fp16 = dict(
+ mode=AMP_TYPE.TORCH,
+)
+
+clip_grad_norm = 1.0
+```
+
+### Step 2. Import libraries in train_with_engine.py
+
+Create a `train_with_engine.py` and import the necessary dependencies. Remember to install `scipy` and `timm` by running
+`pip install timm scipy`.
+
+```python
+import os
+import colossalai
+import torch
+from pathlib import Path
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.utils import get_dataloader
+from colossalai.trainer import Trainer, hooks
+from colossalai.nn.lr_scheduler import LinearWarmupLR
+from timm.models import vit_base_patch16_224
+from torchvision import datasets, transforms
+
+```
+
+### Step 3. Initialize Distributed Environment
+
+We then need to initialize distributed environment. For demo purpose, we uses `launch_from_torch`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md)
+for other initialization methods.
+
+```python
+# initialize distributed setting
+parser = colossalai.get_default_parser()
+args = parser.parse_args()
+
+# launch from torch
+colossalai.launch_from_torch(config=args.config)
+
+```
+
+### Step 4. Create training components
+
+Build your model, optimizer, loss function, lr scheduler and dataloaders. Note that the root path of the dataset is
+obtained from the environment varialbe `DATA`. You may `export DATA=/path/to/data` or change `Path(os.environ['DATA'])`
+to a path on your machine. Data will be automatically downloaded to the root path.
+
+```python
+# build model
+ model = vit_base_patch16_224(drop_rate=0.1)
+
+ # build dataloader
+ train_dataset = datasets.Caltech101(
+ root=Path(os.environ['DATA']),
+ download=True,
+ transform=transforms.Compose([
+ transforms.Resize(256),
+ transforms.RandomResizedCrop(224),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ Gray2RGB(),
+ transforms.Normalize([0.5, 0.5, 0.5],
+ [0.5, 0.5, 0.5])
+ ]))
+
+ train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=gpc.config.BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ )
+
+ # build optimizer
+ optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, weight_decay=0.1)
+
+ # build loss
+ criterion = torch.nn.CrossEntropyLoss()
+
+ # lr_scheduelr
+ lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS)
+```
+
+### Step 5. Inject AMP Feature
+
+Call `colossalai.initialize` to convert the training components to be running with FP16.
+
+```python
+engine, train_dataloader, _, _ = colossalai.initialize(
+ model, optimizer, criterion, train_dataloader,
+ )
+```
+
+### Step 6. Train with Engine
+
+Use engine in a normal training loops.
+
+```python
+engine.train()
+for epoch in range(gpc.config.NUM_EPOCHS):
+ for img, label in enumerate(train_dataloader):
+ img = img.cuda()
+ label = label.cuda()
+ engine.zero_grad()
+ output = engine(img)
+ loss = engine.criterion(output, label)
+ engine.backward(loss)
+ engine.step()
+ lr_scheduler.step()
+```
+
+### Step 7. Invoke Training Scripts
+
+Use the following command to start the training scripts. You can change `--nproc_per_node` to use a different number of GPUs.
+
+```python
+python -m torch.distributed.launch --nproc_per_node 4 --master_addr localhost --master_port 29500 train_with_engine.py --config config/config_AMP_torch.py
+```
diff --git a/docs/source/en/features/nvme_offload.md b/docs/source/en/features/nvme_offload.md
new file mode 100644
index 000000000000..2933c3db6c58
--- /dev/null
+++ b/docs/source/en/features/nvme_offload.md
@@ -0,0 +1,263 @@
+# NVMe offload
+
+Author: Hongxin Liu
+
+**Prerequisite:**
+- [Zero Redundancy Optimizer with chunk-based memory management](../features/zero_with_chunk.md)
+
+**Related Paper**
+
+- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
+- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)
+
+## Introduction
+
+If a model has `N` parameters, when using Adam, it has `8N` optimizer states. For billion-scale models, optimizer states take at least 32 GB memory. GPU memory limits the model scale we can train, which is called GPU memory wall. If we offload optimizer states to the disk, we can break through GPU memory wall.
+
+We implement a user-friendly and efficient asynchronous Tensor I/O library: [TensorNVMe](https://github.com/hpcaitech/TensorNVMe). With this library, we can simply implement NVMe offload.
+
+> This library is compatible with all kinds of disk (HDD, SATA SSD, and NVMe SSD). As I/O bandwidth of HDD or SATA SSD is low, it's recommended to use this lib only on NVMe disk.
+
+When optimizing a parameter, we can divide the optimization process into three stages: read, compute and offload. We perform the optimization process in a pipelined fashion, which can overlap computation and I/O.
+
+
+
+Optimization process
+
+
+## Usage
+
+First, please make sure you installed [TensorNVMe](https://github.com/hpcaitech/TensorNVMe):
+
+```shell
+pip install packaging
+pip install tensornvme
+```
+
+We implement NVMe offload of optimizer states for Adam ([CPUAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.cpu_adam.html) and [HybridAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.hybrid_adam.html)).
+
+
+
+
+```python
+from colossalai.nn.optimizer import CPUAdam, HybridAdam
+
+optimizer = HybridAdam(model.parameters(), lr=1e-3, nvme_offload_fraction=1.0, nvme_offload_dir='./')
+```
+
+
+
+`nvme_offload_fraction` is the fraction of optimizer states to be offloaded to NVMe. `nvme_offload_dir` is the directory to save NVMe offload files. If `nvme_offload_dir` is `None`, a random temporary directory will be used.
+
+It's compatible with all parallel methods in ColossalAI.
+
+> ⚠ It only offloads optimizer states on CPU. This means it only affects CPU training or Zero/Gemini with offloading.
+
+## Exampls
+
+Let's start from two simple examples -- training GPT with different methods. These examples relies on `transformers`.
+
+We should install denpendencies first:
+
+```shell
+pip install psutil transformers
+```
+
+First, we import essential packages and modules:
+
+```python
+import os
+import time
+from typing import Dict, Optional
+
+import psutil
+import torch
+import torch.nn as nn
+from transformers.models.gpt2.configuration_gpt2 import GPT2Config
+from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
+
+import colossalai
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
+from colossalai.utils.model.colo_init_context import ColoInitContext
+```
+
+Then we define a loss function:
+
+```python
+class GPTLMLoss(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.loss_fn = nn.CrossEntropyLoss()
+
+ def forward(self, logits, labels):
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)),
+ shift_labels.view(-1))
+```
+
+And we define some utility functions, which generates random data, computes the number of paramters of a model and get memory usage of current process:
+
+```python
+def get_data(batch_size: int, seq_len: int,
+ vocab_size: int, device: Optional[str] = None) -> Dict[str, torch.Tensor]:
+ device = torch.cuda.current_device() if device is None else device
+ input_ids = torch.randint(vocab_size, (batch_size, seq_len),
+ device=device)
+ attn_mask = torch.ones_like(input_ids)
+ return dict(input_ids=input_ids, attention_mask=attn_mask)
+
+
+def get_model_numel(model: nn.Module) -> int:
+ return sum(p.numel() for p in model.parameters())
+
+
+def get_mem_usage() -> int:
+ proc = psutil.Process(os.getpid())
+ return proc.memory_info().rss
+```
+
+We first try to train GPT model on CPU:
+
+```python
+def train_cpu(nvme_offload_fraction: float = 0.0):
+ config = GPT2Config()
+ model = GPT2LMHeadModel(config)
+ criterion = GPTLMLoss()
+ optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
+ print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')
+
+ start = time.time()
+ for step in range(3):
+ data = get_data(4, 128, config.vocab_size, device='cpu')
+ outputs = model(**data)
+ loss = criterion(outputs.logits, data['input_ids'])
+ loss.backward()
+ optimizer.step()
+ optimizer.zero_grad()
+ print(f'[{step}] loss: {loss.item():.3f}')
+
+ print(f'Time: {time.time() - start:.3f} s')
+ print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB')
+```
+
+Run without NVME offload:
+
+```python
+train_cpu(0.0)
+```
+
+We may get below output:
+
+```
+Model numel: 0.116 B
+[0] loss: 10.953
+[1] loss: 10.974
+[2] loss: 10.965
+Time: 7.739 s
+Mem usage: 5966.445 MB
+```
+
+And then run with (full) NVME offload:
+
+```python
+train_cpu(1.0)
+```
+
+We may get:
+
+```
+Model numel: 0.116 B
+[0] loss: 10.951
+[1] loss: 10.994
+[2] loss: 10.984
+Time: 8.527 s
+Mem usage: 4968.016 MB
+```
+
+For GPT2-S, which has 0.116 billion parameters, its optimizer states take about 0.928 GB memory. And NVME offload saves about 998 MB memory, which meets our expectations.
+
+Then we can train GPT model with Gemini. The placement policy of Gemini should be `"auto"`, `"cpu"` or `"const"`.
+
+```python
+def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
+ colossalai.launch_from_torch({})
+ config = GPT2Config()
+ with ColoInitContext(device=torch.cuda.current_device()):
+ model = GPT2LMHeadModel(config)
+ criterion = GPTLMLoss()
+ optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
+ print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')
+
+ gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(),
+ placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd)
+ model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config)
+ optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5)
+
+ start = time.time()
+ for step in range(3):
+ data = get_data(4, 128, config.vocab_size)
+ outputs = model(**data)
+ loss = criterion(outputs.logits, data['input_ids'])
+ optimizer.backward(loss)
+ optimizer.step()
+ optimizer.zero_grad()
+ print(f'[{step}] loss: {loss.item():.3f}')
+
+ print(f'Time: {time.time() - start:.3f} s')
+ print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB')
+```
+
+Run without NVME offload:
+
+```python
+train_gemini_cpu(0.0)
+```
+
+We may get:
+
+```
+Model numel: 0.116 B
+searching chunk configuration is completed in 0.27 s.
+used number: 118.68 MB, wasted number: 0.75 MB
+total wasted percentage is 0.63%
+[0] loss: 10.953
+[1] loss: 10.938
+[2] loss: 10.969
+Time: 2.997 s
+Mem usage: 5592.227 MB
+```
+
+And run with (full) NVME offload:
+
+```python
+train_gemini_cpu(1.0)
+```
+
+We may get:
+
+```
+Model numel: 0.116 B
+searching chunk configuration is completed in 0.27 s.
+used number: 118.68 MB, wasted number: 0.75 MB
+total wasted percentage is 0.63%
+[0] loss: 10.953
+[1] loss: 10.938
+[2] loss: 10.969
+Time: 3.691 s
+Mem usage: 5298.344 MB
+```
+
+NVME offload saves about 294 MB memory. Note that enabling `pin_memory` of Gemini can accelerate training but increase memory usage. So this result also meets our expectation. If we disable `pin_memory`, we can aslo observe a memory usage drop about 900 MB.
+
+## API Reference
+
+{{ autodoc:colossalai.nn.optimizer.HybridAdam }}
+
+{{ autodoc:colossalai.nn.optimizer.CPUAdam }}
+
+
+
diff --git a/docs/source/en/features/pipeline_parallel.md b/docs/source/en/features/pipeline_parallel.md
new file mode 100644
index 000000000000..ac49863b3c71
--- /dev/null
+++ b/docs/source/en/features/pipeline_parallel.md
@@ -0,0 +1,159 @@
+# Pipeline Parallel
+
+Author: Guangyang Lu, Hongxin Liu, Yongbin Li
+
+**Prerequisite**
+- [Define Your Configuration](../basics/define_your_config.md)
+- [Use Engine and Trainer in Training](../basics/engine_trainer.md)
+- [Configure Parallelization](../basics/configure_parallelization.md)
+
+**Example Code**
+- [ColossalAI-Examples ResNet with pipeline](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/pipeline_parallel)
+
+**Related Paper**
+- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
+- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473)
+- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965)
+
+## Quick introduction
+
+In this tutorial, you will learn how to use pipeline parallel. In Colossal-AI, we use 1F1B pipeline, introduced by Nvidia. In this case, ViT and Imagenet are too large to use. Therefore, here we use ResNet and Cifar as example.
+
+## Table Of Content
+
+In this tutorial we will cover:
+
+1. Introduction of 1F1B pipeline.
+2. Usage of non-interleaved and interleaved schedule.
+3. Training ResNet with pipeline.
+
+## Introduction of 1F1B pipeline
+
+First of all, we will introduce you GPipe for your better understanding.
+
+
+
+Figure1: GPipe. This figure is from Megatron-LM paper.
+
+
+
+As you can see, for GPipe, only when the forward passes of all microbatches in a batch finish, the backward passes would be executed.
+
+In general, 1F1B(one forward pass followed by one backward pass) is more efficient than GPipe(in memory or both memory and time). There are two schedules of 1F1B pipeline, the non-interleaved and the interleaved. The figures are shown below.
+
+
+
+Figure2: This figure is from Megatron-LM paper. The top part shows the default non-interleaved schedule. And the bottom part shows the interleaved schedule.
+
+
+### Non-interleaved Schedule
+
+The non-interleaved schedule can be divided into three stages. The first stage is the warm-up stage, where workers perform differing numbers of forward passes. At the following stage, workers perform one forward pass followed by one backward pass. Workers will finish backward passes at the last stage.
+
+This mode is more memory-efficient than GPipe. However, it would take the same time to finish a turn of passes as GPipe.
+
+### Interleaved Schedule
+
+This schedule requires **the number of microbatches to be an integer multiple of the stage of pipeline**.
+
+In this schedule, each device can perform computation for multiple subsets of layers(called a model chunk) instead of a single contiguous set of layers. i.e. Before device 1 had layer 1-4; device 2 had layer 5-8; and so on. But now device 1 has layer 1,2,9,10; device 2 has layer 3,4,11,12; and so on. With this scheme, each device in the pipeline is assigned multiple pipeline stages and each pipeline stage has less computation.
+
+This mode is both memory-efficient and time-efficient.
+
+## Usage of non-interleaved and interleaved schedule
+
+In Colossal-AI, we provided both non-interleaved(as `PipelineSchedule`) and interleaved schedule(as `InterleavedPipelineSchedule`).
+
+You just need to set `NUM_MICRO_BATCHES` in config file and set `NUM_CHUNKS` in config file if you want to use Interleaved Pipeline Schedule. If you certainly know the shape of each pipeline stage's output tensor and the shapes are all the same, you can set `TENSOR_SHAPE` in config file to further reduce communication. Otherwise, you can just ignore `tensor_shape`, and the shape will be exchanged over pipeline stages automatically. Then we will generate an appropriate schedule for you.
+
+## Training ResNet with pipeline
+
+Let's build the `ResNet` model first with Colossal PipelinableContext:
+```python
+import os
+from typing import Callable, List, Optional, Type, Union
+import torch
+import torch.nn as nn
+import colossalai
+import colossalai.nn as col_nn
+
+from colossalai.core import global_context as gpc
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.trainer import Trainer, hooks
+from colossalai.utils import MultiTimer, get_dataloader
+from colossalai.context import ParallelMode
+from colossalai.pipeline.pipelinable import PipelinableContext
+
+from titans.dataloader.cifar10 import build_cifar
+from torchvision.models import resnet50
+from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1
+
+# Define some config
+BATCH_SIZE = 64
+NUM_EPOCHS = 2
+NUM_CHUNKS = 1
+CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))
+
+# Train
+disable_existing_loggers()
+parser = colossalai.get_default_parser()
+args = parser.parse_args()
+colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
+logger = get_dist_logger()
+pipelinable = PipelinableContext()
+
+# build model
+with pipelinable:
+ model = resnet50()
+```
+
+Define an execution sequence.
+```python
+exec_seq = [
+ 'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool',
+ (lambda x: torch.flatten(x, 1), "behind"), 'fc'
+]
+pipelinable.to_layer_list(exec_seq)
+```
+
+Partition the model into pipeline.
+```python
+model = pipelinable.partition(NUM_CHUNKS, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
+```
+
+In this tutorial, we use `Trainer` to train `ResNet`:
+```python
+# build criterion
+criterion = nn.CrossEntropyLoss()
+
+# optimizer
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+# build dataloader
+root = os.environ.get('DATA', './data')
+train_dataloader, test_dataloader = build_cifar(BATCH_SIZE, root, padding=4, crop=32, resize=32)
+
+lr_scheduler = col_nn.lr_scheduler.LinearWarmupLR(optimizer, NUM_EPOCHS, warmup_steps=1)
+engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model, optimizer, criterion,
+ train_dataloader, test_dataloader,
+ lr_scheduler)
+timer = MultiTimer()
+
+trainer = Trainer(engine=engine, timer=timer, logger=logger)
+
+hook_list = [
+ hooks.LossHook(),
+ hooks.AccuracyHook(col_nn.metric.Accuracy()),
+ hooks.LogMetricByEpochHook(logger),
+ hooks.LRSchedulerHook(lr_scheduler, by_epoch=True)
+]
+
+trainer.fit(train_dataloader=train_dataloader,
+ epochs=NUM_EPOCHS,
+ test_dataloader=test_dataloader,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True)
+```
+
+We use `2` pipeline stages and the batch will be splitted into `4` micro batches.
diff --git a/docs/source/en/features/zero_with_chunk.md b/docs/source/en/features/zero_with_chunk.md
new file mode 100644
index 000000000000..6b0a9585af85
--- /dev/null
+++ b/docs/source/en/features/zero_with_chunk.md
@@ -0,0 +1,265 @@
+# Zero Redundancy Optimizer with chunk-based memory management
+
+Author: [Hongxiu Liu](https://github.com/ver217), [Jiarui Fang](https://github.com/feifeibear), [Zijian Ye](https://github.com/ZijianYY)
+
+**Prerequisite:**
+- [Define Your Configuration](../basics/define_your_config.md)
+
+**Example Code**
+
+- [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt)
+
+**Related Paper**
+
+- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054)
+- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
+- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)
+- [DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters](https://dl.acm.org/doi/10.1145/3394486.3406703)
+- [PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management](https://arxiv.org/abs/2108.05818)
+
+## Introduction
+
+The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning three
+model states (optimizer states, gradients, and parameters) instead of replicating them.
+By doing so, memory efficiency is boosted drastically compared to classic data parallelism, while the computational granularity
+and communication efficiency is retained.
+
+1. **Shard Optimizer States**: The optimizer states (e.g., for [Adam optimizer](https://arxiv.org/abs/1412.6980), 32-bit weights,
+and the first and second momentum estimates) are partitioned across the processes, so that each process updates only its partition.
+
+
+2. **Shard Gradient**: After reduction inside data parallel process group, gradient tensors are also partitioned such that each process only stores the gradients corresponding to its partition of the optimizer states. Note, Colossal converts gradient into fp32 format to participate in parameter updating.
+
+3. **Shard Parameter**: The 16-bit model parameters are partitioned across the processes of a data parallel group.
+
+4. **[Gemini](../advanced_tutorials/meet_gemini.md)**: Dynamic heterogeneous memory space manager for paramters, gradients and optimizer states.
+
+Besides, this article will introduce the Zero Redundancy Optimizer with chunk-based memory management.
+
+When using ZeRO, we distributed the model by sharding the parameters. The advantage of this method is that the memory of each node is load balanced. But this approach has two significiant disadvantages. First, during communication, a temporary memory buffer needs to be allocated and released afterwards, leading to the memory fragmentation problem. Secondly, using tensor as the granularity for communication will cause the network bandwidth underutilized. Generally, the longer the transmitted message length, the higher the bandwidth utilization.
+
+Using the Chunk mechanism introduced in ColossalAI v0.1.8, we can improve the efficiency of ZeRO. We store a continuous set of parameters in initialization order into a Chunk (a chunk is a continuous memory space), and each Chunk has the same size. Organizing memory in chunks can lead to efficient use of network bandwidth between PCI-e and GPU-GPU, reduce the number of communications, and avoid potential memory fragmentation.
+
+Before v0.1.8, ZeRO had a high communication cost for parameter communications. If a parameter was used multiple times in several consecutive operators, there will be repeated communications operations, and the efficiency was highly damaged. This situation is very common when using the Gradient Checkpoint technique, and the parameter will recompute the forward propagation during backward propagation.
+
+Taking GPT as an example, its Checkpoint will be applied to each GPT Block, and each GPT Block contains a Self-Attention layer and an MLP layer. During the backward pass, the forward of the Self-Attention layer and the MLP layer will be computed in turn, and then the backward of the MLP layer and the Self-Attention layer will be computed in turn.
+
+In addition, due to the communication and memory movement of small Tensors, the bandwidth of NVLINK and PCI-E cannot be fully utilized, and each communication and memory movement has the overhead of kernel launch. After using Chunk, multiple small Tensor communication and memory movement can be changed into one large Tensor communication and memory movement, which not only improves bandwidth utilization but also reduces the overhead of kernel launch.
+
+We also provide a lightweight chunk search mechanism to help users automatically find the chunk size with the smallest memory fragmentation.
+
+## Usage
+
+### GeminiDDP
+
+We will use `GeminiDDP` to use ZeRO with chunk-based memory management. This is our new torch.Module wrapper which uses ZeRO-DP and Gemini. ZeRO is for parallelism and Gemini is for memory management.
+
+Also Make sure that your model is initialized under the context of ColoInitContext.
+
+```python
+with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
+ model = gpt2_medium(checkpoint=True)
+```
+
+Define the model parameters as follows:
+
+```python
+chunk_manager = init_chunk_manager(model=module,
+ init_device=device,
+ hidden_dim=hidden_dim,
+ search_range_mb=search_range_mb,
+ min_chunk_size_mb=min_chunk_size_mb)
+gemini_manager = GeminiManager(placement_policy, chunk_manager)
+```
+
+`hidden_dim` is the hidden dimension of DNN. Users can provide this argument to speed up searching. If users do not know this argument before training, it is ok. We will use a default value 1024. `min_chunk_size_mb` is the the minimum chunk size in MegaByte. If the aggregate size of parameters is still samller than the minimum chunk size, all parameters will be compacted into one small chunk.
+
+Initialization of the optimizer.
+```python
+optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
+```
+
+Training
+```python
+optimizer.zero_grad()
+outputs = model(input_ids, attn_mask)
+loss = criterion(outputs, input_ids)
+optimizer.backward(loss)
+optimizer.step()
+```
+> ⚠️ Note: Please do not use `loss.backward()`, the standard way of writing is `optimizer.backward(loss)`.
+
+### Train GPT
+
+In this example, we use `Hugging Face Transformers`. You have to install `transformers` before running this example. We will take `GPT2 Medium` as an example here.
+
+For simplicity, we just use randomly generated data here.
+
+First we only need to import `GPT2LMHeadModel` from `Huggingface transformers` to define our model, which does not require users to define or modify the model, so that users can use it more conveniently.
+
+```python
+class GPTLMModel(nn.Module):
+
+ def __init__(self,
+ hidden_size=768,
+ num_layers=12,
+ num_attention_heads=12,
+ max_seq_len=1024,
+ vocab_size=50257,
+ checkpoint=False):
+ super().__init__()
+ self.checkpoint = checkpoint
+ self.model = GPT2LMHeadModel(
+ GPT2Config(n_embd=hidden_size,
+ n_layer=num_layers,
+ n_head=num_attention_heads,
+ n_positions=max_seq_len,
+ n_ctx=max_seq_len,
+ vocab_size=vocab_size))
+ if checkpoint:
+ self.model.gradient_checkpointing_enable()
+
+ def forward(self, input_ids, attention_mask):
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
+
+def gpt2_medium(checkpoint=False):
+ return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
+```
+
+Define our loss function:
+
+```python
+class GPTLMLoss(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.loss_fn = nn.CrossEntropyLoss()
+
+ def forward(self, logits, labels):
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+```
+
+Define tensor parallel and parameter sharding strategies for tensor parallelism:
+
+```python
+def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
+ for mn, module in model.named_modules():
+ for pn, param in module.named_parameters(recurse=False):
+ if hasattr(param, 'visited'):
+ continue
+ param.set_dist_spec(ReplicaSpec())
+ if 'mlp.c_fc' in mn:
+ if 'weight' in pn or 'bias' in pn:
+ split_param_col_tp1d(param, pg)
+ param.compute_spec.set_output_replicate(False)
+ else:
+ param.set_dist_spec(ReplicaSpec())
+ elif 'mlp.c_proj' in mn:
+ if 'weight' in pn:
+ split_param_row_tp1d(param, pg)
+ else:
+ param.set_dist_spec(ReplicaSpec())
+ elif 'wte' in mn or 'wpe' in mn:
+ split_param_col_tp1d(param, pg)
+ elif 'c_attn' in mn or 'c_proj' in mn:
+ split_param_col_tp1d(param, pg)
+ else:
+ param.set_dist_spec(ReplicaSpec())
+
+ param.visited = True
+def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
+ spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
+ param.set_tensor_spec(*spec)
+
+
+def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
+ split_param_single_dim_tp1d(0, param, pg)
+
+
+def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
+ split_param_single_dim_tp1d(-1, param, pg)
+```
+
+Define a model which uses Gemini + ZeRO DDP:
+
+```python
+def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
+ cai_version = colossalai.__version__
+ if version.parse(cai_version) > version.parse("0.1.10"):
+ from colossalai.nn.parallel import GeminiDDP
+ model = GeminiDDP(model,
+ device=get_current_device(),
+ placement_policy=placememt_policy,
+ pin_memory=True,
+ search_range_mb=32)
+ elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
+ from colossalai.gemini import ChunkManager, GeminiManager
+ chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
+ gemini_manager = GeminiManager(placememt_policy, chunk_manager)
+ chunk_manager = ChunkManager(chunk_size,
+ pg,
+ enable_distributed_storage=True,
+ init_device=GeminiManager.get_default_device(placememt_policy))
+ model = ZeroDDP(model, gemini_manager)
+ else:
+ raise NotImplemented(f"CAI version {cai_version} is not supported")
+ return model
+```
+
+As we pre-train GPT in this example, we just use a simple language model loss.
+
+Write a function to get random inputs:
+
+```python
+def get_data(batch_size, seq_len, vocab_size):
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
+ attention_mask = torch.ones_like(input_ids)
+ return input_ids, attention_mask
+```
+
+Finally, we can define our training loop:
+
+```python
+def main():
+ args = parse_args()
+ BATCH_SIZE = 8
+ SEQ_LEN = 1024
+ VOCAB_SIZE = 50257
+ NUM_STEPS = 10
+ colossalai.launch_from_torch(config={})
+
+ # build criterion
+ criterion = GPTLMLoss()
+
+ torch.manual_seed(123)
+ default_pg = ProcessGroup(tp_degree=args.tp_degree)
+ default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
+ # build GPT model
+ with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
+ model = gpt2_medium(checkpoint=True)
+ pg = default_pg
+ # Tensor Parallelism (TP)
+ tensor_parallelize(model, pg)
+ # Gemini + ZeRO DP, Note it must be used after TP
+ model = gemini_zero_dpp(model, pg, args.placement)
+ # build optimizer
+ optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
+ numel = sum([p.numel() for p in model.parameters()])
+ get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
+ torch.cuda.synchronize()
+ model.train()
+ for n in range(NUM_STEPS):
+ # we just use randomly generated data here
+ input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
+ optimizer.zero_grad()
+ outputs = model(input_ids, attn_mask)
+ loss = criterion(outputs, input_ids)
+ optimizer.backward(loss)
+ optimizer.step()
+
+ torch.cuda.synchronize()
+```
+> ⚠️ Note: If you want to use the Gemini module, please do not use the [Gradient Accumulation](../features/gradient_accumulation.md) we mentioned before。
+The complete example can be found on [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt).
diff --git a/docs/source/en/get_started/installation.md b/docs/source/en/get_started/installation.md
new file mode 100644
index 000000000000..672fd8ae03a4
--- /dev/null
+++ b/docs/source/en/get_started/installation.md
@@ -0,0 +1,50 @@
+# Setup
+
+Requirements:
+- PyTorch >= 1.11 (PyTorch 2.x in progress)
+- Python >= 3.7
+- CUDA >= 11.0
+
+If you encounter any problem about installation, you may want to raise an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) in this repository.
+
+
+## Download From PyPI
+
+You can install Colossal-AI with
+
+```shell
+pip install colossalai
+```
+
+**Note: only Linux is supported for now**
+
+If you want to build PyTorch extensions during installation, you can use the command below. Otherwise, the PyTorch extensions will be built during runtime.
+
+```shell
+CUDA_EXT=1 pip install colossalai
+```
+
+
+## Download From Source
+
+> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problem. :)
+
+```shell
+git clone https://github.com/hpcaitech/ColossalAI.git
+cd ColossalAI
+
+# install dependency
+pip install -r requirements/requirements.txt
+
+# install colossalai
+pip install .
+```
+
+If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
+
+```shell
+CUDA_EXT=1 pip install .
+```
+
+
+
diff --git a/docs/source/en/get_started/reading_roadmap.md b/docs/source/en/get_started/reading_roadmap.md
new file mode 100644
index 000000000000..476c524ac011
--- /dev/null
+++ b/docs/source/en/get_started/reading_roadmap.md
@@ -0,0 +1,19 @@
+# Reading Roadmap
+
+Colossal-AI provides a collection of parallel training components for you. We aim to support you with your development
+of distributed deep learning models just like how you write single-GPU deep learning models. ColossalAI provides easy-to-use
+APIs to help you kickstart your training process. To better how ColossalAI works, we recommend you to read this documentation
+in the following order.
+
+- If you are not familiar with distributed system or have never used Colossal-AI, you should first jump into the `Concepts`
+section to get a sense of what we are trying to achieve. This section can provide you with some background knowledge on
+distributed training as well.
+- Next, you can follow the `basics` tutorials. This section will cover the details about how to use Colossal-AI.
+- Afterwards, you can try out the features provided in Colossal-AI by reading `features` section. We will provide a codebase for each tutorial. These tutorials will cover the
+basic usage of Colossal-AI to realize simple functions such as data parallel and mixed precision training.
+- Lastly, if you wish to apply more complicated techniques such as how to run hybrid parallel on GPT-3, the
+`advanced tutorials` section is the place to go!
+
+**We always welcome suggestions and discussions from the community, and we would be more than willing to help you if you
+encounter any issue. You can raise an [issue](https://github.com/hpcaitech/ColossalAI/issues) here or create a discussion
+topic in the [forum](https://github.com/hpcaitech/ColossalAI/discussions).**
diff --git a/docs/source/en/get_started/run_demo.md b/docs/source/en/get_started/run_demo.md
new file mode 100644
index 000000000000..f47bdbbd62fc
--- /dev/null
+++ b/docs/source/en/get_started/run_demo.md
@@ -0,0 +1,43 @@
+# Quick Demo
+
+Colossal-AI is an integrated large-scale deep learning system with efficient parallelization techniques. The system can
+accelerate model training on distributed systems with multiple GPUs by applying parallelization techniques. The system
+can also run on systems with only one GPU. Quick demos showing how to use Colossal-AI are given below.
+
+## Single GPU
+
+Colossal-AI can be used to train deep learning models on systems with only one GPU and achieve baseline
+performances. We provided an example to [train ResNet on CIFAR10 dataset](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet)
+with only one GPU. You can find the example in [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples).
+Detailed instructions can be found in its `README.md`.
+
+## Multiple GPUs
+
+Colossal-AI can be used to train deep learning models on distributed systems with multiple GPUs and accelerate the
+training process drastically by applying efficient parallelization techniques. When we have several parallelism for you
+to try out.
+
+#### 1. data parallel
+
+You can use the same [ResNet example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/resnet) as the
+single-GPU demo above. By setting `--nproc_per_node` to be the number of GPUs you have on your machine, the example
+is turned into a data parallel example.
+
+#### 2. hybrid parallel
+
+Hybrid parallel includes data, tensor, and pipeline parallelism. In Colossal-AI, we support different types of tensor
+parallelism (i.e. 1D, 2D, 2.5D and 3D). You can switch between different tensor parallelism by simply changing the configuration
+in the `config.py`. You can follow the [GPT example](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt).
+Detailed instructions can be found in its `README.md`.
+
+#### 3. MoE parallel
+
+We provided [an example of WideNet](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) to demonstrate
+MoE parallelism. WideNet uses mixture of experts (MoE) to achieve better performance. More details can be found in
+[Tutorial: Integrate Mixture-of-Experts Into Your Model](../advanced_tutorials/integrate_mixture_of_experts_into_your_model.md)
+
+#### 4. sequence parallel
+
+Sequence parallel is designed to tackle memory efficiency and sequence length limit problems in NLP tasks. We provided
+[an example of BERT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/bert/sequene_parallel) in
+[ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples). You can follow the `README.md` to execute the code.
diff --git a/docs/source/en/sidebar_category_translation.json b/docs/source/en/sidebar_category_translation.json
new file mode 100644
index 000000000000..9cc320424e40
--- /dev/null
+++ b/docs/source/en/sidebar_category_translation.json
@@ -0,0 +1,26 @@
+{
+ "sidebar.tutorialSidebar.category.Get started": {
+ "message": "Get started",
+ "description": "The label for category Get started in sidebar tutorialSidebar"
+ },
+ "sidebar.tutorialSidebar.category.Concepts": {
+ "message": "Concepts",
+ "description": "The label for category Concepts in sidebar tutorialSidebar"
+ },
+ "sidebar.tutorialSidebar.category.Basics": {
+ "message": "Basics",
+ "description": "The label for category Basics in sidebar tutorialSidebar"
+ },
+ "sidebar.tutorialSidebar.category.Features": {
+ "message": "Features",
+ "description": "The label for category Features in sidebar tutorialSidebar"
+ },
+ "sidebar.tutorialSidebar.category.Tensor Parallel": {
+ "message": "Tensor Parallel",
+ "description": "The label for category Tensor Parallel in sidebar tutorialSidebar"
+ },
+ "sidebar.tutorialSidebar.category.Advanced Tutorials": {
+ "message": "Advanced Tutorials",
+ "description": "The label for category Advanced Tutorials in sidebar tutorialSidebar"
+ }
+}
diff --git a/docs/source/zh-Hans/Colossal-Auto/feature/auto_checkpoint.md b/docs/source/zh-Hans/Colossal-Auto/feature/auto_checkpoint.md
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/docs/source/zh-Hans/Colossal-Auto/feature/device_mesh.md b/docs/source/zh-Hans/Colossal-Auto/feature/device_mesh.md
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/docs/source/zh-Hans/Colossal-Auto/feature/layout_converting_management.md b/docs/source/zh-Hans/Colossal-Auto/feature/layout_converting_management.md
new file mode 100644
index 000000000000..71bce57ea91b
--- /dev/null
+++ b/docs/source/zh-Hans/Colossal-Auto/feature/layout_converting_management.md
@@ -0,0 +1,12 @@
+当一个张量在上下游算子中被要求的sharding spec不同时,我们需要进行分布转换处理(Layout Conversion)。目前主流的方式有两种,打表转换和逐维度转换。打表转换就是将所有可能的情况枚举出来,然后在遇到需要转换的情况下,去表格中找到对应的转换方案。
+为了解决这个问题,我们提出一个新奇的想法,使用启发式的搜索,来解决sharding spec的转换问题。
+然而它有一个很大问题,就是随着设备块(Device Mesh)的维度增加,这个问题的规模极具膨胀,以至于无法通过这种枚举打表的方式来解决。逐维度转换是对于一个N-d tensor的sharding spec,X0X1...Xn-1,我们让i从0到n-1逐维度地进行转换,这样不管设备块和张量的维度多少,我们都只需要一次扫描,就可以得到一个可行的转换操作序列,然而它问题是这样的转换效率会很差。为了解决这个问题,我们提出一个新奇的想法,使用启发式算法,来解决sharding spec的转换问题。,这个算法可以描述为:
+ 1. 从source spec生成所有的one-step transform sharding specs
+ 2. 在one-step transform sharding specs中,根据相似度函数,挑选一个”区别最小“的sharding spec作为后续的source sharding spec,并将该sharding spec记录在transform path中,如果one-step transform sharding spec中,有与target sharding spec相同的sharding spec,则算法结束。
+ 3. 重复a,b直到算法结束
+
+| Source/target sharding spec pairs |All gather | Shard | All to All | One step transform | Best sharding spec |Transform path|
+| :-: | :-: | :-: | :-: | :-: | :-: |:-: |
+| $S_{01}RR, RS_{01}R$ | $S_0RR$ | - | $S_0RS_1, S_0S_1R$ | $S_0RR, S_0RS_1, S_0S_1R$ | $S_0RR$ | $S_0RR$
+| $S_0RR, RS_{01}RR$ | $RRR$ | $S_0S_1R, S_0RS_1$ | $RS_0R, RRS_0$ | $RRR$, $S_0S_1R$, $S_0RS_1$, $RS_0R$, $RRS_0$ | $RS_0R$ | $S_0RR$ -> $RS_0R$
+| $RS_0R, RS_{01}RR$ | $RRR$ | $RS_{01}R, S_1S_0R, RS_0S_1$ | $S_0RR, RRS_0$ | $RRR$, $RS_{01}R$, $S_1S_0R$, $RS_0S_1$, $S_0RR$, $RRS_0$ | $RS_{01}R$ | $S_0RR$ -> $RS_0R$ -> $RS_{01}R$
diff --git a/docs/source/zh-Hans/Colossal-Auto/feature/tracer.md b/docs/source/zh-Hans/Colossal-Auto/feature/tracer.md
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/docs/source/zh-Hans/Colossal-Auto/get_started/installation.md b/docs/source/zh-Hans/Colossal-Auto/get_started/installation.md
new file mode 100644
index 000000000000..054b709c92d0
--- /dev/null
+++ b/docs/source/zh-Hans/Colossal-Auto/get_started/installation.md
@@ -0,0 +1,28 @@
+# 安装
+
+## 声明
+
+我们的自动并行功能处于alpha版本,仍在快速的开发迭代中。我们会在兼容性和稳定性上做持续地改进。如果您遇到任何问题,欢迎随时提issue给我们。
+
+
+## 要求
+
+我们需要一些额外的依赖性来支持自动并行功能。 请在使用自动平行之前安装它们。
+
+### 安装PyTorch
+
+我们仅支持Pytorch 1.12,现在未测试其他版本。 将来我们将支持更多版本。
+
+```bash
+#conda
+conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
+#pip
+pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
+```
+
+### 安装pulp和coin-or-cbc
+
+```bash
+pip install pulp
+conda install -c conda-forge coin-or-cbc
+```
diff --git a/docs/source/zh-Hans/Colossal-Auto/get_started/introduction.md b/docs/source/zh-Hans/Colossal-Auto/get_started/introduction.md
new file mode 100644
index 000000000000..bd5326d43220
--- /dev/null
+++ b/docs/source/zh-Hans/Colossal-Auto/get_started/introduction.md
@@ -0,0 +1,41 @@
+# 介绍
+
+近年来,大规模机器学习模型的部署受到越来越多的重视。然而,目前常见的分布式大模型训练方案,都依赖用户**人工反复尝试**和系统专家的经验来进行配置部署。这对绝大多数AI开发者来说十分不友好,因为他们不希望将时间精力花费在研究分布式系统和试错上。
+Colossal-AI的**Colossal-Auto** 帮助AI开发者简化了大规模机器学习模型的部署过程。相比现有其他手动配置复杂并行策略和修改模型的解决方案,Colossal-Auto 仅需增加一行代码,提供 cluster 信息以及单机训练模型即可获得分布式训练能力,并且**原生支持包括 Hugging Face,Timm 等热门 AI 模型库**。
+
+
+
+## 概览
+
+
+
+
+
+## 用法
+```python
+# wrap the model using auto_engine
+model = autoparallelize(model, meta_input_samples)
+# normal training loop
+...
+```
+
+
+## 图追踪
+Colossal-Auto 是**首个基于 PyTorch 框架使用静态图分析的自动并行系统**。PyTorch 作为一个动态图框架,获取其静态的执行计划是机器学习系统领域被长期研究的问题。Colossal-Auto 使用基于 torch.FX Tracer 的 ColoTracer 来完成对于最优并行策略的搜索。在 tracing 过程中推导并记录了每个 tensor 的元信息,例如 tensor shape,dims,dtype 等。因此 Colossal-AI 具有更好的模型泛化能力,而不是依靠模型名或手动修改来适配并行策略。
+
+
+## 细粒度分布式训练策略搜索
+
+我们调研了很多现有的自动并行系统( Tofu , Flexflow , Alpa ),以及自动激活值检查点算法( Rotor , Sublinear ),在他们的启发下,我们开发一个基于PyTorch框架的自动并行系统Colossal-Auto。Colossal-Auto会在满足内存预算的限制下,以最快运行时间为目标,为每个 op 进行策略搜索,最终得到真实训练时的策略,包括每个 tensor 的切分策略,不同计算节点间需要插入的通信算子类型,是否要进行算子替换等。现有系统中的张量并行,数据并行,NVIDIA 在 Megatron-LM 等并行系统中使用的 column 切分和 row 切分并行等混合并行,都是自动并行可以搜索到的策略的子集。除了这些可以手动指定的并行方式外,Colossal-AI 有能力为每个 op 指定独特的并行方式,因此有可能找到比依赖专家经验和试错配置的手动切分更好的并行策略。
+
+
+
+## 分布式 tensor 与 shape consistency 系统
+
+与 PyTorch 最新发布的 DTensor 类似,Colossal-AI 也使用了 device mesh 对集群进行了抽象管理。具体来说,Colossal-AI 使用 sharding spec 对 tensor 的分布式存储状态进行标注,使用 shape consistency manager 自动地对同一 tensor 在不同 sharding spec 间进行转换。这让 Colossal-AI 的通用性和易用性极大地提升,借助 shape consistency manager 可以没有负担地切分 tensor,而不用担心上游 op 的 output 与下游的 input 在集群中的存储方式不同。
+
+
+相较于 PyTorch DTensor,Colossal-AI 有以下优势:
++ Colossal-AI 的 device mesh 可以 profiling 到集群性能指标,对不同的通信算子进行耗时估算。
++ Colossal-AI 的 shape consistency 会贪心地搜索 sharding spec 间的转换方式,而不是朴素地逐 dimension 进行转换,这样能找到更高效的转换路径,进而使得 sharding spec 间的转换通信开销更小。
++ 加入了 all_to_all 操作,使得 Colossal-AI 的扩展性更强,这在大规模集群上进行训练时,可以展现出很大的优势。
diff --git a/docs/source/zh-Hans/Colossal-Auto/get_started/run_demo.md b/docs/source/zh-Hans/Colossal-Auto/get_started/run_demo.md
new file mode 100644
index 000000000000..19316e12b4d5
--- /dev/null
+++ b/docs/source/zh-Hans/Colossal-Auto/get_started/run_demo.md
@@ -0,0 +1,12 @@
+# 快速上手
+
+Colossal-AI 提供了业界急需的一套高效易用自动并行系统。相比现有其他手动配置复杂并行策略和修改模型的解决方案,Colossal-AI 仅需增加一行代码,提供 cluster 信息以及单机训练模型即可获得分布式训练能力。Colossal-Auto的快速上手示例如下。
+
+### 1. 基本用法
+Colossal-Auto 可被用于为每一次操作寻找一个包含数据、张量(如1D、2D、序列化)的混合SPMD并行策略。您可参考[GPT 示例](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt/experiments/auto_parallel)。
+详细的操作指引见其 `README.md`。
+
+### 2. 与 activation checkpoint 结合
+
+作为大模型训练中必不可少的显存压缩技术,Colossal-AI 也提供了对于 activation checkpoint 的自动搜索功能。相比于大部分将最大显存压缩作为目标的技术方案,Colossal-AI 的搜索目标是在显存预算以内,找到最快的 activation checkpoint 方案。同时,为了避免将 activation checkpoint 的搜索一起建模到 SPMD solver 中导致搜索时间爆炸,Colossal-AI 做了 2-stage search 的设计,因此可以在合理的时间内搜索到有效可行的分布式训练方案。 您可参考 [Resnet 示例](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/auto_parallel)。
+详细的操作指引见其 `README.md`。
diff --git a/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
new file mode 100644
index 000000000000..4825a6fa1d6c
--- /dev/null
+++ b/docs/source/zh-Hans/advanced_tutorials/add_your_parallel.md
@@ -0,0 +1,112 @@
+# 添加你自己的并行模式
+
+作者: Shenggui Li, Yongbin Li
+
+**前置教程**
+- [定义配置文件](../basics/define_your_config.md)
+- [并行配置](../basics/configure_parallelization.md)
+
+## 引言
+
+为了使研究人员和工程师能够以更少的努力将我们的系统扩展到其他新颖的大规模分布式训练算法,我们已经将训练生命周期中的各种组件解耦。你可以通过简单地继承基类来实现你自己的并行模式。
+
+主要组件有:
+
+1. `ProcessGroupInitializer`
+2. `GradientHandler`
+3. `Schedule`
+
+**目前这需要对源代码进行一些改动,因此我们建议你用`-e`标志从源代码安装。`-e`标志使得安装是可编辑的,因此,你的代码变化将反映在你的Python运行时中。我们将在这方面努力,以避免在未来的版本中改变源代码。**
+
+
+## 进程组初始化器
+
+并行通常由进程组来管理,参与相同并行算法的进程被置于同一进程组。对于不同的并行算法,需要创建不同的进程组。
+Colossal-AI 为用户提供了一个全局 context,使他们能够轻松地管理进程组。如果你想添加新的进程组,你可以很容易地定义一个新的类并在你的配置文件中设置它。为了定义你自己的进程组创建方式,你可以按照下面的步骤来创建一个新的分布式初始化。
+
+1. 在 `colossalai.context.parallel_mode.ParallelMode` 中添加你自己的并行模式。
+ ```python
+ class ParallelMode(Enum):
+ GLOBAL = 'global'
+ DATA = 'data'
+ PIPELINE = 'pipe'
+ ...
+
+ NEW_MODE = 'new_mode' # define your mode here
+ ```
+
+2. 创建一个 `ProcessGroupInitializer`。 你可以参考 `colossalai.context.dist_group_initializer` 中给出的例子,前六个参数是固定的。
+`ParallelContext` 将为你传入这些参数。如果你需要设置其他参数,可以像下面的例子中的 `arg1, arg2` 一样,在后面添加它。
+最后,通过添加装饰器 `@DIST_GROUP_INITIALIZER.register_module` 将你的初始化程序注册到注册表。
+ ```python
+ # sample initializer class
+ @DIST_GROUP_INITIALIZER.register_module
+ class MyParallelInitializer(ProcessGroupInitializer):
+
+ def __init__(self,
+ rank: int,
+ world_size: int,
+ config: Config,
+ data_parallel_size: int,
+ pipeline_parlalel_size: int,
+ tensor_parallel_size: int,
+ arg1,
+ arg2):
+ super().__init__(rank, world_size, config)
+ self.arg1 = arg1
+ self.arg2 = arg2
+ # ... your variable init
+
+ def init_parallel_groups(self):
+ # initialize your process groups
+ pass
+
+ ```
+ 然后,你可以将你的新初始化器插入到 `colossalai.constants.INITIALIZER_MAPPING` 当前的模式与初始化映射中。你可以修改该文件或动态插入新的键值对。
+
+ ```python
+ colossalai.constants.INITIALIZER_MAPPING['new_mode'] = 'MyParallelInitializer'
+ ```
+
+3. 在你的配置文件中设置你的初始化器。你可以传入你的自定义参数。这允许
+ `ParallelContext` 创建你的初始化器并初始化你期望的进程组。
+
+ ```python
+ parallel = dict(
+ pipeline=dict(size=1),
+ tensor=dict(size=x, mode='new_mode') # this is where you enable your new parallel mode
+ )
+ ```
+
+## 梯度 Handler
+
+梯度 handler 是对参数的梯度执行 all-reduce 操作的对象。由于不同的 all-reduce 策略或许在不同的并行中被执行,用户可以继承
+`colossalai.engine.gradient_handler.BaseGradientHandler` 来实现其策略。目前,Colossal-AI 使用普通的数据并行梯度 handler 在数据并行的 rank 间 all-reduce 梯度。
+如果数据并行被检测到,梯度 handler 会被自动添加进 engine。
+
+你可以添加你自己的梯度 handler,如下所示:
+
+```python
+from colossalai.registry import GRADIENT_HANDLER
+from colossalai.engine import BaseGradientHandler
+
+@GRADIENT_HANDLER.register_module
+class YourGradientHandler(BaseGradientHandler):
+
+ def handle_gradient(self):
+ do_something()
+
+```
+
+之后,你可以在配置文件中指定你要使用的梯度 handler。
+
+```python
+gradient_handlers = [
+ dict(type='YourGradientHandler'),
+]
+```
+
+## Schedule
+
+Schedule 包含了如何执行前向和后向计算。目前, Colossal-AI 提供了流水和非流水的 schedule。
+如果你想修改前向和后向计算的执行方式,你可以继承 `colossalai.engine.schedule.BaseSchedule` 并实现 `forward_back_step` 函数。
diff --git a/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md b/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md
new file mode 100644
index 000000000000..64e8d8bcd14a
--- /dev/null
+++ b/docs/source/zh-Hans/advanced_tutorials/define_your_own_parallel_model.md
@@ -0,0 +1,31 @@
+# 定义你自己的并行模型
+
+作者: Zhengda Bian, Yongbin Li
+
+> ⚠️ 我们正在编写此文档以使其更加详细。 我们将介绍不同并行的机制以及如何使用它们来编写模型。
+
+假设您有一个具有数十亿参数的巨大 MLP 模型,其极大的隐藏层大小使其无法直接被单个 GPU 容纳。别担心,Colossal-AI 可以帮你解决这个问题。
+在 Colossal-AI 的帮助下,您可以用所熟悉的为单个 GPU 编写模型的方式编写大模型,而 Colossal-AI 会自动拆分您的模型权重,并将它们完美地分配到一组 GPU 中。我们给出一个简单的示例,展示如何在 Colossal-AI 中编写简单的 2D 并行模型。
+
+## 写一个简单的2D并行模型
+
+```python
+from colossalai.nn import Linear2D
+import torch.nn as nn
+
+class MLP_2D(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.linear_1 = Linear2D(in_features=1024, out_features=16384)
+ self.linear_2 = Linear2D(in_features=16384, out_features=1024)
+
+ def forward(self, x):
+ x = self.linear_1(x)
+ x = self.linear_2(x)
+ return x
+```
+
+## 使用预定义的模型
+
+为了方便您的使用,我们在 Colossal-AI 的 Model Zoo 中提供一些流行的模型,如*BERT*, *ViT*, *MoE* 和 *GPT*,请自由地将它们定制为不同的尺寸,以满足您的特殊需求。
diff --git a/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md b/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md
new file mode 100644
index 000000000000..456878caa147
--- /dev/null
+++ b/docs/source/zh-Hans/advanced_tutorials/integrate_mixture_of_experts_into_your_model.md
@@ -0,0 +1,140 @@
+# 将 MoE 整合进你的模型
+
+作者: Haichen Huang, Yongbin Li
+
+**前置教程**
+- [ColossalAI-Examples WideNet](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet)
+
+**相关论文**
+- [Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity](https://arxiv.org/abs/2101.03961)
+- [Go Wider Instead of Deeper](https://arxiv.org/abs/2107.11817)
+
+(中文版教程将会在近期提供)
+
+## Introduction
+
+Since the advent of Switch Transformer, the AI community has found Mixture of Experts (MoE) a useful technique to enlarge the capacity of deep learning models.
+
+Colossal-AI provides an early access version of parallelism specifically designed for MoE models.
+The most prominent advantage of MoE in Colossal-AI is convenience.
+We aim to help our users to easily combine MoE with model parallelism and data parallelism.
+
+However, the current implementation has two main drawbacks now.
+The first drawback is its poor efficiency in large batch size and long sequence length training.
+The second drawback is incompatibility with tensor parallelism.
+We are working on system optimization to overcome the training efficiency problem.
+The compatibility problem with tensor parallelism requires more adaptation, and we will tackle this issue in the future.
+
+Here, we will introduce how to use MoE with model parallelism and data parallelism.
+
+## Table of Content
+In this tutorial we will cover:
+1. Set up MoE running environment
+2. Create MoE layer
+3. Train your model
+
+We provided the [example code](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/widenet) for this tutorial in [ColossalAI-Examples](https://github.com/hpcaitech/ColossalAI-Examples).
+This example uses [WideNet](https://arxiv.org/abs/2107.11817) as an example of MoE-based model.
+
+
+## Set up MoE running environment
+In your project folder, create a `config.py`.
+
+This file is to specify some features you may want to use to train your model.
+In order to enable MoE, you need to add a dict called parallel and specify the value of key moe.
+You can assign a value for the key size of moe, which represents the model parallel size of experts (i.e. the number of experts in one group to parallelize training).
+
+For example, if the size is 4, 4 processes will be assigned to 4 consecutive GPUs and these 4 processes form a moe model parallel group.
+Each process on the 4 GPUs will only get a portion of experts. Increasing the model parallel size will reduce communication cost, but increase computation cost in each GPU and activation cost in memory.
+The total data parallel size is auto-detected and set as the number of GPUs by default.
+
+```python
+MOE_MODEL_PARALLEL_SIZE = ...
+parallel = dict(
+ moe=dict(size=MOE_MODEL_PARALLEL_SIZE)
+)
+```
+
+If `MOE_MODEL_PARALLEL_SIZE = E` and set the number of experts as `E` where `E` is a constant number, the process flow of forward pass of a transformer encoder in a model parallel group is shown below.
+
+
+
+MoE Transformer, image source: GShard
+
+
+Since all experts are allocated to all GPUs in a model parallel group and a GPU only owns a portion of experts,
+original data parallel groups are no longer correct for the parameters of experts during gradient handling in backward pass anymore.
+So we create a new kind of parallel group called moe data parallel group.
+The difference among different kinds of parallel group, when the configuration is set as `WORLD_SIZE=4`,
+`MOE_MODEL_PARALLEL_SIZE=2`, is shown here.
+
+
+
+MoE process group
+
+
+
+As for gradient handling, we provide MoeGradientHandler to all-reduce every parameter of the model.
+If you use `colossalai.initialize` function to create your training engine, the MoE gradient handler will be added to your engine automatically.
+Otherwise, you should take care of gradient by yourself.
+All parameters of MoE running environment are stored in colossalai.global_variables.moe_env.
+You can access your configuration parameters to check whether your setup is correct.
+```python
+from colossalai.global_variables import moe_env
+```
+
+## Create MoE layer
+You can create a MoE layer from `colossalai.nn.moe`.
+But before doing that, you should set up random seeds for all processes like this.
+
+```python
+from colossalai.context.random import moe_set_seed
+from model_zoo.moe.models import Widenet
+
+moe_set_seed(42)
+model = Widenet(num_experts=4, capacity_factor=1.2)
+```
+
+`moe_set_seed` will set different seed for different processes in a moe model parallel group.
+This helps initialize parameters in experts.
+Then create an instance of experts and an instance of router.
+Here is the example in model zoo.
+
+```python
+from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator
+
+
+noisy_func = NormalNoiseGenerator(num_experts)
+shared_router = Top2Router(capacity_factor,
+ noisy_func=noisy_func)
+shared_experts = Experts(expert=VanillaFFN,
+ num_experts=num_experts,
+ **moe_mlp_args(
+ d_model=d_model,
+ d_ff=d_ff,
+ drop_rate=drop_rate
+ ))
+ffn=MoeLayer(dim_model=d_model, num_experts=num_experts,
+ router=shared_router, experts=shared_experts)
+```
+
+Inside the initialization of Experts, the local expert number of each GPU will be calculated automatically. You just need to specify the class of each expert and its parameters used in its initialization. As for routers, we have provided top1 router and top2 router. You can find them in colossalai.nn.layer.moe. After creating the instance of experts and router, the only thing initialized in Moelayer is gate module. More definitions of each class can be found in our API document and code.
+
+
+## Train Your Model
+Do not to forget to use `colossalai.initialize` function in `colosalai` to add gradient handler for the engine.
+We handle the back-propagation of MoE models for you.
+In `colossalai.initialize`, we will automatically create a `MoeGradientHandler` object to process gradients.
+You can find more information about the handler `MoeGradientHandler` in colossal directory.
+
+The loss criterion should be wrapped by `Moeloss` to add auxiliary loss of MoE. Example is like this.
+```python
+criterion = MoeLoss(
+ aux_weight=0.01,
+ loss_fn=nn.CrossEntropyLoss,
+ label_smoothing=0.1
+)
+```
+
+Finally, just use trainer or engine in `colossalai` to do your training.
+Otherwise, you should take care of gradient by yourself.
diff --git a/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md b/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md
new file mode 100644
index 000000000000..2bf0a9c98c3f
--- /dev/null
+++ b/docs/source/zh-Hans/advanced_tutorials/meet_gemini.md
@@ -0,0 +1,96 @@
+# 认识Gemini:ColossalAI的异构内存空间管理器
+
+作者: [Jiarui Fang](https://github.com/feifeibear)
+
+## 简介
+
+在GPU数量不足情况下,想要增加模型规模,异构训练是最有效的手段。它通过在 CPU 和 GPU 中容纳模型数据,并仅在必要时将数据移动到当前设备,可以同时利用 GPU 内存、CPU 内存(由 CPU DRAM 或 NVMe SSD内存组成)来突破单GPU内存墙的限制。并行,在大规模训练下,其他方案如数据并行、模型并行、流水线并行都可以在异构训练基础上进一步扩展GPU规模。这篇文章描述ColossalAI的异构内存空间管理模块Gemini的设计细节,它的思想来源于[PatrickStar](https://arxiv.org/abs/2108.05818),ColossalAI根据自身情况进行了重新实现。
+
+## 用法
+
+目前Gemini支持和ZeRO并行方式兼容,它的使用方法很简单,在训练策略的配置文件里设置zero的model_config属性tensor_placement_policy='auto'
+
+```
+zero = dict(
+ model_config=dict(
+ reduce_scatter_bucket_size_mb=25,
+ fp32_reduce_scatter=False,
+ gradient_predivide_factor=1.0,
+ tensor_placement_policy="auto",
+ shard_strategy=TensorShardStrategy(),
+ ...
+ ),
+ optimizer_config=dict(
+ ...
+ )
+)
+```
+
+注意,Gemini和并行策略,如Tensor Parallelism,Data Parallelism,Pipeline Parallelism,ZeRO是解耦合的。对TP,PP的支持还在开发中。
+
+## 术语
+
+**算子**(**OP**erator):一个神经网络层的计算操作,比如Linear,LayerNorm等。算子可以是正向传播的计算,也可以是反向传播的计算。
+
+神经网络在训练期间必须管理的两种类型的训练数据。
+
+**模型数据(model data)**: 由参数、梯度和优化器状态组成,其规模与模型结构定义相关
+
+**非模型数据(non-model data)**: 主要由算子生成的中间张量和算子的临时变量组成。非模型数据根据训练任务的配置动态变化,例如批量大小。模型数据和非模型数据相互竞争 GPU 内存。
+
+## 设计
+
+目前的一些解决方案,DeepSpeed采用的[Zero-offload](https://arxiv.org/abs/2101.06840)在CPU和GPU内存之间静态划分模型数据,并且它们的内存布局对于不同的训练配置是恒定的。如下图左边所示,当 GPU 内存不足以满足其相应的模型数据要求时,即使当时CPU上仍有可用内存,系统也会崩溃。而ColossalAI可以通过将一部分模型数据换出到CPU上来完成训练。
+
+
+
+比较Zero-Offload和Gemini的内存管理方案
+
+
+
+ColossalAI设计了Gemini,就像双子星一样,它管理CPU和GPU二者内存空间。它可以让张量在训练过程中动态分布在CPU-GPU的存储空间内,从而让模型训练突破GPU的内存墙。内存管理器由两部分组成,分别是MemStatsCollector(MSC)和StatefuleTensorMgr(STM)。
+
+
+我们利用了深度学习网络训练过程的迭代特性。我们将迭代分为warmup和non-warmup两个阶段,开始时的一个或若干迭代步属于预热阶段,其余的迭代步属于正式阶段。在warmup阶段我们为MSC收集信息,而在non-warmup阶段STM入去MSC收集的信息来移动tensor,以达到最小化CPU-GPU数据移动volume的目的。
+
+
+
+Gemini在不同训练阶段的运行流程
+
+
+
+### StatefulTensorMgr
+
+STM管理所有model data tensor的信息。在模型的构造过程中,ColossalAI把所有model data张量注册给STM。内存管理器给每个张量标记一个状态信息。状态集合包括HOLD,COMPUTE,FREE三种状态。STM的功能如下:
+
+**查询内存使用:**通过遍历所有tensor的在异构空间的位置,获取模型数据对CPU和GPU的内存占用。
+
+**转换张量状态:**它在每个模型数据张量参与算子计算之前,将张量标记为COMPUTE状态,在计算之后标记为HOLD状态。如果张量不再使用则标记的FREE状态。
+
+**调整张量位置:**张量管理器保证COMPUTE状态的张量被放置在计算设备上,如果计算设备的存储空间不足,则需要移动出一些HOLD状态的张量到其他设备上存储。Tensor eviction strategy需要MSC的信息,我们将在后面介绍。
+
+
+### MemStatsCollector
+在预热阶段,内存信息统计器监测CPU和GPU中模型数据和非模型数据的内存使用情况,供正式训练阶段参考。我们通过查询STM可以获得模型数据在某个时刻的内存使用。但是非模型的内存使用却难以获取。因为非模型数据的生存周期并不归用户管理,现有的深度学习框架没有暴露非模型数据的追踪接口给用户。MSC通过采样方式在预热阶段获得非模型对CPU和GPU内存的使用情况。具体方法如下:
+
+我们在算子的开始和结束计算时,触发内存采样操作,我们称这个时间点为**采样时刻(sampling moment)**,两个采样时刻之间的时间我们称为**period**。计算过程是一个黑盒,由于可能分配临时buffer,内存使用情况很复杂。但是,我们可以较准确的获取period的系统最大内存使用。非模型数据的使用可以通过两个统计时刻之间系统最大内存使用-模型内存使用获得。
+
+我们如何设计采样时刻呢。我们选择preOp的model data layout adjust之前。如下图所示。我们采样获得上一个period的system memory used,和下一个period的model data memoy used。并行策略会给MSC的工作造成障碍。如图所示,比如对于ZeRO或者Tensor Parallel,由于Op计算前需要gather模型数据,会带来额外的内存需求。因此,我们要求在模型数据变化前进行采样系统内存,这样在一个period内,MSC会把preOp的模型变化内存捕捉。比如在period 2-3内,我们考虑的tensor gather和shard带来的内存变化。
+尽管可以将采样时刻放在其他位置,比如排除gather buffer的变动新信息,但是会给造成麻烦。不同并行方式Op的实现有差异,比如对于Linear Op,Tensor Parallel中gather buffer的分配在Op中。而对于ZeRO,gather buffer的分配是在PreOp中。将放在PreOp开始时采样有利于将两种情况统一。
+
+
+尽管可以将采样时刻放在其他位置,比如排除gather buffer的变动新信息,但是会给造成麻烦。不同并行方式Op的实现有差异,比如对于Linear Op,Tensor Parallel中gather buffer的分配在Op中。而对于ZeRO,gather buffer的分配是在PreOp中。将放在PreOp开始时采样有利于将两种情况统一。
+
+
+
+Sampling based MemStatsCollector
+
+
+### Tensor Eviction Strategy
+
+MSC的重要职责是在调整tensor layout位置,比如在上图S2时刻,我们减少设备上model data数据,Period 2-3计算的峰值内存得到满足。
+
+在warmup阶段,由于还没执行完毕一个完整的迭代,我们对内存的真实使用情况尚一无所知。我们此时限制模型数据的内存使用上限,比如只使用30%的GPU内存。这样保证我们可以顺利完成预热状态。
+
+在non-warmup阶段,我们需要利用预热阶段采集的非模型数据内存信息,预留出下一个Period在计算设备上需要的峰值内存,这需要我们移动出一些模型张量。
+为了避免频繁在CPU-GPU换入换出相同的tensor,引起类似[cache thrashing](https://en.wikipedia.org/wiki/Thrashing_(computer_science))的现象。我们利用DNN训练迭代特性,设计了OPT cache换出策略。具体来说,在warmup阶段,我们记录每个tensor被计算设备需要的采样时刻。如果我们需要驱逐一些HOLD tensor,那么我们选择在本设备上最晚被需要的tensor作为受害者。
diff --git a/docs/source/zh-Hans/advanced_tutorials/opt_service.md b/docs/source/zh-Hans/advanced_tutorials/opt_service.md
new file mode 100644
index 000000000000..a213584fd41d
--- /dev/null
+++ b/docs/source/zh-Hans/advanced_tutorials/opt_service.md
@@ -0,0 +1,79 @@
+# Colossal-AI使用指南:5分钟搭建在线OPT服务
+
+## 介绍
+
+本指导手册将说明如何利用[Colossal-AI](https://github.com/hpcaitech/ColossalAI)搭建您自己的OPT服务。
+
+## Colossal-AI 推理概述
+Colossal-AI 提供了一个推理子系统 [Energon-AI](https://github.com/hpcaitech/EnergonAI), 这是一个基于Colossal-AI的服务系统,拥有以下特性:
+
+- **大模型并行:** 在Colossal-AI的张量并行和流水线并行策略的帮助下,Colossal-AI的推理可实现大模型的高效并行推理。
+- **预构建大模型:** Colossal-AI提供热门模型的预构建部署,例如OPT。其支持用于生成任务和加载检查点的缓存技术。
+- **引擎封装:** Colossal-AI中有一个抽象层被称作引擎。其将单实例多设备(SIMD) 执行与远程过程调用封装在一起。
+- **在线服务系统:** 基于FastAPI,用户可以快速启动分布式推理的网络服务。 在线服务对生成任务进行了特殊优化。它采用left padding和bucket batching两种技术来提高效率。
+
+## 基本用法
+
+1. 下载OPT模型
+
+想要快速发布分布式推理服务,您从[此处](https://huggingface.co/patrickvonplaten/opt_metaseq_125m/blob/main/model/restored.pt)下载OPT-125M。有关加载其他体量模型的详细方法,您可访问[此处](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt/script)。
+
+2. 准备提前构建的服务镜像
+
+从dockerhub拉取一个已经安装Colossal-AI推理的docker镜像。
+
+```bash
+docker pull hpcaitech/energon-ai:latest
+```
+
+3. 发布HTTP服务
+
+若想发布服务,我们需要准备python脚本来描述模型的类型和相关的部署,以及HTTP服务的设置。 我们为您提供了一组[示例](https://github.com/hpcaitech/EnergonAI/tree/main/examples])。 我们将在本指导手册中使用[OPT 示例](https://github.com/hpcaitech/EnergonAI/tree/main/examples/opt)。
+服务的入口是一个bash脚本 server.sh。
+本服务的配置文件参考 opt_config.py,该文件定义了模型的类型、 检查点文件路径、并行策略和http设置。您能按照您的需求来修改这些设置。
+例如,将模型的大小设置为opt_125M,将正确的检查点路径按照如下设置:
+
+```bash
+model_class = opt_125M
+checkpoint = 'your_file_path'
+```
+
+将张量并行度设置为您的gpu数量。
+
+```bash
+tp_init_size = #gpu
+```
+
+现在,我们就能利用docker发布一个服务。您能在`/model_checkpoint` 和 `/config`路径下找到检查点文件和配置文件。
+
+
+```bash
+export CHECKPOINT_DIR="your_opt_checkpoint_path"
+# the ${CONFIG_DIR} must contain a server.sh file as the entry of service
+export CONFIG_DIR="config_file_path"
+
+docker run --gpus all --rm -it -p 8020:8020 -v ${CHECKPOINT_DIR}:/model_checkpoint -v ${CONFIG_DIR}:/config --ipc=host energonai:lastest
+```
+
+接下来,您就可以在您的浏览器中打开 `https://[IP-ADDRESS]:8020/docs#` 进行测试。
+
+## 高级特性用法
+
+1. 批处理优化
+
+若想使用我们的高级批处理技术来批量收集多个查询,您可以将executor_max_batch_size设置为最大批处理大小。 请注意,只有具有相同 top_k、top_p 和温度的解码任务才能一起批处理。
+
+```
+executor_max_batch_size = 16
+```
+
+所有的查询将进入FIFO队列。解码步数小于或等于队列头部解码步数的所有连续查询可以一起批处理。 应用左填充以确保正确性。 executor_max_batch_size 不应该过大,从而确保批处理不会增加延迟。 以opt-30b为例, `executor_max_batch_size=16` 合适,但对于opt-175b而言, `executor_max_batch_size=4` 更合适。
+
+2. 缓存优化
+
+对于每一个独立的服务过程,您能将最近的多个查询结果缓存在一起。在config.py中设置 cache_size 和 cache_list_size。缓存的大小应为缓存的查询数目。cache_list_size 应为每次查询存储的结果数。一个随机缓存的结果将会被返回。当缓存已满,LRU策略被用于清理缓存过的查询。cache_size=0意味着不缓存。
+
+```
+cache_size = 50
+cache_list_size = 2
+```
diff --git a/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md
new file mode 100644
index 000000000000..f3c6247c38e4
--- /dev/null
+++ b/docs/source/zh-Hans/advanced_tutorials/parallelize_your_training_like_Megatron.md
@@ -0,0 +1,176 @@
+# 使用ColoTensor让串行程序像Megatron-LM一样并行
+
+Author: [Haichen Huang](https://github.com/1SAA) and [Jiarui Fang](https://github.com/feifeibear)
+
+**Prerequisite:**
+- [ColoTensor Concepts](../basics/colotensor_concept.md)
+
+## 介绍
+
+在新版本中,我们引入了ColoTensor。ColoTensor为用户使用并行训练提供了极大的便利,使得用户可以在原本的串行代码上,通过较小的修改将训练改为并行。在本教程中,我们将说明如何修改训练模型以自动使代码采取像 Megatron-LM 一样的方式并行训练。我们以 HuggingFace 提供的 GPT-2 模型为例,并提供一种方式让你可以在单个GPU上预训练GPT-2模型。
+
+Megatron-LM 提供了一个具有影响力的并行化范式,这个范式主要应用于Transformer大模型的训练。然而,为了大规模训练 Transformer 语言大模型,用户必须使用Megatron-LM提供的特殊模块来构建他们的模型。这给用户带来了一些困难的工作,例如从预先训练的模型中加载权重,或是构建自己的并行训练模型。为了减轻用户的麻烦,我们提供 ColoTensor 类,以完成自动启用张量模型并行。
+
+## 定义模型和损失函数
+
+首先,我们直接调用 HuggingFace 库中的 GPTModel 和 GPTLoss。
+
+```python
+import torch
+import torch.nn as nn
+from transformers import GPT2Config, GPT2LMHeadModel
+
+class GPTLMModel(nn.Module):
+ def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=12, max_seq_len=1024, vocab_size=50257, checkpoint=False):
+ super().__init__()
+ self.checkpoint = checkpoint
+ self.model = GPT2LMHeadModel(GPT2Config(n_embd=hidden_size, n_layer=num_layers,
+ n_head=num_attention_heads, n_positions=max_seq_len, n_ctx=max_seq_len, vocab_size=vocab_size))
+ if checkpoint:
+ self.model.gradient_checkpointing_enable()
+
+ def forward(self, input_ids, attention_mask):
+ # Only return lm_logits
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
+
+
+class GPTLMLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.loss_fn = nn.CrossEntropyLoss()
+
+ def forward(self, logits, labels):
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+```
+
+## 对GPT-2的简短回顾
+
+现在,我们回顾一下 GPT-2 模型的结构。每个 GPT-2 模型都可以表示为一个 DAG。如下图所示,每个圆圈代表一个算子,每个方块代表一个权重。每个箭头表示输入数据的流向,而箭头旁边的符号表示输入数据的形状。
+
+然后,让我们深入了解一下这个 GPT-2 模型。它由三部分组成,分别是**嵌入模块**、**转换器层**和**分类头**。
+
+嵌入模块包含两个权重,符号嵌入权重和位置嵌入权重。在嵌入模块的前向操作之后,原始输入数据的所有序列中的每个单词都会被嵌入到隐藏状态。
+
+
+
+嵌入模块
+
+
+每个转换器层包含两个块。自注意操作在第一个块中调用,同时一个双层感知器位于第二个块中。
+
+
+
+转换器层
+
+
+最后,分类头只是一个不加偏差的线性模块,里面只有一个线性权重。
+
+## 应用ColoTensor
+
+两个步骤使您的串行代码采取 Megatron-LM 张量并行风格。
+1. 在ColoInitContext的上下文中初始化模型。
+2. 为每个参数设置 ColoTensorSpec。
+
+### 使用 ColoInitContext 初始化
+
+我们应该在 ColoInitContext 中构建模型。在该种上下文中,任何初始化的参数都将转换为 ColoParameter 并自动移动到相应的设备上。
+
+```python
+from colossalai.utils.model.colo_init_context import ColoInitContext
+
+with ColoInitContext(device=torch.device('cpu')):
+ model = GPTLMModel()
+```
+
+### 为每个参数设置 ColoTensorSpec
+
+模型创建完成后,我们通过ProcessGroup建立分布式环境。这里,我们将张量并行度指定为所有GPU的数量,即数据并行度为一。
+
+```python
+import torch.distributed as dist
+from colossalai.tensor import ProcessGroup
+
+pg = ProcessGroup(tp_degree=dist.get_world_size())
+```
+
+现在,我们需要一些辅助函数为下一步做准备。我们定义了两个函数来切分参数。Megatron-LM张量并行需要沿参数的第一维或最后一维切分参数张量。
+
+```python
+from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern, ColoParameter, ProcessGroup
+
+def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
+ spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
+ if param.process_group.tp_world_size() == 1:
+ param.set_process_group(pg)
+ param.set_tensor_spec(*spec)
+
+
+def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
+ split_param_single_dim_tp1d(0, param, pg)
+
+
+def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
+ split_param_single_dim_tp1d(-1, param, pg)
+```
+
+然后我们使模型采用张量并行。根据 Megatron 中使用的张量并行,应该沿着张量的最后一个维度进行切片,包括符号嵌入的权重,位置嵌入的权重,自注意力块中的所有线性权重和偏差,以及每个双层感知器中的第一个线性权重和偏差。且需要沿第一个维度切分双层感知器中的第二个线性权重。
+
+```python
+for mn, module in model.named_modules():
+ for pn, param in module.named_parameters(recurse=False):
+ # set process group for all parameters
+ param.set_process_group(pg)
+
+ if 'mlp.c_fc' in mn:
+ if 'weight' in pn or 'bias' in pn:
+ split_param_col_tp1d(param, pg) # colmn slice
+ # keep the shape of the output from c_fc
+ param.compute_spec.set_output_replicate(False)
+ elif 'mlp.c_proj' in mn:
+ if 'weight' in pn:
+ split_param_row_tp1d(param, pg) # row slice
+ elif 'wte' in mn or 'wpe' in mn:
+ split_param_col_tp1d(param, pg) # colmn slice
+ elif 'c_attn' in mn or 'c_proj' in mn:
+ split_param_col_tp1d(param, pg) # colmn slice
+```
+
+修改后的模型如下图所示。
+
+嵌入模块:
+
+
+
+修改后的嵌入模块
+
+
+转换器层:
+
+
+
+修改后的转换器层
+
+
+一旦用户指定了每个参数的在并行中的分布模式,ColoTensor 就能够推断出所有算子的计算模式,包括矩阵乘法、线性函数、torch.nn.functional 中的其他逐元素函数,以及其他的一些常用函数。这样,用户可以像往常一样训练他们的模型。
+
+在我们最新示例中还定义了一个Gemini + ZeRO DDP 的模型从而减小开销,提升效率。这一部分的详细内容可以参考[ZeRO](../features/zero_with_chunk.md),你可以将这两部分内容结合起来看从而理解我们整个训练流程:
+
+```python
+def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
+ from colossalai.nn.parallel import GeminiDDP
+ model = GeminiDDP(model,
+ device=get_current_device(),
+ placement_policy=placememt_policy,
+ pin_memory=True,
+ search_range_mb=32)
+ return model
+```
+
+## 在单个GPU上预训练GPT-2
+
+我们做的上述优化让我们可以在单GPU上训练GPT-2模型,只需要将`run.sh`中设置参数`GPUNUM`=1,再运行文件时就可以在单个GPU上完成模型的训练。
+
+GPT-2 示例在[Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). 获得。
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
new file mode 100644
index 000000000000..6c6dcf6e850d
--- /dev/null
+++ b/docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md
@@ -0,0 +1,275 @@
+# 使用混合并行训练 GPT
+
+作者: Hongxin Liu, Yongbin Li
+
+**示例代码**
+- [ColossalAI-Examples GPT2](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_2)
+- [ColossalAI-Examples GPT3](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/gpt_3)
+
+**相关论文**
+- [Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training](https://arxiv.org/abs/2110.14883)
+- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473)
+
+## 引言
+
+在上一篇教程中,我们介绍了如何用流水并行训练 ViT。在本教程中,你将学习一个更复杂的场景--用混合并行方式训练GPT。在这种情况下,由于GPT-3过大,即使CPU内存也无法容纳它。因此,你必须自己分割模型。
+
+## 目录
+
+在本教程中,我们将介绍:
+
+1. 基于 colossalai/model_zoo 定义 GPT 模型
+2. 处理数据集
+3. 使用混合并行训练 GPT
+
+## 导入依赖库
+
+```python
+import json
+import os
+from typing import Callable
+
+import colossalai
+import colossalai.utils as utils
+import model_zoo.gpt.gpt as col_gpt
+import torch
+import torch.nn as nn
+from colossalai import nn as col_nn
+from colossalai.amp import AMP_TYPE
+from colossalai.builder.pipeline import partition_uniform
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+ PipelineSchedule)
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
+from colossalai.trainer import Trainer, hooks
+from colossalai.utils.timer import MultiTimer
+from model_zoo.gpt import GPTLMLoss
+from torch.nn import functional as F
+from torch.utils.data import Dataset
+from transformers import GPT2Tokenizer
+```
+
+
+
+## 定义 GPT 模型
+
+在前面的教程中,我们介绍了3种建立流水并行模型的方法,但对于像 GPT-3 这样的巨大模型,你甚至不能在 CPU 中建立模型。在这种情况下,你必须自己分割模型。
+
+GPT 数据加载器返回 `input_ids` 和 `attention_mask`, 因此我们在 `forward()` 中使用两个关键字参数来获得它们。请注意,对于除第一阶段以外的其他阶段, `forward()` 的第一个位置参数是上一阶段的输出张量。所以 `hidden_states` 来自前一阶段,并且对于第一阶段来说,它是 `None`。
+
+对于 GPT, *word embedding layer* 与 *output head* 共享权重。我们提供 `PipelineSharedModuleWrapper` 在流水阶段间共享参数。它需要一个 `int` 型的 `list` 作为参数, 这意味着 rank 们共享这些参数。你可以使用 `register_module()`
+或 `register_parameter()` 来注册一个模块或一个参数作为共享模块或参数。如果你有多组共享模块/参数,你应该有多个 `PipelineSharedModuleWrapper` 实例。 如果参数在**一个**阶段内共享, 你不应该使用
+`PipelineSharedModuleWrapper`, 而只是使用同一个模块/参数实例。在这个例子中,*word embedding layer* 在第一阶段, 而 *output head* 在最后一个阶段。因此,他们在 rank `[0, pipeline_size - 1]` 之间共享参数。
+
+对于第一阶段,它维护 embedding layer 和一些 transformer blocks。对于最后一个阶段,它维护一些 transformer blocks 和 output head layer。对于其他阶段,他们只维护一些 transformer blocks。
+`partition_uniform(num_layers, pipeline_size, num_chunks)` 返回所有 rank 的 parts, part 是一个 `(start, end)` (不包括end) 的 `tuple`。`start == 0` 表示这是第一阶段, 而 `end == num_layers` 表示这是最后一个阶段。
+
+```python
+class PipelineGPTHybrid(nn.Module):
+ def __init__(self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ embed_drop_rate: float = 0.,
+ act_func: Callable = F.gelu,
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0.,
+ drop_rate: float = 0.,
+ dtype: torch.dtype = torch.float,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ first: bool = False,
+ last: bool = False):
+ super().__init__()
+ self.embedding = None
+ self.norm = None
+ self.head = None
+ if first:
+ self.embedding = col_gpt.GPTEmbedding(
+ hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype)
+ self.blocks = nn.ModuleList([
+ col_gpt.GPTBlock(hidden_size, num_attention_heads, mlp_ratio=mlp_ratio, attention_dropout=attn_drop_rate,
+ dropout=drop_rate, dtype=dtype, checkpoint=checkpoint, activation=act_func)
+ for _ in range(num_layers)
+ ])
+ if last:
+ self.norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+ self.head = col_gpt.GPTLMHead(vocab_size=vocab_size,
+ dim=hidden_size,
+ dtype=dtype,
+ bias=False)
+
+ def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
+ if self.embedding is not None:
+ hidden_states = self.embedding(input_ids=input_ids)
+ batch_size = hidden_states.shape[0]
+ attention_mask = attention_mask.view(batch_size, -1)
+ attention_mask = attention_mask[:, None, None, :]
+ attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * -10000.0
+ for block in self.blocks:
+ hidden_states, attention_mask = block(hidden_states, attention_mask)
+ if self.norm is not None:
+ hidden_states = self.head(self.norm(hidden_states))
+ return hidden_states
+
+
+def build_gpt_pipeline(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ logger = get_dist_logger()
+ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
+ pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
+ rank = gpc.get_global_rank()
+ wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
+ parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
+ models = []
+ for start, end in parts:
+ kwargs['num_layers'] = end - start
+ kwargs['first'] = start == 0
+ kwargs['last'] = end == num_layers
+ logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
+ chunk = PipelineGPTHybrid(**kwargs).to(device)
+ if start == 0:
+ wrapper.register_module(chunk.embedding.word_embeddings)
+ elif end == num_layers:
+ wrapper.register_module(chunk.head)
+ models.append(chunk)
+ if len(models) == 1:
+ model = models[0]
+ else:
+ model = nn.ModuleList(models)
+ return model
+
+
+def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float):
+ cfg = dict(hidden_size=1600, num_attention_heads=32, checkpoint=checkpoint, dtype=dtype)
+ return build_gpt_pipeline(48, num_chunks, **cfg)
+
+
+def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float):
+ cfg = dict(hidden_size=12288, num_attention_heads=96,
+ checkpoint=checkpoint, max_position_embeddings=2048, dtype=dtype)
+ return build_gpt_pipeline(96, num_chunks, **cfg)
+```
+
+## 处理数据集
+
+我们在这里提供了一个小型 GPT web-text 数据集。 原始格式是 loose JSON, 我们将保存处理后的数据集。
+
+```python
+class WebtextDataset(Dataset):
+ def __init__(self, path, seq_len=1024) -> None:
+ super().__init__()
+ root = os.path.dirname(path)
+ encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
+ if os.path.isfile(encoded_data_cache_path):
+ seq_len_, data, attention_mask = torch.load(
+ encoded_data_cache_path)
+ if seq_len_ == seq_len:
+ self.data = data
+ self.attention_mask = attention_mask
+ return
+ raw_data = []
+ with open(path) as f:
+ for line in f.readlines():
+ raw_data.append(json.loads(line)['text'])
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.unk_token
+ encoded_data = tokenizer(
+ raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
+ self.data = encoded_data['input_ids']
+ self.attention_mask = encoded_data['attention_mask']
+ torch.save((seq_len, self.data, self.attention_mask),
+ encoded_data_cache_path)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ return {
+ 'input_ids': self.data[index],
+ 'attention_mask': self.attention_mask[index]
+ }, self.data[index]
+```
+
+## 使用混合并行训练 GPT
+
+在上一个教程中,我们解释了一些流水并行的参数含义。在本例中,我们可以确定在流水阶段之间交换的每个输出张量的形状。对于 GPT,该形状为
+`(MICRO BATCH SIZE, SEQUENCE LEN, HIDDEN SIZE)`。通过设置该参数,我们可以避免交换每个阶段的张量形状。当你不确定张量的形状时,你可以把它保留为
+`None`, 形状会被自动推测。请确保你的模型的 `dtype` 是正确的:当你使用 `fp16`,模型的 `dtype` 必须是 `torch.half`;否则,`dtype` 必须是 `torch.float`。对于流水并行,仅支持 `AMP_TYPE.NAIVE`。
+
+你可以通过在 `CONFIG` 里使用 `parallel` 来轻松使用张量并行。数据并行的大小是根据 GPU 的数量自动设置的。
+
+```python
+NUM_EPOCHS = 60
+SEQ_LEN = 1024
+BATCH_SIZE = 192
+NUM_CHUNKS = None
+TENSOR_SHAPE = (1, 1024, 1600)
+# only pipeline parallel
+# CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2), fp16=dict(mode=AMP_TYPE.NAIVE))
+# pipeline + 1D model parallel
+CONFIG = dict(NUM_MICRO_BATCHES = 192, parallel=dict(pipeline=2, tensor=dict(mode='1d', size=2)), fp16=dict(mode=AMP_TYPE.NAIVE))
+
+
+def train():
+ disable_existing_loggers()
+ parser = colossalai.get_default_parser()
+ args = parser.parse_args()
+ colossalai.launch_from_torch(config=CONFIG, backend=args.backend)
+ logger = get_dist_logger()
+
+ train_ds = WebtextDataset(os.environ['DATA'], seq_len=SEQ_LEN)
+ train_dataloader = utils.get_dataloader(train_ds,
+ seed=42,
+ batch_size=BATCH_SIZE,
+ pin_memory=True,
+ shuffle=True,
+ drop_last=True)
+
+ use_interleaved = NUM_CHUNKS is not None
+ num_chunks = 1 if not use_interleaved else NUM_CHUNKS
+ model = GPT2_exlarge_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half)
+ # model = GPT3_pipeline_hybrid(num_chunks=num_chunks, checkpoint=True, dtype=torch.half)
+ if use_interleaved and not isinstance(model, nn.ModuleList):
+ model = nn.ModuleList([model])
+
+ criterion = GPTLMLoss()
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.00015, weight_decay=1e-2,)
+
+ engine, train_dataloader, _, _ = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader=train_dataloader)
+ global_batch_size = BATCH_SIZE * \
+ gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
+ logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
+
+ timer = MultiTimer()
+
+ trainer = Trainer(
+ engine=engine,
+ logger=logger,
+ timer=timer
+ )
+
+ hook_list = [
+ hooks.LossHook(),
+ hooks.LogMetricByEpochHook(logger),
+ hooks.ThroughputHook(),
+ hooks.LogMetricByStepHook(),
+ ]
+
+ trainer.fit(
+ train_dataloader=train_dataloader,
+ epochs=NUM_EPOCHS,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True,
+ return_output_label=False,
+ )
+```
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
new file mode 100644
index 000000000000..495c7fa36cc1
--- /dev/null
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_using_pipeline_parallelism.md
@@ -0,0 +1,246 @@
+# 使用流水并行训练 ViT
+
+作者: Hongxin Liu, Yongbin Li
+
+**示例代码**
+- [ColossalAI-Examples Pipeline Parallel ViT](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer/pipeline_parallel)
+
+**相关论文**
+- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473)
+
+## 引言
+
+在本教程中,你将学习如何使用流水并行从头开始训练用于图像分类的 Vision Transformer (ViT)。流水并行是一种模型并行,主要针对 GPU 内存不能满足模型容量的情况。
+通过使用流水并行,我们将原始模型分割成多个阶段,每个阶段保留原始模型的一部分。我们假设你的 GPU 内存不能容纳 ViT/L-16,而你的内存可以容纳这个模型。
+
+## 目录
+
+在本教程中,我们将介绍:
+
+1. 基于 [TIMM](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) 定义 ViT 模型
+2. 处理数据集
+3. 使用流水并行训练 ViT
+
+## 导入依赖库
+
+```python
+import os
+from collections import OrderedDict
+from functools import partial
+
+import colossalai
+import colossalai.nn as col_nn
+import torch
+import torch.nn as nn
+from colossalai.builder import build_pipeline_model
+from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+ PipelineSchedule)
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.trainer import Trainer, hooks
+from colossalai.utils import MultiTimer, get_dataloader
+from timm.models import vision_transformer as vit
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+```
+
+
+## 定义 Vision Transformer 模型
+
+总的来说, 我们提供3种方法来建立一个流水并行的模型:
+
+1. `colossalai.builder.build_pipeline_model_from_cfg`
+2. `colossalai.builder.build_pipeline_model`
+3. 自己按阶段拆分模型
+
+当你的内存能够容纳模型时,你可以使用前两种方法来建立你的模型,否则你必须自己分割模型。前两种方法首先在 CPU 上建立整个模型,然后分割模型,最后你可以直接把模型的相应部分移到 GPU 上。
+
+`colossalai.builder.build_pipeline_model_from_cfg()` 接收一个模型的配置文件,它可以均匀地(按层)或平衡地(按参数大小)分割模型。
+
+如果你熟悉 `PyTorch`, 你可以使用 `colossalai.builder.build_pipeline_model()` 它接收一个 `torch.nn.Sequential` 模型并按层均匀分割。
+
+在本教程中,我们将修改 [TIMM/ViT](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py) to `torch.nn.Sequential`,然后使用 `colossalai.builder.build_pipeline_model()` 来建立流水线模型。
+
+当数据是 **一个** `Tensor`, 你可以使用你的模型 `forward()` 中的位置参数来获得数据张量。对于流水线的第一阶段,`forward()` 的第一个位置参数是从数据加载器加载的数据张量。对于其他阶段,`forward()` 的第一个位置参数是上一阶段的输出张量。注意,如果该阶段不是最后一个阶段,则 `forward()` 的返回必须是一个 `Tensor`。
+
+当数据是一个 `Tensor` 的 `dict`, 你可以使用你模型 `forward()` 的命名关键字参数来获得数据的 `dict`。
+
+```python
+class ViTEmbedding(nn.Module):
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, embed_layer=vit.PatchEmbed, drop_rate=0., distilled=False):
+ super().__init__()
+ self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 2 if distilled else 1
+ self.patch_embed = embed_layer(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ self.init_weights()
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ if self.dist_token is None:
+ x = torch.cat((cls_token, x), dim=1)
+ else:
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = self.pos_drop(x + self.pos_embed)
+ return x
+
+ def init_weights(self):
+ vit.trunc_normal_(self.pos_embed, std=.02)
+ if self.dist_token is not None:
+ vit.trunc_normal_(self.dist_token, std=.02)
+ vit.trunc_normal_(self.cls_token, std=.02)
+ self.apply(vit._init_vit_weights)
+
+
+class ViTHead(nn.Module):
+ def __init__(self, embed_dim=768, num_classes=1000, norm_layer=None, distilled=False, representation_size=None):
+ super().__init__()
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ self.norm = norm_layer(embed_dim)
+ self.num_classes = num_classes
+ self.distilled = distilled
+ self.num_features = embed_dim
+ # Representation layer
+ if representation_size and not distilled:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+ # Classifier head(s)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = None
+ if distilled:
+ self.head_dist = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.init_weights()
+
+ def forward(self, x):
+ x = self.norm(x)
+ if self.distilled:
+ x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1])
+ if self.training and not torch.jit.is_scripting():
+ # during inference, return the average of both classifier predictions
+ return x, x_dist
+ else:
+ return (x + x_dist) / 2
+ else:
+ x = self.pre_logits(x[:, 0])
+ x = self.head(x)
+ return x
+
+ def init_weights(self):
+ self.apply(vit._init_vit_weights)
+
+
+def sequential_vit(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=vit.PatchEmbed, norm_layer=None,
+ act_layer=None):
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+ embedding = ViTEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans,
+ embed_dim=embed_dim, embed_layer=embed_layer, drop_rate=drop_rate, distilled=distilled)
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ blocks = [vit.Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
+ for i in range(depth)]
+ for block in blocks:
+ block.apply(vit._init_vit_weights)
+ head = ViTHead(embed_dim=embed_dim, num_classes=num_classes, norm_layer=norm_layer,
+ distilled=distilled, representation_size=representation_size)
+ return nn.Sequential(embedding, *blocks, head)
+
+
+def vit_large_patch16_224(**kwargs):
+ model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ return sequential_vit(**model_kwargs)
+```
+
+## 处理数据集
+
+一般来说, 我们在大型数据集如 ImageNet 上训练 ViT。为了简单期间,我们在这里只使用 CIFAR-10, 因为本教程只是用于流水并行训练。
+
+```python
+def build_cifar(batch_size):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(224, pad_if_needed=True),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+
+ train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
+ train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
+ test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
+ return train_dataloader, test_dataloader
+```
+
+## 使用流水并行训练 ViT
+
+你可以在配置文件中设置流水并行的大小。`NUM_CHUNKS` 在使用交错流水线时很有用 (更多细节见 [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://arxiv.org/abs/2104.04473) )。
+原始 batch 将会被分割为 `num_microbatches`, 每个阶段每次将加载一个 micro batch。如果你确定性地知道每个阶段输出张量的形状,你可以在配置文件中设置 `tensor_shape` 来减少通信。
+我们的仓库会自动为用户生成合适的schedule来支持流水并行训练。如果你不需要模型的输出和标签,你可以在调用 `trainer.fit()` 时,将 `return_output_label` 设置为 `False`,这样能进一步减少 GPU 显存使用。
+
+你应当使用 `export DATA=/path/to/cifar`。
+
+```python
+BATCH_SIZE = 16
+NUM_EPOCHS = 60
+NUM_CHUNKS = 1
+CONFIG = dict(NUM_MICRO_BATCHES=4, parallel=dict(pipeline=2))
+
+
+def train():
+ disable_existing_loggers()
+ parser = colossalai.get_default_parser()
+ args = parser.parse_args()
+ colossalai.launch_from_torch(backend=args.backend, config=CONFIG)
+ logger = get_dist_logger()
+
+ # build model
+ model = vit_large_patch16_224()
+ model = build_pipeline_model(model, num_chunks=NUM_CHUNKS, verbose=True)
+
+ # build criterion
+ criterion = nn.CrossEntropyLoss()
+
+ # optimizer
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
+
+ # build dataloader
+ train_dataloader, test_dataloader = build_cifar(BATCH_SIZE)
+
+ engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model, optimizer, criterion,
+ train_dataloader, test_dataloader)
+ timer = MultiTimer()
+
+ trainer = Trainer(engine=engine, timer=timer, logger=logger)
+
+ hook_list = [
+ hooks.LossHook(),
+ hooks.AccuracyHook(col_nn.metric.Accuracy()),
+ hooks.LogMetricByEpochHook(logger),
+ ]
+
+ trainer.fit(train_dataloader=train_dataloader,
+ epochs=NUM_EPOCHS,
+ test_dataloader=test_dataloader,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True)
+```
diff --git a/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
new file mode 100644
index 000000000000..6dc5eccf4421
--- /dev/null
+++ b/docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md
@@ -0,0 +1,591 @@
+# 使用 Colossal-AI (从数据并行到异构并行)加速 ViT 训练详解
+
+作者:Yuxuan Lou
+
+**示例代码**
+
+- [Colossal-AI Examples ViT on Cifar10](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/image/vision_transformer)
+
+**相关文献**
+- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf)
+
+
+## 引言
+
+在这个ViT模型的样例中,Colossal-AI 提供了三种不同的并行技术来加速模型训练:数据并行,流水线并行和张量并行。我们将展示如何使用这三种并行技术在 CIFAR-10 数据集上训练 ViT。为了运行项目,需要2-4个 GPU。
+
+
+## 目录
+1. Colossal-AI 安装方法
+2. 使用数据并行训练 ViT 步骤
+3. 使用数据流水线并行训练 ViT 步骤
+4. 使用张量并行或异构并行训练 ViT 步骤
+
+## Colossal-AI 安装
+可以通过 Python 的官方索引来安装 Colossal-AI 软件包。
+```bash
+pip install colossalai
+```
+
+
+
+## 数据并行
+数据并行是实现加速模型训练的基本方法。通过两步可以实现训练的数据并行:
+1. 构建一个配置文件
+2. 在训练脚本中修改很少的几行代码
+
+### 构建配置文件 (`data_parallel/config.py`)
+为了使用 Colossal-AI,第一步是构建配置文件。并且,在这里有两种变量:
+
+1. **Colossal-AI 功能配置**
+
+Colossal-AI 提供了一系列的功能来加快训练速度(包括模型并行,混合精度,零冗余优化器等)。每个功能都是由配置文件中的相应字段定义的。如果我们只用到数据并行,那么我们只需要具体说明并行模式。在本例中,我们使用 PyTorch 最初提出的混合精度训练,只需要定义混合精度配置 `fp16 = dict(mode=AMP_TYPE.TORCH)` 。
+
+2. **全局超参数**
+
+全局超参数包括特定于模型的超参数、训练设置、数据集信息等。
+
+```python
+from colossalai.amp import AMP_TYPE
+# ViT Base
+BATCH_SIZE = 256
+DROP_RATE = 0.1
+NUM_EPOCHS = 300
+# mix precision
+fp16 = dict(
+ mode=AMP_TYPE.TORCH,
+)
+gradient_accumulation = 16
+clip_grad_norm = 1.0
+dali = dict(
+ gpu_aug=True,
+ mixup_alpha=0.2
+)
+```
+
+### 修改训练脚本 (`/data_parallel/train_with_cifar10.py`)
+
+#### 导入模块
+- Colossal-AI 相关模块
+```python
+import colossalai
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.lr_scheduler import LinearWarmupLR
+from colossalai.nn.metric import Accuracy
+from colossalai.trainer import Trainer, hooks
+```
+
+- 其他模块
+```python
+import os
+import torch
+from timm.models import vit_base_patch16_224
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+```
+
+#### 启动 Colossal-AI
+
+在训练脚本中,在构建好配置文件后,需要为 Colossal-AI 初始化分布式环境。我们将此过程称为 `launch` 。在 Colossal-AI 中,我们提供了几种启动方法来初始化分布式后端。在大多数情况下,您可以使用 `colossalai.launch` 和 `colossalai.get_default_parser ` 来实现使用命令行传递参数。此外,Colossal-AI 可以利用 PyTorch 提供的现有启动工具,正如许多用户通过使用熟知的 `colossalai.launch_from_torch` 那样。更多详细信息,您可以查看相关[文档](https://www.colossalai.org/docs/basics/launch_colossalai)。
+
+
+```python
+# initialize distributed setting
+parser = colossalai.get_default_parser()
+args = parser.parse_args()
+colossalai.launch_from_torch(config=args.config)
+disable_existing_loggers()
+logger = get_dist_logger()
+```
+
+初始化后,您可以使用 `colossalai.core.global_context` 访问配置文件中的变量。
+
+```python
+#access parameters
+print(gpc.config.BATCH_SIZE)
+```
+
+#### 构建模型
+
+如果只需要数据并行性,则无需对模型代码进行任何更改。这里,我们使用 `timm` 中的 `vit_base_patch16_224`。
+
+```python
+# build model
+model = vit_base_patch16_224(drop_rate=0.1, num_classes=gpc.config.NUM_CLASSES)
+```
+
+#### 构建 CIFAR-10 数据加载器
+`colossalai.utils.get_dataloader` 可以帮助您轻松构建数据加载器。
+
+```python
+def build_cifar(batch_size):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(224, pad_if_needed=True),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
+ train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
+ test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
+ return train_dataloader, test_dataloader
+# build dataloader
+train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE)
+```
+
+#### 定义优化器,损失函数和学习率调度器
+
+Colossal-AI 提供了自己的优化器、损失函数和学习率调度器。PyTorch 的这些组件与Colossal-AI也兼容。
+
+```python
+# build optimizer
+optimizer = colossalai.nn.Lamb(model.parameters(), lr=1.8e-2, weight_decay=0.1)
+# build loss
+criterion = torch.nn.CrossEntropyLoss()
+# lr_scheduelr
+lr_scheduler = LinearWarmupLR(optimizer, warmup_steps=50, total_steps=gpc.config.NUM_EPOCHS)
+```
+
+#### 启动用于训练的 Colossal-AI 引擎
+
+Engine 本质上是对模型、优化器和损失函数的封装类。当我们使用 `colossalai.initialize` ,将返回一个 engine 对象,并且它已经按照配置文件中的指定内容,配置了梯度剪裁、梯度累积和零冗余优化器等功能。之后,基于 Colossal-AI 的 engine 我们可以进行模型训练。
+
+```python
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
+ model, optimizer, criterion, train_dataloader, test_dataloader
+ )
+```
+
+#### 训练:Trainer 应用程序编程接口
+Trainer 是一个更高级的封装类,用户可以用更少的代码就可以实现训练。通过传递 engine 对象很容易创建 trainer 对象。
+
+此外,在 trainer 中,用户可以自定义一些挂钩,并将这些挂钩连接到 trainer 对象。钩子对象将根据训练方案定期执行生命周期方法。例如,`LRSchedulerHook` 将执行`lr_scheduler.step()` 在 `after_train_iter` 或 `after_train_epoch` 阶段更新模型的学习速率。
+
+```python
+# build trainer
+trainer = Trainer(engine=engine, logger=logger)
+# build hooks
+hook_list = [
+ hooks.LossHook(),
+ hooks.AccuracyHook(accuracy_func=MixupAccuracy()),
+ hooks.LogMetricByEpochHook(logger),
+ hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
+ # comment if you do not need to use the hooks below
+ hooks.SaveCheckpointHook(interval=1, checkpoint_dir='./ckpt'),
+ hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
+]
+```
+
+使用 `trainer.fit` 进行训练:
+
+```python
+# start training
+trainer.fit(
+ train_dataloader=train_dataloader,
+ test_dataloader=test_dataloader,
+ epochs=gpc.config.NUM_EPOCHS,
+ hooks=hook_list,
+ display_progress=True,
+ test_interval=1
+)
+```
+
+### 开始训练
+`DATA` 是自动下载和存储 CIFAR-10 数据集的文件路径。
+
+`` 是要用于使用 CIFAR-10 数据集,以数据并行方式训练 ViT 的 GPU 数。
+
+```bash
+export DATA=
+# If your torch >= 1.10.0
+torchrun --standalone --nproc_per_node train_dp.py --config ./configs/config_data_parallel.py
+# If your torch >= 1.9.0
+# python -m torch.distributed.run --standalone --nproc_per_node= train_dp.py --config ./configs/config_data_parallel.py
+# Otherwise
+# python -m torch.distributed.launch --nproc_per_node --master_addr --master_port 29500 train_dp.py --config ./configs/config.py
+```
+
+
+
+## 流水线并行
+除了数据并行性,Colossal-AI 还支持流水线并行。具体而言,Colossal-AI 使用 NVIDIA 引入的 1F1B 流水线。更多详细信息,您可以查看相关[文档](https://www.colossalai.org/tutorials/features/pipeline_parallel)。
+
+### 构建配置文件(`hybrid_parallel/configs/vit_pipeline.py`)
+要在数据并行的基础上应用流水线并行,只需添加一个 **parallel dict**
+```python
+from colossalai.amp import AMP_TYPE
+parallel = dict(
+ pipeline=2
+)
+# pipeline config
+NUM_MICRO_BATCHES = parallel['pipeline']
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE)
+fp16 = dict(mode=AMP_TYPE.NAIVE)
+clip_grad_norm = 1.0
+```
+
+其他配置:
+```python
+# hyperparameters
+# BATCH_SIZE is as per GPU
+# global batch size = BATCH_SIZE x data parallel size
+BATCH_SIZE = 256
+LEARNING_RATE = 3e-3
+WEIGHT_DECAY = 0.3
+NUM_EPOCHS = 300
+WARMUP_EPOCHS = 32
+# model config
+IMG_SIZE = 224
+PATCH_SIZE = 16
+HIDDEN_SIZE = 768
+DEPTH = 12
+NUM_HEADS = 12
+MLP_RATIO = 4
+NUM_CLASSES = 10
+CHECKPOINT = True
+SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
+```
+
+### 构建流水线模型 (`/hybrid_parallel/model/vit.py`)
+Colossal-AI 提供了两种从现有模型构建流水线模型的方法。
+- `colossalai.builder.build_pipeline_model_from_cfg`
+- `colossalai.builder.build_pipeline_model`
+
+此外,您还可以使用 Colossal-AI 从头开始构建流水线模型。
+```python
+import math
+from typing import Callable
+import inspect
+import torch
+from colossalai import nn as col_nn
+from colossalai.registry import LAYERS, MODELS
+from colossalai.logging import get_dist_logger
+from colossalai.core import global_context as gpc
+from colossalai.context import ParallelMode
+from colossalai.builder.pipeline import partition_uniform
+from torch import dtype, nn
+from model_zoo.vit.vit import ViTBlock, ViTEmbedding, ViTHead
+@MODELS.register_module
+class PipelineVisionTransformer(nn.Module):
+ def __init__(self,
+ img_size: int = 224,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ depth: int = 12,
+ num_heads: int = 12,
+ dim: int = 768,
+ mlp_ratio: int = 4,
+ attention_dropout: float = 0.,
+ dropout: float = 0.1,
+ drop_path: float = 0.,
+ layernorm_epsilon: float = 1e-6,
+ activation: Callable = nn.functional.gelu,
+ representation_size: int = None,
+ dtype: dtype = None,
+ bias: bool = True,
+ checkpoint: bool = False,
+ init_method: str = 'torch',
+ first_stage=True,
+ last_stage=True,
+ start_idx=None,
+ end_idx=None,):
+ super().__init__()
+ layers = []
+ if first_stage:
+ embed = ViTEmbedding(img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embedding_dim=dim,
+ dropout=dropout,
+ dtype=dtype,
+ init_method=init_method)
+ layers.append(embed)
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
+ if start_idx is None and end_idx is None:
+ start_idx = 0
+ end_idx = depth
+ blocks = [
+ ViTBlock(
+ dim=dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ attention_dropout=attention_dropout,
+ dropout=dropout,
+ drop_path=dpr[i],
+ activation=activation,
+ dtype=dtype,
+ bias=bias,
+ checkpoint=checkpoint,
+ init_method=init_method,
+ ) for i in range(start_idx, end_idx)
+ ]
+ layers.extend(blocks)
+ if last_stage:
+ norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
+ head = ViTHead(dim=dim,
+ num_classes=num_classes,
+ representation_size=representation_size,
+ dtype=dtype,
+ bias=bias,
+ init_method=init_method)
+ layers.extend([norm, head])
+ self.layers = nn.Sequential(
+ *layers
+ )
+ def forward(self, x):
+ x = self.layers(x)
+ return x
+def _filter_kwargs(func, kwargs):
+ sig = inspect.signature(func)
+ return {k: v for k, v in kwargs.items() if k in sig.parameters}
+def _build_pipeline_vit(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ logger = get_dist_logger()
+ if gpc.is_initialized(ParallelMode.PIPELINE):
+ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
+ pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
+ else:
+ pipeline_size = 1
+ pipeline_rank = 0
+ rank = gpc.get_global_rank()
+ parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
+ models = []
+ for start, end in parts:
+ kwargs['first_stage'] = start == 0
+ kwargs['last_stage'] = end == num_layers
+ kwargs['start_idx'] = start
+ kwargs['end_idx'] = end
+ logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
+ chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
+ models.append(chunk)
+ if len(models) == 1:
+ model = models[0]
+ else:
+ model = nn.ModuleList(models)
+ return model
+def build_pipeline_vit(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ return _build_pipeline_vit(PipelineVisionTransformer, num_layers, num_chunks, device, **kwargs)
+```
+
+### 修改训练脚本 (`/hybrid_parallel/train_with_cifar10.py`)
+
+#### 导入模块
+```python
+from colossalai.engine.schedule import (InterleavedPipelineSchedule,
+ PipelineSchedule)
+from colossalai.utils import MultiTimer
+import os
+import colossalai
+import torch
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.nn import CrossEntropyLoss
+from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
+from colossalai.utils import is_using_pp, get_dataloader
+from model.vit import build_pipeline_vit
+from model_zoo.vit.vit import _create_vit_model
+from tqdm import tqdm
+from torchvision import transforms
+from torchvision.datasets import CIFAR10
+```
+
+#### 启动 Colossal-AI
+`colossalai.utils.is_using_pp` 可以帮您检查配置文件是否满足流水线并行的要求。
+
+```python
+# initialize distributed setting
+parser = colossalai.get_default_parser()
+args = parser.parse_args()
+# launch from torch
+colossalai.launch_from_torch(config=args.config)
+# get logger
+logger = get_dist_logger()
+logger.info("initialized distributed environment", ranks=[0])
+if hasattr(gpc.config, 'LOG_PATH'):
+ if gpc.get_global_rank() == 0:
+ log_path = gpc.config.LOG_PATH
+ if not os.path.exists(log_path):
+ os.mkdir(log_path)
+ logger.log_to_file(log_path)
+use_pipeline = is_using_pp()
+```
+
+#### 定义模型
+
+```python
+# create model
+model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
+ patch_size=gpc.config.PATCH_SIZE,
+ dim=gpc.config.HIDDEN_SIZE,
+ depth=gpc.config.DEPTH,
+ num_heads=gpc.config.NUM_HEADS,
+ mlp_ratio=gpc.config.MLP_RATIO,
+ num_classes=gpc.config.NUM_CLASSES,
+ init_method='jax',
+ checkpoint=gpc.config.CHECKPOINT)
+if use_pipeline:
+ model = build_pipeline_vit(num_layers=model_kwargs['depth'], num_chunks=1, **model_kwargs)
+else:
+ model = _create_vit_model(**model_kwargs)
+```
+
+#### 计算参数个数
+
+您可以轻松计算不同流水线阶段上的模型参数个数。
+
+```
+# count number of parameters
+total_numel = 0
+for p in model.parameters():
+ total_numel += p.numel()
+if not gpc.is_initialized(ParallelMode.PIPELINE):
+ pipeline_stage = 0
+else:
+ pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
+logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
+```
+
+#### 构建数据加载器,优化器等组件
+
+```python
+def build_cifar(batch_size):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(224, pad_if_needed=True),
+ transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.Resize(224),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ train_dataset = CIFAR10(root=os.environ['DATA'], train=True, download=True, transform=transform_train)
+ test_dataset = CIFAR10(root=os.environ['DATA'], train=False, transform=transform_test)
+ train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True)
+ test_dataloader = get_dataloader(dataset=test_dataset, batch_size=batch_size, pin_memory=True)
+ return train_dataloader, test_dataloader
+
+
+# craete dataloaders
+train_dataloader , test_dataloader = build_cifar()
+# create loss function
+criterion = CrossEntropyLoss(label_smoothing=0.1)
+# create optimizer
+optimizer = torch.optim.AdamW(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
+# create lr scheduler
+lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
+ total_steps=gpc.config.NUM_EPOCHS,
+ warmup_steps=gpc.config.WARMUP_EPOCHS)
+```
+
+#### 启动 Colossal-AI 引擎
+
+```python
+# intiailize
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model=model,
+ optimizer=optimizer,
+ criterion=criterion,
+ train_dataloader=train_dataloader,
+ test_dataloader=test_dataloader)
+logger.info("Engine is built", ranks=[0])
+```
+
+#### 训练:基于engine
+
+在数据并行示例中,我们展示了如何使用 Trainer API 训练模型。我们还可以直接训练基于 engine 的模型。通过这种方式,您可以使用更多功能自定义训练方法。
+
+```python
+data_iter = iter(train_dataloader)
+for epoch in range(gpc.config.NUM_EPOCHS):
+ # training
+ engine.train()
+ if gpc.get_global_rank() == 0:
+ description = 'Epoch {} / {}'.format(
+ epoch,
+ gpc.config.NUM_EPOCHS
+ )
+ progress = tqdm(range(len(train_dataloader)), desc=description)
+ else:
+ progress = range(len(train_dataloader))
+ for _ in progress:
+ engine.zero_grad()
+ engine.execute_schedule(data_iter, return_output_label=False)
+ engine.step()
+ lr_scheduler.step()
+```
+
+### 开始训练
+```bash
+export DATA=
+# If your torch >= 1.10.0
+torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_pipeline_parallel.py
+# If your torch >= 1.9.0
+# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_pipeline_parallel.py
+```
+
+
+
+
+## 张量并行和异构并行
+张量并行将每个权重参数跨多个设备进行分区,以减少内存负载。Colossal-AI 支持 1D、2D、2.5D 和 3D 张量并行。此外,还可以将张量并行、流水线并行和数据并行结合起来,实现混合并行。Colossal-AI 还提供了一种简单的方法来应用张量并行和混合并行。只需在配置文件中更改几行代码即可实现流水线并行。
+
+### 构造您的配置文件 (`/hybrid_parallel/configs/vit_1d_tp2_pp2.py`)
+使用张量并行,只需将相关信息添加到 **parallel dict**。具体而言,`TENSOR_PARALLEL_MODE` 可以是“1d”、“2d”、“2.5d”、“3d”。不同并行度的大小应满足:`#GPUs = pipeline parallel size x tensor parallel size x data parallel size`。在指定 GPU 数量、流水线并行大小和张量并行大小后 `data parallel size` 会自动计算。
+
+```python
+from colossalai.amp import AMP_TYPE
+# parallel setting
+TENSOR_PARALLEL_SIZE = 2
+TENSOR_PARALLEL_MODE = '1d'
+parallel = dict(
+ pipeline=2,
+ tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE)
+)
+fp16 = dict(mode=AMP_TYPE.NAIVE)
+clip_grad_norm = 1.0
+# pipeline config
+NUM_MICRO_BATCHES = parallel['pipeline']
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LENGTH, HIDDEN_SIZE)
+```
+
+其他配置:
+```python
+# hyperparameters
+# BATCH_SIZE is as per GPU
+# global batch size = BATCH_SIZE x data parallel size
+BATCH_SIZE = 256
+LEARNING_RATE = 3e-3
+WEIGHT_DECAY = 0.3
+NUM_EPOCHS = 300
+WARMUP_EPOCHS = 32
+# model config
+IMG_SIZE = 224
+PATCH_SIZE = 16
+HIDDEN_SIZE = 768
+DEPTH = 12
+NUM_HEADS = 12
+MLP_RATIO = 4
+NUM_CLASSES = 10
+CHECKPOINT = True
+SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE) ** 2 + 1 # add 1 for cls token
+```
+
+### 开始训练
+```bash
+export DATA=
+# If your torch >= 1.10.0
+torchrun --standalone --nproc_per_node train_hybrid.py --config ./configs/config_hybrid_parallel.py
+# If your torch >= 1.9.0
+# python -m torch.distributed.run --standalone --nproc_per_node= train_hybrid.py --config ./configs/config_hybrid_parallel.py
+```
diff --git a/docs/source/zh-Hans/basics/colotensor_concept.md b/docs/source/zh-Hans/basics/colotensor_concept.md
new file mode 100644
index 000000000000..cac5b9a4b40d
--- /dev/null
+++ b/docs/source/zh-Hans/basics/colotensor_concept.md
@@ -0,0 +1,98 @@
+# ColoTensor Concepts
+
+Author: [Jiarui Fang](https://github.com/feifeibear), [Hongxin Liu](https://github.com/ver217) and [Haichen Huang](https://github.com/1SAA)
+
+**Prerequisite:**
+- [Colossal-AI Overview](../concepts/colossalai_overview.md)
+- [Distributed Training](../concepts/distributed_training.md)
+- [Paradigms of Parallelism](../concepts/paradigms_of_parallelism.md)
+
+## Introduction
+
+在ColossalAI 0.1.8 版本之后,[ColoTensor](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ColoTensor) 成为 ColossalAI 中张量的基本数据结构。 它是 torch.Tensor 的子类,可以当做 PyTorch Tensor使用。 此外,一些独特的功能使其能够表示一个payload分布在多个 GPU 设备上的Global Tensor,并提供一些列方式操作这个Global Tensor。 在 ColoTensor 的帮助下,用户可以以类似编写串行程序方式,编写的分布式 DNN 训练程序。
+
+ColoTensor 包含额外的属性[ColoTensorSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.tensor_spec.html#colossalai.tensor.tensor_spec.ColoTensorSpec)
+来描述张量的payload分布和计算模式。
+
+- ProcessGroup:如何将进程组织为通信组。
+- Distributed Spec:张量如何在进程组之间分布。
+- Compute Spec:计算过程中如何使用张量。
+
+我们一一详述。
+
+## ProcessGroup
+
+[ProcessGroup](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.html#colossalai.tensor.ProcessGroup) 类的一个实例描述了如何在进程组中组织进程。进程组内的进程可以一起参与同一个集合通信,比如allgather, allreduce等。进程组组织方式被张量的并行策略支配。比如,如果用户定义了Tensor的张量并行(TP),数据并行(DP)方式,那么进程组的进程组织方式将被自动推导出来。 进程组设置可能因不同的张量而异。 因此,它使我们能够支持更复杂的混合并行。流水线并行(PP)定义不在ProcessGroup中描述,它需要另一套机制,我们将在未来补充ColoTensor应用于PP的相关内容。
+
+目前,ColoTensor 的一个进程组由 tp_degree 和 dp_degree 两种配置定义。 在 DP+TP 混合并行的情况下,可以将设备视为 2D 网格。 我们将 TP 通信组放置在设备网格的前导低维上,然后将数据并行组放置在设备网格的高维上。 原因是张量并行比数据并行具有更大的通信开销。 相邻设备放置在一个 TP 进程组内,并且通常放置在同一个节点中。
+
+考虑到8个进程配置为tp_degree=4,dp_degree=2,布局如下图。 进程组 tp0 包含 gpu 0,1,2,3。 进程 dp1 包含 gpu 1 和 5。
+
+
+
+Process Group using tp_degree=4, dp_degree=2
+
+
+## Distributed Spec
+
+[Distributed Spec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html)描述了 ColoTensor 如何在 ProcessGroup 中分布。
+
+张量在 DP 进程组之间的分布方式是自动导出的,不需要用户手动指定。 如果这个张量是一个模型参数,它会在 DP 进程组中被复制。 如果是activation张量,则沿tensor最高维度在DP进程组中进行平均分割。
+
+因此,在使用 Distributed Spec 时,我们只需要描述张量在 TP 进程组之间的分布方式即可。 TP 进程组目前有两种分布式规范,即 [ShardSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ShardSpec)和[ReplicaSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.distspec.html#colossalai.tensor.distspec.ReplicaSpec)。 ShardSpec 需要指定分区的维度索引 dim 和分区个数 num_partitions。 目前,我们仅支持在单个dim上进行拆分。 TP进程组上不同的dist spec可以通过set_dist_spec()接口相互转换。这些转化操作可以被记录在PyTorch的自动求导机制中,并在反向传播时候触发对应的反向操作。
+
+## Compute Spec
+
+[ComputeSpec](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.compute_spec.html#colossalai.tensor.compute_spec.ComputeSpec)类描述Tensor如何参与计算。目前,我们将作为module parameter的ColoTensor设置正确的Compute Pattern。可以触发正取的计算模式。具体应用方式我们会在接下来的文档中展示。
+
+## ColoParameter
+
+[ColoParameter](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.tensor.colo_parameter.html#colossalai.tensor.colo_parameter.ColoParameter)是ColoTensor的子类。用来声明Parameter。他和ColoTensor关系和Torch.Tensor和torch.Parameter一致。后者可以让tensor出现在module的parameters()和name_parameters() 的返回值中。
+
+## Example
+
+让我们看一个例子。 使用 tp_degree=4, dp_dgree=2 在 8 个 GPU 上初始化并Shard一个ColoTensor。 然后tensor被沿着 TP 进程组中的最后一个维度进行分片。 最后,我们沿着 TP 进程组中的第一个维度(dim 0)对其进行重新Shard。 我们鼓励用户运行代码并观察每个张量的形状。
+
+
+```python
+import torch
+import torch.multiprocessing as mp
+from colossalai.utils import free_port, print_rank_0
+from functools import partial
+
+import colossalai
+from colossalai.tensor import ProcessGroup, ColoTensor, ColoTensorSpec, ShardSpec, ComputeSpec, ComputePattern
+from colossalai.utils import free_port
+
+import torch
+
+def run_dist_tests(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ pg = ProcessGroup(tp_degree=2, dp_degree=2)
+
+ torch.manual_seed(0)
+ local_tensor = torch.randn(2, 3, 1).cuda()
+ print_rank_0(f"shape {local_tensor.shape}, {local_tensor.data}")
+
+ spec = ColoTensorSpec(pg, ShardSpec(dims=[-1], num_partitions=[pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
+ t1 = ColoTensor.from_torch_tensor(local_tensor, spec)
+ t1 = t1.to_replicate()
+ print_rank_0(f"shape {t1.shape}, {t1.data}")
+
+ spec2 = ShardSpec([0], [pg.tp_world_size()])
+ t1.set_dist_spec(spec2)
+ print_rank_0(f"shape {t1.shape}, {t1.data}")
+
+def test_dist_cases(world_size):
+ run_func = partial(run_dist_tests, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+if __name__ == '__main__':
+ test_dist_cases(4)
+```
+
+:::caution
+
+The ColoTensor is an experimental feature and may be updated.
+
+:::
diff --git a/docs/source/zh-Hans/basics/command_line_tool.md b/docs/source/zh-Hans/basics/command_line_tool.md
new file mode 100644
index 000000000000..9b0275a6cedd
--- /dev/null
+++ b/docs/source/zh-Hans/basics/command_line_tool.md
@@ -0,0 +1,47 @@
+# 命令行工具
+
+作者: Shenggui Li
+
+**预备知识:**
+- [Distributed Training](../concepts/distributed_training.md)
+- [Colossal-AI Overview](../concepts/colossalai_overview.md)
+
+## 简介
+
+Colossal-AI给用户提供了命令行工具,目前命令行工具可以用来支持以下功能。
+- 检查Colossal-AI是否安装正确
+- 启动分布式训练
+- 张量并行基准测试
+
+## 安装检查
+
+用户可以使用`colossalai check -i`这个命令来检查目前环境里的版本兼容性以及CUDA Extension的状态。
+
+
+
+Check Installation Demo
+
+
+## 启动分布式训练
+
+在分布式训练时,我们可以使用`colossalai run`来启动单节点或者多节点的多进程,详细的内容可以参考[启动 Colossal-AI](./launch_colossalai.md)。
+
+## 张量并行基准测试
+
+Colossal-AI提供了多种张量并行,想要充分理解这些方法需要一定的学习成本,对于新手来说很难靠经验选择一个并行方式。
+所以我们提供了一个简单的基准测试,能够让用户在自己的机器上测试不同张量并行的性能。这个基准测试跑一个并行的MLP模型,
+输入数据的维度为`(批大小,序列长度,隐藏层维度)`。通过指定GPU的数量,Colossal-AI会搜索所有可行的并行配置。用户可以通过查看`colossalai benchmark --help`来自定义相关的测试参数。
+
+```shell
+# 使用4个GPU
+colossalai benchmark --gpus 4
+
+# 使用8个GPU
+colossalai benchmark --gpus 8
+```
+
+:::caution
+
+目前仅支持单节点的基准测试。
+
+:::
diff --git a/docs/source/zh-Hans/basics/configure_parallelization.md b/docs/source/zh-Hans/basics/configure_parallelization.md
new file mode 100644
index 000000000000..eb4b38f48ddb
--- /dev/null
+++ b/docs/source/zh-Hans/basics/configure_parallelization.md
@@ -0,0 +1,136 @@
+# 并行配置
+
+作者: Shenggui Li, Siqi Mai
+
+**预备知识:**
+- [分布式训练](../concepts/distributed_training.md)
+- [并行技术](../concepts/paradigms_of_parallelism.md)
+- [构建配置文件](./define_your_config.md)
+
+
+## 简介
+
+我们在 Colossal-AI 中支持多种并行技术。代码库中的混合并行是指您可以轻松地结合数据并行、流水线并行和张量并行(1D、2D、2.5D、3D)的优势共同来进行并行训练。
+
+每种并行方式需要不同的网络拓扑结构,因此要初始化不同的进程组。您可以通过在配置文件中设置 `parallel` 来初始化相应的进程组。 `parallel` 的配置必须遵从以下格式。数据并行度的大小将被根据您对流水线并行和张量并行的输入自动推断。`colossalai.launch` 将根据您的配置自动初始化这些分布式进程组。
+
+我们为您提供了一些配置的例子以供参考。
+
+```python
+# sampler format
+parallel = dict(
+ pipeline=dict("size": int),
+ tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any)
+)
+
+# this is ok
+parallel = dict(
+ pipeline=dict(size=2),
+ tensor=dict(size=4, mode='2d')
+)
+
+# this is ok
+parallel = dict(
+ pipeline=2,
+ tensor=dict(size=4, mode='2d')
+)
+
+# this is not ok
+# as you need to specify the mode for tensor parallelism
+parallel = dict(
+ pipeline=2,
+ tensor=4
+)
+
+# this is ok as well as tensor will be default to size 1
+# and mode None
+parallel = dict(
+ pipeline=2
+)
+
+# this is ok as well as pipeline will default to size 1
+parallel = dict(
+ tensor=dict(size=4, mode='2d')
+)
+
+```
+
+关键字 `size` 指的是并行维度的并行大小。 例如,流水线大小为2意味着有
+将有2个流水线阶段。张量并行配置中的关键字 `mode` 意味着相应的张量并行技术
+将被初始化,如1D、2D、2.5D、3D。
+
+**您也可以选择不在您的配置中使用 "并行",此时流水线和张量的并行度都将默认为大小1。**
+
+**GPU的总数量必须等于` 数据并行大小 x 张量并行大小 x 流水线并行大小` 。**
+
+## 数据并行
+
+数据并行是最常见的分布式训练方式。它将数据分割成几个碎片分别在每个设备上进行训练。数据并行的配置会自动检测并为您设置。您不需要在您的配置中明确地设置它们。在Colossal-AI 中,有两种方法来处理数据并行的 all-reduce。
+
+1. 如果您设置了梯度handler,梯度handler将会all-reduce梯度。
+2. 若没有指定相应的配置,Colossal-AI 将会使用 PyTorch 的 DistributedDataParallel。
+
+在大多数情况下,若您对梯度没有复杂的处理的需求,您将会使用第二种模式。
+
+## 1D, 2D, 2.5D 和 3D 并行
+
+为了实现混合并行,我们提供了一系列张量并行方法。您可以阅读相应的学术论文进行深入的了解。这些并行模式需要和 Colossal-AI 提供的分布式层一同工作。
+
+- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
+
+- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
+ 2D 并行基于 SUMMA 矩阵乘法,它将输入数据、模型权重和层输出切分成两个不同的维度。 这些张量块分布在 `P = N^2` 设备的二维网格上,其中 `N` 是单一维度上张量块的数量。
+
+- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
+ 在 2.5D 矩阵乘法的启发下,2.5D 并行引入了一种新的张量并行,进一步将2D张量并行化。其中,`P = N^2 ∗ d` 个处理器被分配到 `d` 层, 每层独立进行矩阵乘法运算,维度为 `N`。
+
+- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
+ 我们还介绍了一种 3D 张量并行方法,在三维处理器立方体上并行化神经网络。这种方法在数量为 `P` 的处理器上实现了最佳的 `O(P^{1/3})` 通信开销,而计算和内存的使用都是通过优化的参数和激活的负载平衡来实现的。同时,通过优化参数和 activations 的负载平衡,计算和内存的使用都是均匀分布的。
+
+```python
+# 1D parallel
+parallel = dict(
+ tensor=dict(size=4, mode='1d')
+)
+
+# 2D parallel
+parallel = dict(
+ tensor=dict(size=4, mode='2d')
+)
+
+# 2.5D parallel
+parallel = dict(
+ tensor=dict(size=8, mode='2.5d', depth=2)
+)
+
+# 3D parallel
+parallel = dict(
+ tensor=dict(size=8, mode='3d')
+)
+```
+
+当您在配置中指定了张量并行模式,您就可以使用其相应的分布式算子。例如,若您设置模式为 `2d`,那么在模型构建中就能使用 `colossalai.nn.Linear2D` 了。
+
+
+## 流水线并行
+
+流水线并行是将模型按层分成几个部分。例如,假设我们有一个简单的模型,它由两个线性层组成。我们有两个 GPU,我们可以将第一个线性层分配给第一个 GPU 而第二层则分配给第二个 GPU。
+
+您可以在您的配置文件中设置流水线并行度的大小。当流水线并行度大于1,Colossal-AI 将会自动地创建流水线并行的 schedule,这将会为您定义好模型训练的 `forward` 和 `backward`。
+
+```python
+parallel = dict(
+ pipeline=dict(size=4), # number of pipeline stages
+)
+```
+
+## 序列并行
+
+针对处理大图片、视频、长文本、长时间医疗监控等数据的需要,Colossal-AI 还提供了序列并行的方法。该方法是在论文[Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120)中提出的。您可以指定模式为 `sequence` 来初始化进程组。
+
+
+```python
+parallel = dict(
+ tensor=dict(size=4, mode='sequence')
+)
+```
diff --git a/docs/source/zh-Hans/basics/define_your_config.md b/docs/source/zh-Hans/basics/define_your_config.md
new file mode 100644
index 000000000000..d7e49cbf23de
--- /dev/null
+++ b/docs/source/zh-Hans/basics/define_your_config.md
@@ -0,0 +1,71 @@
+# 构建配置文件
+
+作者: Guangyang Lu, Shenggui Li, Siqi Mai
+
+**预备知识:**
+- [分布式训练](../concepts/distributed_training.md)
+- [Colossal-AI 总览](../concepts/colossalai_overview.md)
+
+
+## 简介
+
+在 Colossal-AI 中,我们需要一个配置文件来指定系统在训练过程中要注入的特征。在本教程中,我们将向您介绍如何构建您的配置文件以及如何使用这个配置文件。使用配置文件有以下一些好处:
+
+1. 您可以在不同的配置文件中存储您的特征配置和训练超参数。
+2. 对于我们未来发布的新功能,您亦可以在配置中指定,而无需改变训练脚本的代码。
+
+在本教程中,我们将向您介绍如何构建您的配置文件。
+
+## 配置定义
+
+在一个配置文件中,有两种类型的变量。一种是作为特征说明,另一种是作为超参数。所有与特征相关的变量都是保留关键字。例如,如果您想使用混合精度训练,需要在 config 文件中使用变量名`fp16`,并遵循预先定义的格式。
+
+### 功能配置
+
+Colossal-AI 提供了一系列的功能来加快训练速度。每个功能都是由配置文件中的相应字段定义的。在本教程中,我们不会给出所有功能的配置细节,而是提供一个如何指定一个功能的说明。**每个功能的细节可以在其各自的教程中找到。**
+
+为了说明配置文件的使用,我们在这里使用混合精度训练作为例子。您需要遵循以下步骤。
+
+1. 创建一个配置文件(例如 `config.py`,您可以指定任意的文件名)。
+2. 在配置文件中定义混合精度的配置。例如,为了使用 PyTorch 提供的原始混合精度训练,您只需将下面这几行代码写入您的配置文件中。
+
+ ```python
+ from colossalai.amp import AMP_TYPE
+
+ fp16 = dict(
+ mode=AMP_TYPE.TORCH
+ )
+ ```
+
+3. 当启动分布式环境时,向 Colossal-AI 指定您的配置文件的位置。比如下面的例子是配置文件在当前目录下。
+
+ ```python
+ import colossalai
+
+ colossalai.launch(config='./config.py', ...)
+ ```
+
+这样,Colossal-AI 便知道您想使用什么功能,并会在 `colossalai.initialize` 期间注入您所需要的功能。
+
+### 全局超参数
+
+除了功能的配置,您还可以在配置文件中定义训练的超参数。当您想进行多个实验时,这将会变得非常方便。每个实验的细节都可以放在独立的配置文件中,以避免混乱。这些参数将被存储在全局并行环境中,可以在训练脚本中访问。
+
+例如,您可以在配置文件中指定批量大小。
+
+```python
+BATCH_SIZE = 32
+```
+
+启动后,您能够通过全局并行上下文访问您的超参数。
+
+```python
+import colossalai
+from colossalai.core import global_context as gpc
+
+colossalai.launch(config='./config.py', ...)
+
+# access your parameter
+print(gpc.config.BATCH_SIZE)
+
+```
diff --git a/docs/source/zh-Hans/basics/engine_trainer.md b/docs/source/zh-Hans/basics/engine_trainer.md
new file mode 100644
index 000000000000..a7519bfca14f
--- /dev/null
+++ b/docs/source/zh-Hans/basics/engine_trainer.md
@@ -0,0 +1,384 @@
+# 如何在训练中使用 Engine 和 Trainer
+
+作者: Shenggui Li, Siqi Mai
+
+**预备知识:**
+- [初始化功能](./initialize_features.md)
+
+## 简介
+
+在本教程中,您将学习如何使用 Colossal-AI 中提供的 Engine 和 Trainer 来训练您的模型。在深入研究细节之前,我们想先解释一下 Engine 和 Trainer 的概念。
+
+### Engine
+
+Engine 本质上是一个模型、优化器和损失函数的封装类。当我们调用 `colossalai.initialize` 时,一个 Engine 对象将被返回,并且配备了在您的配置文件中指定的梯度剪裁、梯度累计和 ZeRO 优化器等功能。
+
+Engine 将使用与 PyTorch 训练组件类似的 API,因此您只需对代码进行微小的修改即可。
+
+下表展示了Engine的常用API。
+
+| 组件 | 功能 | PyTorch | Colossal-AI |
+| ------------------------------------- | --------------------------------------------- | ------------------------------- | -------------------------------------- |
+| optimizer | 迭代前将所有梯度设置为零 | optimizer.zero_grad() | engine.zero_grad() |
+| optimizer | 更新参数 | optimizer.step() | engine.step() |
+| model | 进行一次前向计算 | outputs = model(inputs) | outputs = engine(inputs) |
+| criterion | 计算loss值 | loss = criterion(output, label) | loss = engine.criterion(output, label) |
+| criterion | 反向计算 | loss.backward() | engine.backward(loss) |
+
+我们需要这样一个 Engine 类的原因是,我们可以添加更多的功能,同时将实现隐藏在
+`colossalai.initialize` 函数中实现。
+假如我们要添加一个新的功能,我们可以在 `colossalai.initialize` 函数中完成对于模型、优化器、数据加载器和损失函数的功能诠释。不管中间的过程有多复杂,最终我们呈现的以及用户需要使用的只有一个 Engine 类,这将十分便捷。
+用户只需要在最小范围内修改他们的代码,将普通的 PyTorch APIs 调整为 Colossal-AI
+Engine 的 API。通过这种方式,他们可以享受更多的功能来进行有效的训练。
+
+以下是一个简单的例子:
+
+```python
+import colossalai
+
+# build your model, optimizer, criterion, dataloaders
+...
+
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader,
+ test_dataloader)
+for img, label in train_dataloader:
+ engine.zero_grad()
+ output = engine(img)
+ loss = engine.criterion(output, label)
+ engine.backward(loss)
+ engine.step()
+```
+
+### Trainer
+
+Trainer 是一个更高级的封装器,用户可以用更少的代码行来执行训练。 由于 Trainer 的使用会更加简单,相较于 Engine,它会缺少一点灵活性。 Trainer 被设计为进行前向和反向计算来进行模型权重的更新。通过传递 Engine 对象,我们可以很容易地创建一个 Trainer。
+Trainer 的参数 `schedule` 默认值是 `None` 。在大多数情况下,除非我们想使用流水线并行,否则我们把这个值设为 `None`。如果您想探索更多关于这个参数的内容,您可以前往流水线并行的相关教程。
+
+```python
+from colossalai.logging import get_dist_logger
+from colossalai.trainer import Trainer, hooks
+
+# build components and initialize with colossalai.initialize
+...
+
+# create a logger so that trainer can log on the console
+logger = get_dist_logger()
+
+# create a trainer object
+trainer = Trainer(
+ engine=engine,
+ logger=logger
+)
+```
+
+在 Trainer 中,用户可以定制一些 hooks,并将这些 hooks 附加到 Trainer 上。hook 将根据训练方案定期地执行生命周期函数。例如,基于用户是想在每次训练迭代后还是只在整个训练周期后更新学习率,
+`LRSchedulerHook` 将会在 `after_train_iter` 或 `after_train_epoch` 阶段执行 `lr_scheduler.step()` 去为用户更新学习率。您可以将 hook 存储在一个列表中并将其传递给 `trainer.fit` 方法。`trainer.fit` 方法将根据您的参数执行训练和测试。如果 `display_process` 为 True,将在您的控制台显示一个进度条,以显示训练的过程。
+
+
+```python
+# define the hooks to attach to the trainer
+hook_list = [
+ hooks.LossHook(),
+ hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
+ hooks.AccuracyHook(accuracy_func=Accuracy()),
+ hooks.LogMetricByEpochHook(logger),
+]
+
+# start training
+trainer.fit(
+ train_dataloader=train_dataloader,
+ epochs=NUM_EPOCHS,
+ test_dataloader=test_dataloader,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True
+)
+```
+
+如果您想定制您的 hook 类,您可以继承 `hooks.BaseHook` 并重写您想要的生命周期方法。下面提供了一个例子来演示如何创建一个简单的关于日志信息的 hook,以供您参考。
+
+```python
+from colossalai.logging import get_dist_logger
+from colossalai.trainer import hooks
+
+class LogMessageHook(hooks.BaseHook):
+
+ def __init__(self, priority=10):
+ self._logger = get_dist_logger()
+
+ def before_train(self, trainer):
+ self._logger.info('training starts')
+
+ def after_train(self, trainer):
+ self._logger.info('training finished')
+
+
+...
+
+# then in your training script
+hook_list.append(LogMessageHook())
+```
+
+
+
+在下面的章节中,您将会详细地了解到如何用 Engine 和 Trainer 来训练 ResNet 模型。
+
+
+## ResNet
+
+### 总览
+
+在本节中,我们将介绍:
+
+1. 使用一个 Engine 在 CIFAR10 数据集上训练 ResNet34 模型
+2. 使用一个 Trainer 在 CIFAR10 数据集上训练 ResNet34 模型
+
+项目结构如下:
+
+```bash
+-- config.py
+-- run_resnet_cifar10_with_engine.py
+-- run_resnet_cifar10_with_trainer.py
+```
+
+对于使用 Engine 或 Trainer,步骤 1-4 是通用的。 因此,步骤 1-4 + 步骤 5 将会是对应 `run_resnet_cifar10_with_engine.py` 而 步骤 1-4 + 步骤6 则对应 `run_resnet_cifar10_with_trainer.py`。
+
+### 牛刀小试
+
+#### 步骤 1. 创建配置文件
+
+在你的项目文件夹中,创建一个 `config.py`。这个文件是用来指定一些您可能想用来训练您的模型的特征。下面是一个配置文件的例子。
+
+```python
+from colossalai.amp import AMP_TYPE
+
+BATCH_SIZE = 128
+NUM_EPOCHS = 200
+
+fp16=dict(
+ mode=AMP_TYPE.TORCH
+)
+```
+
+在这个配置文件中,我们指定要在每个 GPU 上使用批大小为128,并运行200个 epoch。这两个参数是在 `gpc.config` 中体现的。例如,您可以使用 `gpc.config.BATCH_SIZE` 来访问您存储在配置文件中的批大小值。而 `fp16` 配置则会告诉 `colossalai.initialize` 使用 PyTorch 提供的混合精度训练,以更好的速度和更低的内存消耗来训练模型。
+
+#### 步骤 2. 初始化分布式环境
+
+我们需要初始化分布式训练环境。这在 [启动 Colossal-AI](./launch_colossalai.md) 中有相应的教程。在当前的演示中,我们使用 `launch_from_torch` 和 PyTorch 启用工具。
+
+```python
+import colossalai
+
+# ./config.py refers to the config file we just created in step 1
+colossalai.launch_from_torch(config='./config.py')
+```
+
+#### 步骤 3. 创建所有的训练组件
+
+这时,我们可以创建用于训练的所有组件,包括:
+
+1. 模型
+2. 优化器
+3. 损失函数
+4. 训练/测试数据加载器
+5. 学习率调度器
+6. 日志记录器
+
+
+
+为了构建这些组件,您需要导入以下模块。
+
+```python
+from pathlib import Path
+from colossalai.logging import get_dist_logger
+import torch
+import os
+from colossalai.core import global_context as gpc
+from colossalai.utils import get_dataloader
+from torchvision import transforms
+from colossalai.nn.lr_scheduler import CosineAnnealingLR
+from torchvision.datasets import CIFAR10
+from torchvision.models import resnet34
+```
+
+
+
+然后按照通常在PyTorch脚本中构建组件的方式来构建组件。在下面的脚本中,我们将CIFAR10数据集的根路径设置为环境变量 `DATA`。您可以把它改为您想要的任何路径,例如,您可以把 `root=Path(os.environ['DATA'])` 改为 `root='./data'` ,这样就不需要设置环境变量。
+
+```python
+# build logger
+logger = get_dist_logger()
+
+# build resnet
+model = resnet34(num_classes=10)
+
+# build datasets
+train_dataset = CIFAR10(
+ root='./data',
+ download=True,
+ transform=transforms.Compose(
+ [
+ transforms.RandomCrop(size=32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
+ 0.2023, 0.1994, 0.2010]),
+ ]
+ )
+)
+
+test_dataset = CIFAR10(
+ root='./data',
+ train=False,
+ transform=transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[
+ 0.2023, 0.1994, 0.2010]),
+ ]
+ )
+)
+
+# build dataloaders
+train_dataloader = get_dataloader(dataset=train_dataset,
+ shuffle=True,
+ batch_size=gpc.config.BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ )
+
+test_dataloader = get_dataloader(dataset=test_dataset,
+ add_sampler=False,
+ batch_size=gpc.config.BATCH_SIZE,
+ num_workers=1,
+ pin_memory=True,
+ )
+
+# build criterion
+criterion = torch.nn.CrossEntropyLoss()
+
+# optimizer
+optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
+
+# lr_scheduler
+lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
+```
+
+#### 步骤 4. 用 Colossal-AI 进行初始化
+
+接下来,重要的一步是通过调用 `colossalai.initialize` 获得 Engine。正如 `config.py` 中所述,我们将使用混合精度训练来训练 ResNet34 模型。`colossalai.initialize` 将自动检查您的配置文件,并将相关特征分配给您的训练组件。这样一来,我们的 Engine 已经能够进行混合精度训练,而您不需要进行额外的处理。
+
+```python
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader,
+ test_dataloader,
+ )
+```
+
+
+
+#### 步骤 5. 用 Engine 进行训练
+
+当所有的训练组件都准备好后,我们就可以像使用 PyTorch 一样训练 ResNet34 了。
+
+```python
+for epoch in range(gpc.config.NUM_EPOCHS):
+ # execute a training iteration
+ engine.train()
+ for img, label in train_dataloader:
+ img = img.cuda()
+ label = label.cuda()
+
+ # set gradients to zero
+ engine.zero_grad()
+
+ # run forward pass
+ output = engine(img)
+
+ # compute loss value and run backward pass
+ train_loss = engine.criterion(output, label)
+ engine.backward(train_loss)
+
+ # update parameters
+ engine.step()
+
+ # update learning rate
+ lr_scheduler.step()
+
+ # execute a testing iteration
+ engine.eval()
+ correct = 0
+ total = 0
+ for img, label in test_dataloader:
+ img = img.cuda()
+ label = label.cuda()
+
+ # run prediction without back-propagation
+ with torch.no_grad():
+ output = engine(img)
+ test_loss = engine.criterion(output, label)
+
+ # compute the number of correct prediction
+ pred = torch.argmax(output, dim=-1)
+ correct += torch.sum(pred == label)
+ total += img.size(0)
+
+ logger.info(
+ f"Epoch {epoch} - train loss: {train_loss:.5}, test loss: {test_loss:.5}, acc: {correct / total:.5}, lr: {lr_scheduler.get_last_lr()[0]:.5g}", ranks=[0])
+```
+
+#### 步骤 6. 用 Trainer 进行训练
+
+如果您想用 Trainer 进行训练,您可以参考下面的代码进行您的实验。
+
+
+```python
+from colossalai.nn.metric import Accuracy
+from colossalai.trainer import Trainer, hooks
+
+
+# create a trainer object
+trainer = Trainer(
+ engine=engine,
+ logger=logger
+)
+
+# define the hooks to attach to the trainer
+hook_list = [
+ hooks.LossHook(),
+ hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
+ hooks.AccuracyHook(accuracy_func=Accuracy()),
+ hooks.LogMetricByEpochHook(logger),
+ hooks.LogMemoryByEpochHook(logger)
+]
+
+# start training
+# run testing every 1 epoch
+trainer.fit(
+ train_dataloader=train_dataloader,
+ epochs=gpc.config.NUM_EPOCHS,
+ test_dataloader=test_dataloader,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True
+)
+```
+
+
+
+#### 步骤 7. 开始分布式训练
+
+最后,我们可以使用 PyTorch 提供的分布式启动器来调用脚本,因为我们在步骤2中使用了 `launch_from_torch`。您需要把`` 替换成您机器上可用的GPU数量。如果您只想使用一个 GPU,您可以把这个数字设为1。如果您想使用其他的启动器,请您参考如何启动 Colossal-AI 的教程。
+
+
+```bash
+# with engine
+python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_engine.py
+# with trainer
+python -m torch.distributed.launch --nproc_per_node --master_addr localhost --master_port 29500 run_resnet_cifar10_with_trainer.py
+```
diff --git a/docs/source/zh-Hans/basics/initialize_features.md b/docs/source/zh-Hans/basics/initialize_features.md
new file mode 100644
index 000000000000..67ea114b42b2
--- /dev/null
+++ b/docs/source/zh-Hans/basics/initialize_features.md
@@ -0,0 +1,46 @@
+# 初始化功能
+
+作者: Shenggui Li, Siqi Mai
+
+**预备知识:**
+- [分布式训练](../concepts/distributed_training.md)
+- [Colossal-AI 总览](../concepts/colossalai_overview.md)
+
+## 简介
+
+在本教程中,我们将介绍 `colossalai.initialize` 的使用。 它包含了如何将特征(例如,模型、优化器、数据加载器)无缝注入您的训练组件中。 调用 `colossalai.initialize` 是您进入训练循环前的基本操作。
+
+在下面一节中,我们将介绍 `colossalai.initialize` 是如何工作的以及使用中我们要注意的细节。
+
+## 使用
+
+在一个典型的工作流程中,我们将在训练脚本的开始启动分布式环境。
+之后,我们将实例化我们的对象,如模型、优化器、损失函数、数据加载器等。此时,我们可以使用 `colossalai.initialize` 便捷地为这些对象注入特征。
+具体细节请看以下的伪代码例子。
+
+```python
+import colossalai
+import torch
+...
+
+
+# launch distributed environment
+colossalai.launch(config='./config.py', ...)
+
+# create your objects
+model = MyModel()
+optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+criterion = torch.nn.CrossEntropyLoss()
+train_dataloader = MyTrainDataloader()
+test_dataloader = MyTrainDataloader()
+
+# initialize features
+engine, train_dataloader, test_dataloader, _ = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader,
+ test_dataloader)
+```
+
+ `colossalai.initialize` 将返回一个 `Engine` 对象。 该对象把模型、优化器和损失函数封装起来。 **`Engine` 对象会以配置文件中指定的特征运行。**
+关于 `Engine` 的更多使用细节可以在 [在训练中使用Engine和Trainer](./engine_trainer.md) 中获取。
diff --git a/docs/source/zh-Hans/basics/launch_colossalai.md b/docs/source/zh-Hans/basics/launch_colossalai.md
new file mode 100644
index 000000000000..ca927de578d5
--- /dev/null
+++ b/docs/source/zh-Hans/basics/launch_colossalai.md
@@ -0,0 +1,212 @@
+# 启动 Colossal-AI
+
+作者: Chuanrui Wang, Shenggui Li, Siqi Mai
+
+**预备知识:**
+- [分布式训练](../concepts/distributed_training.md)
+- [Colossal-AI 总览](../concepts/colossalai_overview.md)
+
+
+## 简介
+
+正如我们在前面的教程中所提到的,在您的配置文件准备好后,您需要为 Colossal-AI 初始化分布式环境。我们把这个过程称为 `launch`。在本教程中,您将学习如何在您的服务器上启动 Colossal-AI,不管是小型的还是大型的。
+
+在 Colossal-AI 中,我们提供了几种启动方法来初始化分布式后端。
+在大多数情况下,您可以使用 `colossalai.launch` 和 `colossalai.get_default_parser` 来通过命令行传递参数。如果您想使用 SLURM、OpenMPI 和 PyTorch 等启动工具,我们也提供了几个启动的辅助方法以便您的使用。您可以直接从这些启动工具设置的环境变量中访问 rank 和 world size 大小。
+
+在本教程中,我们将介绍如何启动 Colossal-AI 来初始化分布式后端:
+- 用 colossalai.launch 启动
+- 用 Colossal-AI命令行 启动
+- 用 SLURM 启动
+- 用 OpenMPI 启动
+
+## 启动分布式环境
+
+为了启动 Colossal-AI,我们需要两类参数:
+1. 配置文件
+2. 分布式设置
+
+无论我们使用何种启动方式,配置文件是必须要求的,而分布式设置有可能依情况而定。配置文件可以是配置文件的路径或 Python dictionary 的形式。分布式设置可以通过命令行或多进程启动器传递。
+
+### 命令行解析器
+
+在使用 `launch` 之前, 我们首先需要了解我们需要哪些参数来进行初始化。
+如[分布式训练](../concepts/distributed_training.md) 中 `基本概念` 一节所述 ,涉及的重要参数是:
+
+1. host
+2. port
+3. rank
+4. world_size
+5. backend
+
+在 Colossal-AI 中,我们提供了一个命令行解析器,它已经提前添加了这些参数。您可以通过调用 `colossalai.get_default_parser()` 来获得这个解析器。这个解析器通常与 `colossalai.launch` 一起使用。
+
+```python
+# add these lines in your train.py
+import colossalai
+
+# get default parser
+parser = colossalai.get_default_parser()
+
+# if you want to add your own arguments
+parser.add_argument(...)
+
+# parse arguments
+args = parser.parse_args()
+```
+
+您可以在您的终端传入以下这些参数。
+```shell
+
+python train.py --host --rank --world_size --port --backend
+```
+
+`backend` 是用户可选的,默认值是 nccl。
+
+### 本地启动
+
+为了初始化分布式环境,我们提供了一个通用的 `colossalai.launch` API。`colossalai.launch` 函数接收上面列出的参数,并在通信网络中创建一个默认的进程组。方便起见,这个函数通常与默认解析器一起使用。
+
+```python
+import colossalai
+
+# parse arguments
+args = colossalai.get_default_parser().parse_args()
+
+# launch distributed environment
+colossalai.launch(config=,
+ rank=args.rank,
+ world_size=args.world_size,
+ host=args.host,
+ port=args.port,
+ backend=args.backend
+)
+
+```
+
+
+### 用 Colossal-AI命令行工具 启动
+
+为了更好地支持单节点以及多节点的训练,我们通过封装PyTorch的启动器实现了一个更加方便的启动器。
+PyTorch自带的启动器需要在每个节点上都启动命令才能启动多节点训练,而我们的启动器只需要一次调用即可启动训练。
+
+首先,我们需要在代码里指定我们的启动方式。由于这个启动器是PyTorch启动器的封装,那么我们自然而然应该使用`colossalai.launch_from_torch`。
+分布式环境所需的参数,如 rank, world size, host 和 port 都是由 PyTorch 启动器设置的,可以直接从环境变量中读取。
+
+```python
+import colossalai
+
+colossalai.launch_from_torch(
+ config=,
+)
+```
+
+接下来,我们可以轻松地在终端使用`colossalai run`来启动训练。下面的命令可以在当前机器上启动一个4卡的训练任务。
+你可以通过设置`nproc_per_node`来调整使用的GPU的数量,也可以改变`master_port`的参数来选择通信的端口。
+
+```shell
+# 在当前节点上启动4卡训练 (默认使用29500端口)
+colossalai run --nproc_per_node 4 train.py
+
+# 在当前节点上启动4卡训练,并使用一个不同的端口
+colossalai run --nproc_per_node 4 --master_port 29505 test.py
+```
+
+如果你在使用一个集群,并且想进行多节点的训练,你需要使用Colossal-AI的命令行工具进行一键启动。我们提供了两种方式来启动多节点任务
+
+- 通过`--hosts`来启动
+
+这个方式适合节点数不多的情况。假设我们有两个节点,分别为`host`和`host2`。我们可以用以下命令进行多节点训练。
+比起单节点训练,多节点训练需要手动设置`--master_addr` (在单节点训练中`master_addr`默认为`127.0.0.1`)。
+
+:::caution
+
+多节点训练时,`master_addr`不能为`localhost`或者`127.0.0.1`,它应该是一个节点的名字或者IP地址。
+
+:::
+
+```shell
+# 在两个节点上训练
+colossalai run --nproc_per_node 4 --host host1,host2 --master_addr host1 test.py
+```
+
+
+- 通过`--hostfile`来启动
+
+这个方式适用于节点数很大的情况。host file是一个简单的文本文件,这个文件里列出了可以使用的节点的名字。
+在一个集群中,可用节点的列表一般由SLURM或者PBS Pro这样的集群资源管理器来提供。比如,在SLURM中,
+你可以从`SLURM_NODELIST`这个环境变量中获取到当前分配列表。在PBS Pro中,这个环境变量为`PBS_NODEFILE`。
+可以通过`echo $SLURM_NODELIST` 或者 `cat $PBS_NODEFILE` 来尝试一下。如果你没有这样的集群管理器,
+那么你可以自己手动写一个这样的文本文件即可。
+
+提供给Colossal-AI的host file需要遵循以下格式,每一行都是一个节点的名字。
+
+```text
+host1
+host2
+```
+
+如果host file准备好了,那么我们就可以用以下命令开始多节点训练了。和使用`--host`一样,你也需要指定一个`master_addr`。
+当使用host file时,我们可以使用一些额外的参数:
+- `--include`: 设置你想要启动训练的节点。比如,你的host file里有8个节点,但是你只想用其中的6个节点进行训练,
+ 你可以添加`--include host1,host2,host3,...,host6`,这样训练任务只会在这6个节点上启动。
+
+- `--exclude`: 设置你想排除在训练之外的节点。当你的某一些节点坏掉时,这个参数会比较有用。比如假如host1的GPU有一些问题,无法正常使用,
+ 那么你就可以使用`--exclude host1`来将其排除在外,这样你就可以训练任务就只会在剩余的节点上启动。
+
+```shell
+# 使用hostfile启动
+colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 test.py
+
+# 只使用部分节点进行训练
+colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --include host1 test.py
+
+# 不使用某些节点进行训练
+colossalai run --nproc_per_node 4 --hostfile ./hostfile --master_addr host1 --exclude host2 test.py
+```
+
+
+### 用 SLURM 启动
+
+如果您是在一个由 SLURM 调度器管理的系统上, 您也可以使用 `srun` 启动器来启动您的 Colossal-AI 脚本。我们提供了辅助函数 `launch_from_slurm` 来与 SLURM 调度器兼容。
+`launch_from_slurm` 会自动从环境变量 `SLURM_PROCID` 和 `SLURM_NPROCS` 中分别读取 rank 和 world size ,并使用它们来启动分布式后端。
+
+您可以在您的训练脚本中尝试以下操作。
+
+```python
+import colossalai
+
+colossalai.launch_from_slurm(
+ config=,
+ host=args.host,
+ port=args.port
+)
+```
+
+您可以通过在终端使用这个命令来初始化分布式环境。
+
+```bash
+srun python train.py --host --port 29500
+```
+
+### 用 OpenMPI 启动
+如果您对OpenMPI比较熟悉,您也可以使用 `launch_from_openmpi` 。
+`launch_from_openmpi` 会自动从环境变量
+`OMPI_COMM_WORLD_LOCAL_RANK`, `MPI_COMM_WORLD_RANK` 和 `OMPI_COMM_WORLD_SIZE` 中分别读取local rank、global rank 和 world size,并利用它们来启动分布式后端。
+
+您可以在您的训练脚本中尝试以下操作。
+```python
+colossalai.launch_from_openmpi(
+ config=,
+ host=args.host,
+ port=args.port
+)
+```
+
+以下是用 OpenMPI 启动多个进程的示例命令。
+```bash
+mpirun --hostfile -np python train.py --host --port 29500
+```
+
+- --hostfile: 指定一个要运行的主机列表。
+- --np: 设置总共要启动的进程(GPU)的数量。例如,如果 --np 4,4个 python 进程将被初始化以运行 train.py。
diff --git a/docs/source/zh-Hans/basics/model_checkpoint.md b/docs/source/zh-Hans/basics/model_checkpoint.md
new file mode 100644
index 000000000000..cec12d451989
--- /dev/null
+++ b/docs/source/zh-Hans/basics/model_checkpoint.md
@@ -0,0 +1,61 @@
+# 模型检查点
+
+作者 : Guangyang Lu
+
+**预备知识:**
+- [Launch Colossal-AI](./launch_colossalai.md)
+- [Initialize Colossal-AI](./initialize_features.md)
+
+**示例代码:**
+- [ColossalAI-Examples Model Checkpoint](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/utils/checkpoint)
+
+**函数是经验函数.**
+
+## 简介
+
+本教程将介绍如何保存和加载模型检查点。
+
+为了充分利用Colossal-AI的强大并行策略,我们需要修改模型和张量,可以直接使用 `torch.save` 或者 `torch.load` 保存或加载模型检查点。在Colossal-AI中,我们提供了应用程序接口实现上述同样的效果。
+
+但是,在加载时,你不需要使用与存储相同的保存策略。
+
+## 使用方法
+
+### 保存
+
+有两种方法可以使用Colossal-AI训练模型,即使用engine或使用trainer。
+**注意我们只保存 `state_dict`.** 因此,在加载检查点时,需要首先定义模型。
+
+#### 同 engine 保存
+
+```python
+from colossalai.utils import save_checkpoint
+model = ...
+engine, _, _, _ = colossalai.initialize(model=model, ...)
+for epoch in range(num_epochs):
+ ... # do some training
+ save_checkpoint('xxx.pt', epoch, model)
+```
+
+#### 用 trainer 保存
+```python
+from colossalai.trainer import Trainer, hooks
+model = ...
+engine, _, _, _ = colossalai.initialize(model=model, ...)
+trainer = Trainer(engine, ...)
+hook_list = [
+ hooks.SaveCheckpointHook(1, 'xxx.pt', model)
+ ...]
+
+trainer.fit(...
+ hook=hook_list)
+```
+
+### 加载
+
+```python
+from colossalai.utils import load_checkpoint
+model = ...
+load_checkpoint('xxx.pt', model)
+... # train or test
+```
diff --git a/docs/source/zh-Hans/concepts/colossalai_overview.md b/docs/source/zh-Hans/concepts/colossalai_overview.md
new file mode 100755
index 000000000000..cfb35e59e64a
--- /dev/null
+++ b/docs/source/zh-Hans/concepts/colossalai_overview.md
@@ -0,0 +1,36 @@
+# Colossal-AI 总览
+
+作者: Shenggui Li, Siqi Mai
+
+## 关于 Colossal-AI
+
+随着深度学习模型规模的发展,向新的训练模式转变是非常重要的。没有并行和优化的传统训练方法将成为过去,新的训练方法是使训练大规模模型高效和节省成本的关键。
+
+Colossal-AI 是一个集成的系统,为用户提供一套综合的训练方法。您可以找到常见的训练方法,如混合精度训练和梯度累积。此外,我们提供了一系列的并行技术,包括数据并行、张量并行和流水线并行。我们通过不同的多维分布式矩阵乘法算法来优化张量并行。我们还提供了不同的流水线并行方法,使用户能够有效地跨节点扩展他们的模型。更多的高级功能,如卸载,也可以在这个教程文档中找到详细的内容。
+
+## Colossal-AI 的使用
+
+我们的目标是使 Colossal-AI 易于使用,并且对用户的代码不产生干扰。如果您想使用Colossal-AI,这里有一个简单的一般工作流程。
+
+
+
+Workflow
+
+
+1. 准备一个配置文件,指定您要使用的功能和参数。
+2. 用 `colossalai.launch` 初始化分布式后端。
+3. 用 `colossalai.initialize` 将训练特征注入您的训练组件(如模型、优化器)中。
+4. 进行训练和测试.
+
+我们将在`基本教程`部分介绍整个工作流程。
+
+## 未来计划
+
+Colossal-AI 系统将会进一步拓展和优化,包括但不限于:
+
+1. 分布式操作的优化
+2. 异构系统训练的优化
+3. 从模型大小的维度切入,提升训练速度并维持精度
+4. 拓展现有的并行方法
+
+**我们始终欢迎社区的建议和讨论,如果您遇到任何问题,我们将非常愿意帮助您。您可以在GitHub 提 [issue](https://github.com/hpcaitech/ColossalAI/issues) ,或在[论坛](https://github.com/hpcaitech/ColossalAI/discussions)上创建一个讨论主题。**
diff --git a/docs/source/zh-Hans/concepts/distributed_training.md b/docs/source/zh-Hans/concepts/distributed_training.md
new file mode 100755
index 000000000000..97b3844daa16
--- /dev/null
+++ b/docs/source/zh-Hans/concepts/distributed_training.md
@@ -0,0 +1,88 @@
+# 分布式训练
+
+作者: Shenggui Li, Siqi Mai
+
+## 什么是分布式系统?
+
+
+
+图片来源: Towards Data Science
+
+
+分布式系统由多个软件组件组成,在多台机器上运行。例如,传统的数据库运行在一台机器上。随着数据量的爆发式增长,单台机器已经不能为企业提供理想的性能。特别是在双十一这样的网络狂欢节,网络流量会出乎意料的大。为了应对这种压力,现代高性能数据库被设计成在多台机器上运行,它们共同为用户提供高吞吐量和低延迟。
+
+分布式系统的一个重要评价指标是可扩展性。例如,当我们在4台机器上运行一个应用程序时,我们自然希望该应用程序的运行速度能提高4倍。然而,由于通信开销和硬件性能的差异,很难实现线性提速。因此,当我们实现应用程序时,必须考虑如何使其更快。良好的设计和系统优化的算法可以帮助我们提供良好的性能。有时,甚至有可能实现线性和超线性提速。
+
+
+## 为什么我们需要机器学习的分布式训练?
+
+早在2012年,[AlexNet](https://arxiv.org/abs/1404.5997) 就赢得了ImageNet比赛的冠军,而它是在两张 GTX 580 3GB GPU 上训练的。今天,大多数出现在顶级人工智能会议上的模型都是在多个GPU上训练的。当研究人员和工程师开发人工智能模型时,分布式训练无疑是一种常见的做法。这一趋势背后有几个原因。
+
+1. 模型规模迅速增加。2015年的 [ResNet50](https://arxiv.org/abs/1512.03385) 有2000万的参数,
+2018年的 [BERT-Large](https://arxiv.org/abs/1810.04805)有3.45亿的参数,2018年的
+[GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
+有15亿的参数,而2020年的 [GPT-3](https://arxiv.org/abs/2005.14165) 有1750亿个参数。很明显,模型规模随着时间的推移呈指数级增长。目前最大的模型已经超过了1000多亿个参数。而与较小的模型相比,超大型模型通常能提供更优越的性能。
+
+
+图片来源: HuggingFace
+
+
+
+2. 数据集规模迅速增加。对于大多数机器学习开发者来说,MNIST 和 CIFAR10 数据集往往是他们训练模型的前几个数据集。然而,与著名的 ImageNet 数据集相比,这些数据集非常小。谷歌甚至有自己的(未公布的)JFT-300M 数据集,它有大约3亿张图片,这比 ImageNet-1k 数据集大了近300倍。
+
+
+3. 计算能力越来越强。随着半导体行业的进步,显卡变得越来越强大。由于核的数量增多,GPU是深度学习最常见的算力资源。从2012年的 K10 GPU 到2020年的 A100 GPU,计算能力已经增加了几百倍。这使我们能够更快地执行计算密集型任务,而深度学习正是这样一项任务。
+
+如今,我们接触到的模型可能太大,以致于无法装入一个GPU,而数据集也可能大到足以在一个GPU上训练一百天。这时,只有用不同的并行化技术在多个GPU上训练我们的模型,我们才能完成并加快模型训练,以追求在合理的时间内获得想要的结果。
+
+
+## 分布式训练的基本概念
+
+分布式训练需要多台机器/GPU。在训练期间,这些设备之间会有通信。为了更好地理解分布式训练,有几个重要的术语需要我们了解清楚。
+
+- host: 主机(host)是通信网络中的主要设备。在初始化分布式环境时,经常需要它作为一个参数。
+- port: 这里的端口(port)主要是指主机上用于通信的主端口。
+- rank: 在网络中赋予设备的唯一ID。
+- world size: 网络中设备的数量。
+- process group: 进程组(process group)是一个通信网络,包括设备的一个子集。总是有一个默认的进程组,它包含所有的设备。一个子集的设备可以形成一个进程组,以便它们只在组内的设备之间进行通信。
+
+
+
+一个分布式系统的例子
+
+
+为了说明这些概念,让我们假设我们有2台机器(也称为节点),每台机器有4个 GPU。当我们在这两台机器上初始化分布式环境时,我们基本上启动了8个进程(每台机器上有4个进程),每个进程被绑定到一个 GPU 上。
+
+在初始化分布式环境之前,我们需要指定主机(主地址)和端口(主端口)。在这个例子中,我们可以让主机为节点0,端口为一个数字,如29500。所有的8个进程将寻找地址和端口并相互连接,默认的进程组将被创建。默认进程组的 world size 为8,细节如下。
+
+| process ID | rank | Node index | GPU index |
+| ---------- | ---- | ---------- | --------- |
+| 0 | 0 | 0 | 0 |
+| 1 | 1 | 0 | 1 |
+| 2 | 2 | 0 | 2 |
+| 3 | 3 | 0 | 3 |
+| 4 | 4 | 1 | 0 |
+| 5 | 5 | 1 | 1 |
+| 6 | 6 | 1 | 2 |
+| 7 | 7 | 1 | 3 |
+
+
+我们还可以创建一个新的进程组。这个新的进程组可以包含任何进程的子集。例如,我们可以创建一个只包含偶数进程的组:
+
+| process ID | rank | Node index | GPU index |
+| ---------- | ---- | ---------- | --------- |
+| 0 | 0 | 0 | 0 |
+| 2 | 1 | 0 | 2 |
+| 4 | 2 | 1 | 0 |
+| 6 | 3 | 1 | 2 |
+
+**请注意,rank 是相对于进程组而言的,一个进程在不同的进程组中可以有不同的 rank。最大的 rank 始终是 `world size of the process group - 1`。**
+
+在进程组中,各进程可以通过两种方式进行通信。
+1. peer-to-peer: 一个进程向另一个进程发送数据。
+2. collective: 一组进程一起执行分散、聚集、all-reduce、广播等操作。
+
+
+
+Collective communication, 来源: PyTorch distributed tutorial
+
diff --git a/docs/source/zh-Hans/concepts/paradigms_of_parallelism.md b/docs/source/zh-Hans/concepts/paradigms_of_parallelism.md
new file mode 100755
index 000000000000..8f52d28ecdf4
--- /dev/null
+++ b/docs/source/zh-Hans/concepts/paradigms_of_parallelism.md
@@ -0,0 +1,92 @@
+# 并行技术
+
+作者: Shenggui Li, Siqi Mai
+
+## 简介
+
+随着深度学习的发展,对并行训练的需求越来越大。这是因为模型和数据集越来越大,如果我们坚持使用单 GPU 训练,训练过程的等待将会成为一场噩梦。在本节中,我们将对现有的并行训练方法进行简要介绍。如果您想对这篇文章进行补充,欢迎在[GitHub论坛](https://github.com/hpcaitech/ColossalAI/discussions)上进行讨论。
+
+## 数据并行
+
+数据并行是最常见的并行形式,因为它很简单。在数据并行训练中,数据集被分割成几个碎片,每个碎片被分配到一个设备上。这相当于沿批次维度对训练过程进行并行化。每个设备将持有一个完整的模型副本,并在分配的数据集碎片上进行训练。在反向传播之后,模型的梯度将被全部减少,以便在不同设备上的模型参数能够保持同步。
+
+
+
+数据并行
+
+
+## 模型并行
+
+在数据并行训练中,一个明显的特点是每个 GPU 持有整个模型权重的副本。这就带来了冗余问题。另一种并行模式是模型并行,即模型被分割并分布在一个设备阵列上。通常有两种类型的并行:张量并行和流水线并行。张量并行是在一个操作中进行并行计算,如矩阵-矩阵乘法。流水线并行是在各层之间进行并行计算。因此,从另一个角度来看,张量并行可以被看作是层内并行,流水线并行可以被看作是层间并行。
+
+### 张量并行
+
+张量并行训练是将一个张量沿特定维度分成 `N` 块,每个设备只持有整个张量的 `1/N`,同时不影响计算图的正确性。这需要额外的通信来确保结果的正确性。
+
+以一般的矩阵乘法为例,假设我们有 `C = AB`。我们可以将B沿着列分割成 `[B0 B1 B2 ... Bn]`,每个设备持有一列。然后我们将 `A` 与每个设备上 `B` 中的每一列相乘,我们将得到 `[AB0 AB1 AB2 ... ABn]` 。此刻,每个设备仍然持有一部分的结果,例如,设备(rank=0)持有 `AB0`。为了确保结果的正确性,我们需要收集全部的结果,并沿列维串联张量。通过这种方式,我们能够将张量分布在设备上,同时确保计算流程保持正确。
+
+
+
+张量并行
+
+
+在 Colossal-AI 中,我们提供了一系列的张量并行方法,即 1D、2D、2.5D 和 3D 张量并行。我们将在`高级教程`中详细讨论它们。
+
+
+相关文章:
+- [GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding](https://arxiv.org/abs/2006.16668)
+- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
+- [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
+- [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
+- [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
+
+### 流水线并行
+
+流水线并行一般来说很容易理解。请您回忆一下您的计算机结构课程,这确实存在于 CPU 设计中。
+
+
+
+流水线并行
+
+
+流水线并行的核心思想是,模型按层分割成若干块,每块都交给一个设备。在前向传递过程中,每个设备将中间的激活传递给下一个阶段。在后向传递过程中,每个设备将输入张量的梯度传回给前一个流水线阶段。这允许设备同时进行计算,并增加了训练的吞吐量。流水线并行训练的一个缺点是,会有一些设备参与计算的冒泡时间,导致计算资源的浪费。
+
+
+
+Source: GPipe
+
+
+相关文章:
+- [PipeDream: Fast and Efficient Pipeline Parallel DNN Training](https://arxiv.org/abs/1806.03377)
+- [GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism](https://arxiv.org/abs/1811.06965)
+- [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
+- [Chimera: Efficiently Training Large-Scale Neural Networks with Bidirectional Pipelines](https://arxiv.org/abs/2107.06925)
+
+
+## 优化器相关的并行
+
+另一种并行方法和优化器相关,目前这种并行最流行的方法是 `ZeRO`,即[零冗余优化器](https://arxiv.org/abs/1910.02054)。 ZeRO 在三个层面上工作,以消除内存冗余(ZeRO需要进行fp16训练)。
+
+- Level 1: 优化器状态在各进程中被划分。
+- Level 2: 用于更新模型权重的32位梯度也被划分,因此每个进程只存储与其优化器状态划分相对应的梯度。
+- Level 3: 16位模型参数在各进程中被划分。
+
+相关文章:
+- [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054)
+
+
+## 异构系统的并行
+
+上述方法通常需要大量的 GPU 来训练一个大型模型。然而,人们常常忽略的是,与 GPU 相比,CPU 的内存要大得多。在一个典型的服务器上,CPU 可以轻松拥有几百GB的内存,而每个 GPU 通常只有16或32GB的内存。这促使人们思考为什么 CPU 内存没有被用于分布式训练。
+
+最近的进展是依靠 CPU 甚至是 NVMe 磁盘来训练大型模型。主要的想法是,在不使用张量时,将其卸载回 CPU 内存或 NVMe 磁盘。通过使用异构系统架构,有可能在一台机器上容纳一个巨大的模型。
+
+
+
+异构系统
+
+
+相关文章:
+- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
+- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)
+- [PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management](https://arxiv.org/abs/2108.05818)
diff --git a/docs/source/zh-Hans/features/1D_tensor_parallel.md b/docs/source/zh-Hans/features/1D_tensor_parallel.md
new file mode 100644
index 000000000000..8f3a3c6209da
--- /dev/null
+++ b/docs/source/zh-Hans/features/1D_tensor_parallel.md
@@ -0,0 +1,111 @@
+# 1D 张量并行
+
+作者: Zhengda Bian, Yongbin Li
+
+**前置教程**
+- [定义配置文件](../basics/define_your_config.md)
+- [并行配置](../basics/configure_parallelization.md)
+
+**示例代码**
+- [ColossalAI-Examples 1D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_1d.py)
+
+**相关论文**
+- [Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf)
+
+## 引言
+
+张量并行将模型参数划分到多个设备上,以减少内存负荷。
+[Megatron-LM](https://deepakn94.github.io/assets/papers/megatron-sc21.pdf) 介绍了一种高效的一维张量并行化实现。
+
+让我们以一个线性层为例,它包括一个 GEMM $Y = XA$。 给定2个处理器,我们把列 $A$ 划分为 $[A_1 ~ A_2]$, 并在每个处理器上计算 $Y_i = XA_i$ , which then forms $[Y_1 ~ Y_2] = [XA_1 ~ XA_2]$. This is called a column-parallel fashion.
+
+当第二个线性层 $Z=YB$ 跟随上述列并行层的时候, 我们把 $B$ 划分为 $\left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right]$,
+这就是所谓的行并行方式.
+为了计算 $Z = [Y_1 ~ Y_2] \left[\begin{matrix} B_1 \\ B_2 \end{matrix} \right]$, 我们首先在每个处理器上计算 $Y_iB_i$ 然后使用一个all-reduce操作将结果汇总为 $Z=Y_1B_1+Y_2B_2$。
+
+我们还需要注意,在后向计算中,列并行线性层需要聚合输入张量 $X$, 因为在每个处理器 $i$ 上,我们只有 $\dot{X_i}=\dot{Y_i}A_i^T$,因此,我们在各处理器之间进行all-reduce,得到 $\dot{X}=\dot{Y}A^T=\dot{Y_1}A_1^T+\dot{Y_2}A_2^T$。
+
+## 效率
+给定 $P$ 个处理器, 我们展现理论上的计算和内存成本,以及基于环形算法的1D张量并行的前向和后向的通信成本。
+
+| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) |
+| :-: | :-: | :-: | :-: | :-: |
+| $O(1/P)$ | $O(1/P)$ | $O(1)$ | $O(2(P-1)/P)$ | $O(2(P-1))$ |
+
+## 使用
+
+为了使模型能够实现一维张量并行, 如在2个 GPU 上, 我们需要配置如下的并行设置。
+```python
+CONFIG = dict(parallel=dict(
+ data=1,
+ pipeline=1,
+ tensor=dict(size=2, mode='1d'),
+))
+```
+
+然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用1D张量并行。
+
+让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。
+```python
+import colossalai
+import colossalai.nn as col_nn
+import torch
+from colossalai.utils import print_rank_0
+
+class MLP(torch.nn.Module):
+ def __init__(self, dim: int = 256):
+ super().__init__()
+ intermediate_dim = dim * 4
+ self.dense_1 = col_nn.Linear(dim, intermediate_dim)
+ print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.transpose(0, 1).shape}')
+ self.activation = torch.nn.GELU()
+ self.dense_2 = col_nn.Linear(intermediate_dim, dim)
+ print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.transpose(0, 1).shape}')
+ self.dropout = col_nn.Dropout(0.1)
+
+ def forward(self, x):
+ x = self.dense_1(x)
+ print_rank_0(f'Output of the first linear layer: {x.shape}')
+ x = self.activation(x)
+ x = self.dense_2(x)
+ print_rank_0(f'Output of the second linear layer: {x.shape}')
+ x = self.dropout(x)
+ return x
+```
+
+在2个 GPU 上启动 Colossal-AI 并建立模型。
+
+```python
+parser = colossalai.get_default_parser()
+colossalai.launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ local_rank=args.local_rank,
+ host=args.host,
+ port=args.port)
+
+m = MLP()
+```
+我们将会看到 MLP 模型中被划分的参数(如权重)的形状。
+```shell
+Weight of the first linear layer: torch.Size([256, 512])
+Weight of the second linear layer: torch.Size([512, 256])
+```
+第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过列-并行分割,它变成了 `[256, 512]`。
+同样地,第二个行并行层将权重 `[1024, 256]` 划分为 `[512, 256]`。
+
+我们可以用一些随机输入来运行这个模型。
+```python
+from colossalai.utils import get_current_device
+
+x = torch.randn((16, 256), device=get_current_device())
+torch.distributed.broadcast(x, src=0) # synchronize input
+
+x = m(x)
+```
+然后我们可以看到 activation 结果的形状。
+```shell
+Output of the first linear layer: torch.Size([16, 512])
+Output of the second linear layer: torch.Size([16, 256])
+```
+第一个线性层的输出被划分成2块 (每个形状为 `[16, 512]`), 而第二层在整个 GPU 上的输出是相同的。
diff --git a/docs/source/zh-Hans/features/2D_tensor_parallel.md b/docs/source/zh-Hans/features/2D_tensor_parallel.md
new file mode 100644
index 000000000000..c942f82bf9d2
--- /dev/null
+++ b/docs/source/zh-Hans/features/2D_tensor_parallel.md
@@ -0,0 +1,141 @@
+# 2D 张量并行
+
+作者: Zhengda Bian, Yongbin Li
+
+**前置教程**
+- [定义配置文件](../basics/define_your_config.md)
+- [并行配置](../basics/configure_parallelization.md)
+- [1D 张量并行](./1D_tensor_parallel.md)
+
+**示例代码**
+- [ColossalAI-Examples - 2D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2d.py)
+
+**相关论文**
+- [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/pdf/2104.05343.pdf)
+
+## 引言
+
+1D张量并行没有对 activations 进行划分,就大规模模型而言,这也会消耗大量的内存。
+为了平均分配计算和内存负荷,在 SUMMA(可扩展的通用矩阵乘法算法)的基础上, [2D张量并行](https://arxiv.org/pdf/2104.05343.pdf) 被引入。
+
+我们还是以线性层 $Y = XA$ 为例。
+给定 $P=q\times q$ 个处理器(必要条件), 如 $q=2$, 我们把输入 $X$ 和权重A $A$ 都划分为
+
+$$
+\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right]
+\text{~and~}
+\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right]。
+$$
+
+该计算包括 $q$ 步。 当 $t=1$ 时, $X_{i0}$ 在其行中被广播, 而 $A_{0j}$ 在其列中被广播。因此,我们有
+
+$$
+\left[\begin{matrix} X_{10},A_{00} & X_{10},A_{01} \\ X_{00},A_{00} & X_{00},A_{01} \end{matrix} \right]。
+$$
+
+然后我们在每个处理器 $(i, j)$ 上将 $X_{i0}$ 和 $A_{0j}$ 相乘为
+
+$$
+\left[\begin{matrix} X_{10}A_{00} & X_{10}A_{01} \\ X_{00}A_{00} & X_{00}A_{01} \end{matrix} \right] (1)。
+$$
+
+同样,当 $t=2$ 时, $X_{i1}$ 在其行中被广播, $A_{1j}$ 在其列中被广播, 我们将它们相乘为
+
+$$
+\left[\begin{matrix} X_{11}A_{10} & X_{11}A_{11} \\ X_{01}A_{10} & X_{01}A_{11} \end{matrix} \right] (2)。
+$$
+
+通过将 $(1)$ 和 $(2)$ 相加,我们有
+
+$$
+Y = XA = \left[\begin{matrix} X_{10}A_{00}+X_{11}A_{10} & X_{10}A_{01}+X_{11}A_{11} \\ X_{00}A_{00}+X_{01}A_{10} & X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right]。
+$$
+
+## 效率
+给定 $P=q\times q$ 个处理器, 我们展现理论上的计算和内存成本,以及基于环形算法的2D张量并行的前向和后向的通信成本。
+
+| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) |
+| :-: | :-: | :-: | :-: | :-: |
+| $O(1/q^2)$ | $O(1/q^2)$ | $O(1/q^2)$ | $O(6(q-1)/q)$ | $O(6(q-1))$ |
+
+## 使用
+
+为了使我们的模型能够实现二维张量并行,例如在4个 GPU 上,我们需要配置如下的并行设置。
+```python
+CONFIG = dict(parallel=dict(
+ data=1,
+ pipeline=1,
+ tensor=dict(size=4, mode='2d'),
+))
+```
+然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2D张量并行。
+
+让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。
+```python
+import colossalai
+import colossalai.nn as col_nn
+import torch
+from colossalai.utils import print_rank_0
+
+class MLP(torch.nn.Module):
+ def __init__(self, dim: int = 256):
+ super().__init__()
+ intermediate_dim = dim * 4
+ self.dense_1 = col_nn.Linear(dim, intermediate_dim)
+ print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
+ self.activation = torch.nn.GELU()
+ self.dense_2 = col_nn.Linear(intermediate_dim, dim)
+ print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
+ self.dropout = col_nn.Dropout(0.1)
+
+ def forward(self, x):
+ x = self.dense_1(x)
+ print_rank_0(f'Output of the first linear layer: {x.shape}')
+ x = self.activation(x)
+ x = self.dense_2(x)
+ print_rank_0(f'Output of the second linear layer: {x.shape}')
+ x = self.dropout(x)
+ return x
+```
+在4个 GPU 上启动 Colossal-AI 并建立模型。
+```python
+parser = colossalai.get_default_parser()
+colossalai.launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ local_rank=args.local_rank,
+ host=args.host,
+ port=args.port)
+
+m = MLP()
+```
+我们将会看到 MLP 模型中被划分的参数(如权重)的形状。
+```shell
+Weight of the first linear layer: torch.Size([128, 512])
+Weight of the second linear layer: torch.Size([512, 128])
+```
+第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。
+同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`.
+
+我们可以用一些随机输入来运行这个模型。
+```python
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.utils import get_current_device
+
+x = torch.randn((16, 256), device=get_current_device())
+# partition input
+torch.distributed.broadcast(x, src=0)
+x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)]
+x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)]
+print_rank_0(f'Input: {x.shape}')
+
+x = m(x)
+```
+然后我们可以看到 activation 结果的形状。
+```shell
+Input: torch.Size([8, 128])
+Output of the first linear layer: torch.Size([8, 512])
+Output of the second linear layer: torch.Size([8, 128])
+```
+2D并行中的 activation 张量都是同时在行和列分割的。例如,第一个线性层的输出是 `[8, 512]`, 而第二层的输出为 `[8, 128]`。
diff --git a/docs/source/zh-Hans/features/2p5D_tensor_parallel.md b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md
new file mode 100644
index 000000000000..59a4be02ce47
--- /dev/null
+++ b/docs/source/zh-Hans/features/2p5D_tensor_parallel.md
@@ -0,0 +1,145 @@
+# 2.5D 张量并行
+
+作者: Zhengda Bian, Yongbin Li
+
+**前置教程**
+- [定义配置文件](../basics/define_your_config.md)
+- [并行配置](../basics/configure_parallelization.md)
+- [1D 张量并行](./1D_tensor_parallel.md)
+- [2D 张量并行](./2D_tensor_parallel.md)
+
+**示例代码**
+- [ColossalAI-Examples - 2.5D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_2p5d.py)
+
+**相关论文**
+- [2.5-dimensional distributed model training](https://arxiv.org/pdf/2105.14500.pdf)
+
+## 引言
+
+与一维张量并行相比,二维并行降低了内存成本,但可能引入更多的通信。因此,[2.5D张量并行](https://arxiv.org/pdf/2105.14500.pdf) 在 2.5D SUMMA 的基础上被提出,它通过使用更多的设备来减少通信。
+
+我们还是以线性层 $Y = XA$ 为例。
+给定 $P=q \times q \times d$ 个处理器(必要条件), 如 $q=d=2$, 我们把输入 $X$ 划分为 $d\times q$ 行和 $q$ 列
+
+$$
+\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \\ X_{10} & X_{11} \\ X_{00} & X_{01}\end{matrix} \right],
+$$
+它可以被重塑为 $d$ 层
+
+$$
+\left[\begin{matrix} X_{10} & X_{11} \\ X_{00} & X_{01} \end{matrix} \right] \text{~and~}\left[\begin{matrix} X_{30} & X_{31} \\ X_{20} & X_{21} \end{matrix} \right].
+$$
+
+另外,权重 $A$ 被分割为
+
+$$
+\left[\begin{matrix} A_{10} & A_{11} \\ A_{00} & A_{01} \end{matrix} \right].
+$$
+
+对于 $X$ 相关的每一层, 我们使用SUMMA算法将 $X$ 与 $A$ 相乘。
+然后,我们得到输出
+
+$$
+\left[\begin{matrix} Y_{10}=X_{10}A_{00}+X_{11}A_{10} & Y_{11}=X_{10}A_{01}+X_{11}A_{11} \\ Y_{00}=X_{00}A_{00}+X_{01}A_{10} & Y_{01}=X_{00}A_{01}+X_{01}A_{11} \end{matrix} \right]
+\text{~and~}
+$$
+$$
+\left[\begin{matrix} Y_{30}=X_{30}A_{00}+X_{31}A_{10} & Y_{31}=X_{30}A_{01}+X_{31}A_{11} \\ Y_{20}=X_{20}A_{00}+X_{21}A_{10} & Y_{21}=X_{20}A_{01}+X_{21}A_{11} \end{matrix} \right].
+$$
+
+## 效率
+
+给定 $P=q \times q \times d$ 个处理器, 我们展现理论上的计算和内存成本,以及基于环形算法的2.5D张量并行的前向和后向的通信成本。
+
+| 计算 | 内存 (参数) | 内存 (activations) | 通信 (带宽) | 通信 (时延) |
+| :-: | :-: | :-: | :-: | :-: |
+| $O(1/dq^2)$ | $O(1/q^2)$ | $O(1/dq^2)$ | $\small O(3(q-1)(d+1)/dq)$ | $O(6(q-1))$ |
+
+## 使用
+
+为了使我们的模型能够实现2.5D张量并行,例如在8个 GPU 上,我们需要配置如下的并行设置。
+
+```python
+CONFIG = dict(parallel=dict(
+ data=1,
+ pipeline=1,
+ tensor=dict(size=8, mode='2.5d', depth=2),
+))
+
+```
+
+然后 Colossal-AI 会自动对所有来自 `colossalai.nn` 的层应用2.5D张量并行。
+
+让我们定义一个由两层多层感知器 (MLP) 组成的模型,如下所示。
+
+```python
+import colossalai
+import colossalai.nn as col_nn
+import torch
+from colossalai.utils import print_rank_0
+
+class MLP(torch.nn.Module):
+ def __init__(self, dim: int = 256):
+ super().__init__()
+ intermediate_dim = dim * 4
+ self.dense_1 = col_nn.Linear(dim, intermediate_dim)
+ print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
+ self.activation = torch.nn.GELU()
+ self.dense_2 = col_nn.Linear(intermediate_dim, dim)
+ print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')
+ self.dropout = col_nn.Dropout(0.1)
+
+ def forward(self, x):
+ x = self.dense_1(x)
+ print_rank_0(f'Output of the first linear layer: {x.shape}')
+ x = self.activation(x)
+ x = self.dense_2(x)
+ print_rank_0(f'Output of the second linear layer: {x.shape}')
+ x = self.dropout(x)
+ return x
+```
+在8个 GPU 上启动 Colossal-AI 并建立模型。
+```python
+parser = colossalai.get_default_parser()
+colossalai.launch(config=CONFIG,
+ rank=args.rank,
+ world_size=args.world_size,
+ local_rank=args.local_rank,
+ host=args.host,
+ port=args.port)
+
+m = MLP()
+```
+我们将会看到 MLP 模型中被划分的参数(如权重)的形状。
+```shell
+Weight of the first linear layer: torch.Size([128, 512])
+Weight of the second linear layer: torch.Size([512, 128])
+```
+
+第一个线性层的完整权重形状应该为 `[256, 1024]`. 经过2.5D并行划分后,它在每个 GPU 上变成了 `[128, 512]` 。
+同样地,第二层将权重 `[1024, 256]` 划分为 `[512, 128]`.
+
+我们可以用一些随机输入来运行这个模型。
+```python
+from colossalai.context import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.utils import get_current_device
+
+x = torch.randn((16, 256), device=get_current_device())
+# partition input
+torch.distributed.broadcast(x, src=0)
+x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)]
+x = torch.chunk(x, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)]
+x = torch.chunk(x, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)]
+print_rank_0(f'Input: {x.shape}')
+
+x = m(x)
+```
+然后我们可以看到 activation 结果的形状。
+```shell
+Input: torch.Size([4, 128])
+Output of the first linear layer: torch.Size([4, 512])
+Output of the second linear layer: torch.Size([4, 128])
+```
+2.5D并行中的 activation 张量都是同时在$d \times q$行和$q$列分割的。例如,第一个线性层的输出是 `[4, 512]`, 而第二层的输出为 `[4, 128]`。
+注意,2.5D并行使用与2D并行相同的划分方法来处理权重,区别在于对输入的划分。
diff --git a/docs/source/zh-Hans/features/3D_tensor_parallel.md b/docs/source/zh-Hans/features/3D_tensor_parallel.md
new file mode 100644
index 000000000000..440121c94243
--- /dev/null
+++ b/docs/source/zh-Hans/features/3D_tensor_parallel.md
@@ -0,0 +1,154 @@
+# 3D 张量并行
+
+作者: Zhengda Bian, Yongbin Li
+
+**前置教程**
+- [定义配置文件](../basics/define_your_config.md)
+- [并行配置](../basics/configure_parallelization.md)
+- [1D 张量并行](./1D_tensor_parallel.md)
+- [2D 张量并行](./2D_tensor_parallel.md)
+
+**示例代码**
+- [ColossalAI-Examples - 3D Tensor Parallelism](https://github.com/hpcaitech/ColossalAI-Examples/tree/main/features/tensor_parallel/tensor_parallel_3d.py)
+
+**相关论文**
+- [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/pdf/2105.14450.pdf)
+
+## 引言
+
+[3D 张量并行](https://arxiv.org/pdf/2105.14450.pdf) 是一种将神经网络模型的计算并行化,以期望获得最佳通信成本优化的方法。
+
+我们还是以线性层 $Y = XA$ 为例。
+给定 $P=q \times q \times q$ 个处理器(必要条件), 如 $q=2$, 我们把输入 $X$ 和权重 $A$ 划分为
+
+$$
+\left[\begin{matrix}
+ X_{000} & X_{001} \\
+ X_{010} & X_{011} \\
+ X_{100} & X_{101} \\
+ X_{110} & X_{111} \end{matrix}
+\right]
+\text{~and~}
+\left[\begin{matrix}
+ A_{000} & A_{001} & A_{010} & A_{011} \\
+ A_{100} & A_{101} & A_{110} & A_{111} \end{matrix}
+\right]
+\text{~respectively,}$$
+其中每个 $X_{ijl}$ 和 $A_{lji}$ 都被存储在处理器 $(i,j,l)$ 上, 如下图所示。
+
+
-## Example folder description
-
-This folder provides several examples using colossalai. The images folder includes model like diffusion, dreambooth and vit. The language folder includes gpt, opt, palm and roberta. The tutorial folder is for concept illustration, such as auto-parallel, hybrid-parallel and so on.
-
-
-## Integrate Your Example With System Testing
-
-For example code contributor, to meet the expectation and test your code automatically using github workflow function, here are several steps:
-
-
-- (must) Have a test_ci.sh file in the folder like shown below in 'File Structure Chart'
-- The dataset should be located in the company's machine and can be announced using environment variable and thus no need for a separate terminal command.
-- The model parameters should be small to allow fast testing.
-- File Structure Chart
-
- └─examples
- └─images
- └─vit
- └─requirements.txt
- └─test_ci.sh
+- [Colossal-AI Examples](#colossal-ai-examples)
+ - [Table of Contents](#table-of-contents)
+ - [Overview](#overview)
+ - [Folder Structure](#folder-structure)
+ - [Integrate Your Example With Testing](#integrate-your-example-with-testing)
+
+## Overview
+
+This folder provides several examples accelerated by Colossal-AI. The `tutorial` folder is for everyone to quickly try out the different features in Colossal-AI. Other folders such as `images` and `language` include a wide range of deep learning tasks and applications.
+
+## Folder Structure
+
+```text
+└─ examples
+ └─ images
+ └─ vit
+ └─ test_ci.sh
+ └─ train.py
+ └─ README.md
+ └─ ...
+ └─ ...
+```
+## Invitation to open-source contribution
+Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models!
+
+You may contact us or participate in the following ways:
+1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks!
+2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md).
+3. Join the Colossal-AI community on
+[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
+and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
+4. Send your official proposal to email contact@hpcaitech.com
+
+Thanks so much to all of our amazing contributors!
+
+## Integrate Your Example With Testing
+
+Regular checks are important to ensure that all examples run without apparent bugs and stay compatible with the latest API.
+Colossal-AI runs workflows to check for examples on a on-pull-request and weekly basis.
+When a new example is added or changed, the workflow will run the example to test whether it can run.
+Moreover, Colossal-AI will run testing for examples every week.
+
+Therefore, it is essential for the example contributors to know how to integrate your example with the testing workflow. Simply, you can follow the steps below.
+
+1. Create a script called `test_ci.sh` in your example folder
+2. Configure your testing parameters such as number steps, batch size in `test_ci.sh`, e.t.c. Keep these parameters small such that each example only takes several minutes.
+3. Export your dataset path with the prefix `/data` and make sure you have a copy of the dataset in the `/data/scratch/examples-data` directory on the CI machine. Community contributors can contact us via slack to request for downloading the dataset on the CI machine.
+4. Implement the logic such as dependency setup and example execution
diff --git a/examples/images/diffusion/README.md b/examples/images/diffusion/README.md
index abb1d24c0262..a70792b9f4a4 100644
--- a/examples/images/diffusion/README.md
+++ b/examples/images/diffusion/README.md
@@ -1,6 +1,5 @@
# ColoDiffusion: Stable Diffusion with Colossal-AI
-
Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1](https://github.com/CompVis/stable-diffusion) and [Stable Diffusion v2](https://github.com/Stability-AI/stablediffusion).
@@ -26,13 +25,22 @@ Acceleration of AIGC (AI-Generated Content) models such as [Stable Diffusion v1]
More details can be found in our [blog of Stable Diffusion v1](https://www.hpc-ai.tech/blog/diffusion-pretraining-and-hardware-fine-tuning-can-be-almost-7x-cheaper) and [blog of Stable Diffusion v2](https://www.hpc-ai.tech/blog/colossal-ai-0-2-0).
+
+## Roadmap
+This project is in rapid development.
+
+- [X] Train a stable diffusion model v1/v2 from scatch
+- [X] Finetune a pretrained Stable diffusion v1 model
+- [X] Inference a pretrained model using PyTorch
+- [ ] Finetune a pretrained Stable diffusion v2 model
+- [ ] Inference a pretrained model using TensoRT
+
## Installation
### Option #1: install from source
#### Step 1: Requirements
-A suitable [conda](https://conda.io/) environment named `ldm` can be created
-and activated with:
+To begin with, make sure your operating system has the cuda version suitable for this exciting training session, which is cuda11.6/11.8. For your convience, we have set up the rest of packages here. You can create and activate a suitable [conda](https://conda.io/) environment named `ldm` :
```
conda env create -f environment.yaml
@@ -43,23 +51,60 @@ You can also update an existing [latent diffusion](https://github.com/CompVis/la
```
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
-pip install transformers==4.19.2 diffusers invisible-watermark
-pip install -e .
+pip install transformers diffusers invisible-watermark
+```
+
+#### Step 2: install lightning
+
+Install Lightning version later than 2022.01.04. We suggest you install lightning from source. Notice that the default download path of pip should be within the conda environment, or you may need to specify using 'which pip' and redirect the path into conda environment.
+
+##### From Source
+```
+git clone https://github.com/Lightning-AI/lightning.git
+pip install -r requirements.txt
+python setup.py install
+```
+
+##### From pip
+
+```
+pip install pytorch-lightning
+```
+
+#### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
+
+You can install the latest version (0.2.7) from our official website or from source. Notice that the suitable version for this training is colossalai(0.2.5), which stands for torch(1.12.1).
+
+##### Download suggested verision for this training
+
+```
+
+pip install colossalai==0.2.5
+
```
-##### Step 2: install lightning
+##### Download the latest version from pip for latest torch version
+
+```
+pip install colossalai
+```
-Install Lightning version later than 2022.01.04. We suggest you install lightning from source.
+##### From source
-https://github.com/Lightning-AI/lightning.git
+```
+git clone https://github.com/hpcaitech/ColossalAI.git
+cd ColossalAI
+# install colossalai
+CUDA_EXT=1 pip install .
+```
-##### Step 3:Install [Colossal-AI](https://colossalai.org/download/) From Our Official Website
+#### Step 4:Accelerate with flash attention by xformers(Optional)
-For example, you can install v0.1.12 from our official website.
+Notice that xformers will accelerate the training process in cost of extra disk space. The suitable version of xformers for this training process is 0.12.0. You can download xformers directly via pip. For more release versions, feel free to check its official website: [XFormers](./https://pypi.org/project/xformers/)
```
-pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
+pip install xformers==0.0.12
```
### Option #2: Use Docker
@@ -75,7 +120,7 @@ docker build -t hpcaitech/diffusion:0.2.0 .
docker pull hpcaitech/diffusion:0.2.0
```
-Once you have the image ready, you can launch the image with the following command:
+Once you have the image ready, you can launch the image with the following command
```bash
########################
@@ -109,12 +154,15 @@ It is important for you to configure your volume mapping in order to get the bes
3. **Optional**, if you encounter any problem stating that shared memory is insufficient inside container, please add `-v /dev/shm:/dev/shm` to your `docker run` command.
-
## Download the model checkpoint from pretrained
-### stable-diffusion-v1-4
+### stable-diffusion-v2-base(Recommand)
-Our default model config use the weight from [CompVis/stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4?text=A+mecha+robot+in+a+favela+in+expressionist+style)
+```
+wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt
+```
+
+### stable-diffusion-v1-4
```
git lfs install
@@ -123,8 +171,6 @@ git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
### stable-diffusion-v1-5 from runway
-If you want to useed the Last [stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) wiegh from runwayml
-
```
git lfs install
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
@@ -137,25 +183,31 @@ you should the change the `data.file_path` in the `config/train_colossalai.yaml`
## Training
-We provide the script `train_colossalai.sh` to run the training task with colossalai,
-and can also use `train_ddp.sh` to run the training task with ddp to compare.
+We provide the script `train_colossalai.sh` to run the training task with colossalai. Meanwhile, we have enlightened other training process such as DDP model in PyTorch. You can also use `train_ddp.sh` to run the training task with ddp to compare the corresponding performance.
+
+In `train_colossalai.sh` the main command is
-In `train_colossalai.sh` the main command is:
```
-python main.py --logdir /tmp/ -t -b configs/train_colossalai.yaml
+python main.py --logdir /tmp/ --train --base configs/train_colossalai.yaml --ckpt 512-base-ema.ckpt
```
-- you can change the `--logdir` to decide where to save the log information and the last checkpoint.
+- You can change the `--logdir` to decide where to save the log information and the last checkpoint.
+ - You will find your ckpt in `logdir/checkpoints` or `logdir/diff_tb/version_0/checkpoints`
+ - You will find your train config yaml in `logdir/configs`
+- You can add the `--ckpt` if you want to load the pretrained model, for example `512-base-ema.ckpt`
+- You can change the `--base` to specify the path of config yaml
### Training config
You can change the trainging config in the yaml file
-- devices: device number used for training, default 8
-- max_epochs: max training epochs, default 2
-- precision: the precision type used in training, default 16 (fp16), you must use fp16 if you want to apply colossalai
+- devices: device number used for training, default = 8
+- max_epochs: max training epochs, default = 2
+- precision: the precision type used in training, default = 16 (fp16), you must use fp16 if you want to apply colossalai
+- placement_policy: the training strategy supported by Colossal AI, defult = 'cuda', which refers to loading all the parameters into cuda memory. On the other hand, 'cpu' refers to 'cpu offload' strategy while 'auto' enables 'Gemini', both featured by Colossal AI.
- more information about the configuration of ColossalAIStrategy can be found [here](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#colossal-ai)
+
## Finetune Example
### Training on Teyvat Datasets
@@ -171,8 +223,8 @@ you can get yout training last.ckpt and train config.yaml in your `--logdir`, an
```
python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
--outdir ./output \
- --config path/to/logdir/checkpoints/last.ckpt \
- --ckpt /path/to/logdir/configs/project.yaml \
+ --ckpt path/to/logdir/checkpoints/last.ckpt \
+ --config /path/to/logdir/configs/project.yaml \
```
```commandline
@@ -211,6 +263,19 @@ optional arguments:
evaluate at this precision
```
+## Invitation to open-source contribution
+Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models!
+
+You may contact us or participate in the following ways:
+1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks!
+2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md).
+3. Join the Colossal-AI community on
+[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
+and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
+4. Send your official proposal to email contact@hpcaitech.com
+
+Thanks so much to all of our amazing contributors!
+
## Comments
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
diff --git a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml
index d466c1c56259..ff0f4c5a0463 100644
--- a/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml
+++ b/examples/images/diffusion/configs/Teyvat/train_colossalai_teyvat.yaml
@@ -6,6 +6,7 @@ model:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
+ ckpt: None # use ckpt path
log_every_t: 200
timesteps: 1000
first_stage_key: image
@@ -16,7 +17,7 @@ model:
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
- use_ema: False # we set this to false because this is an inference only config
+ use_ema: False
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
@@ -110,6 +111,7 @@ lightning:
enable_distributed_storage: True
placement_policy: cuda
force_outputs_fp32: true
+ min_chunk_size: 64
log_every_n_steps: 2
logger: True
diff --git a/examples/images/diffusion/configs/train_colossalai.yaml b/examples/images/diffusion/configs/train_colossalai.yaml
index 0354311f84b6..88432e978a0f 100644
--- a/examples/images/diffusion/configs/train_colossalai.yaml
+++ b/examples/images/diffusion/configs/train_colossalai.yaml
@@ -107,6 +107,7 @@ lightning:
enable_distributed_storage: True
placement_policy: cuda
force_outputs_fp32: true
+ min_chunk_size: 64
log_every_n_steps: 2
logger: True
diff --git a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml
index 0273ca862bf8..0ba06f832178 100644
--- a/examples/images/diffusion/configs/train_colossalai_cifar10.yaml
+++ b/examples/images/diffusion/configs/train_colossalai_cifar10.yaml
@@ -111,6 +111,7 @@ lightning:
enable_distributed_storage: True
placement_policy: cuda
force_outputs_fp32: true
+ min_chunk_size: 64
log_every_n_steps: 2
logger: True
diff --git a/examples/images/diffusion/configs/train_pokemon.yaml b/examples/images/diffusion/configs/train_pokemon.yaml
deleted file mode 100644
index aadb5f2a0870..000000000000
--- a/examples/images/diffusion/configs/train_pokemon.yaml
+++ /dev/null
@@ -1,120 +0,0 @@
-model:
- base_learning_rate: 1.0e-4
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- parameterization: "v"
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: image
- cond_stage_key: txt
- image_size: 64
- channels: 4
- cond_stage_trainable: false
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False # we set this to false because this is an inference only config
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1.e-4 ]
- f_min: [ 1.e-10 ]
-
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- use_checkpoint: True
- use_fp16: True
- image_size: 32 # unused
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_head_channels: 64 # need to fix for flash-attn
- use_spatial_transformer: True
- use_linear_in_transformer: True
- transformer_depth: 1
- context_dim: 1024
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- monitor: val/rec_loss
- ddconfig:
- #attn_type: "vanilla-xformers"
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
- params:
- freeze: True
- layer: "penultimate"
-
-data:
- target: main.DataModuleFromConfig
- params:
- batch_size: 32
- wrap: False
- train:
- target: ldm.data.pokemon.PokemonDataset
- # params:
- # file_path: "/data/scratch/diffuser/laion_part0/"
- # world_size: 1
- # rank: 0
-
-lightning:
- trainer:
- accelerator: 'gpu'
- devices: 1
- log_gpu_memory: all
- max_epochs: 2
- precision: 16
- auto_select_gpus: False
- strategy:
- target: strategies.ColossalAIStrategy
- params:
- use_chunk: True
- enable_distributed_storage: True
- placement_policy: cuda
- force_outputs_fp32: true
-
- log_every_n_steps: 2
- logger: True
- default_root_dir: "/tmp/diff_log/"
- # profiler: pytorch
-
- logger_config:
- wandb:
- target: loggers.WandbLogger
- params:
- name: nowname
- save_dir: "/tmp/diff_log/"
- offline: opt.debug
- id: nowname
diff --git a/examples/images/diffusion/docker/Dockerfile b/examples/images/diffusion/docker/Dockerfile
index 17cc8bc8bbc7..3b5301b89853 100644
--- a/examples/images/diffusion/docker/Dockerfile
+++ b/examples/images/diffusion/docker/Dockerfile
@@ -15,16 +15,9 @@ RUN git clone https://github.com/NVIDIA/apex && \
# && cd ./ColossalAI \
# && pip install -v --no-cache-dir .
-RUN pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
+RUN pip install colossalai
-# install our lightning, it will be merged to Lightning official repo.
-RUN git clone https://github.com/1SAA/lightning.git && \
- cd lightning && \
- git checkout strategy/colossalai && \
- export PACKAGE_NAME=pytorch && \
- pip install --no-cache-dir .
-
# install titans
RUN pip install --no-cache-dir titans
diff --git a/examples/images/diffusion/environment.yaml b/examples/images/diffusion/environment.yaml
index 69904c72ea73..d1ec69c1a585 100644
--- a/examples/images/diffusion/environment.yaml
+++ b/examples/images/diffusion/environment.yaml
@@ -18,7 +18,7 @@ dependencies:
- test-tube>=0.7.5
- streamlit==1.12.1
- einops==0.3.0
- - transformers==4.19.2
+ - transformers
- webdataset==0.2.5
- kornia==0.6
- open_clip_torch==2.0.2
@@ -27,5 +27,6 @@ dependencies:
- torchmetrics==0.7.0
- prefetch_generator
- datasets
- - colossalai
+ - colossalai==0.2.5
+ - lightning==1.9.0
- -e .
diff --git a/examples/images/diffusion/ldm/models/diffusion/ddpm.py b/examples/images/diffusion/ldm/models/diffusion/ddpm.py
index f7ac0a735f10..b7315b048c66 100644
--- a/examples/images/diffusion/ldm/models/diffusion/ddpm.py
+++ b/examples/images/diffusion/ldm/models/diffusion/ddpm.py
@@ -6,56 +6,41 @@
-- merci
"""
+import numpy as np
import torch
import torch.nn as nn
-import numpy as np
+
try:
import lightning.pytorch as pl
- from lightning.pytorch.utilities import rank_zero_only, rank_zero_info
+ from lightning.pytorch.utilities import rank_zero_info, rank_zero_only
except:
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only, rank_zero_info
-from torch.optim.lr_scheduler import LambdaLR
-from einops import rearrange, repeat
+
+import itertools
from contextlib import contextmanager, nullcontext
from functools import partial
-import itertools
-from tqdm import tqdm
-from torchvision.utils import make_grid
-from omegaconf import ListConfig
-
-from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
-from ldm.modules.ema import LitEma
-from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
-
-
-from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from einops import rearrange, repeat
+from ldm.models.autoencoder import *
+from ldm.models.autoencoder import AutoencoderKL, IdentityFirstStage
+from ldm.models.diffusion.ddim import *
from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.modules.diffusionmodules.model import *
+from ldm.modules.diffusionmodules.model import Decoder, Encoder, Model
from ldm.modules.diffusionmodules.openaimodel import *
-
-from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
-from ldm.models.diffusion.ddim import DDIMSampler
from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d
-from ldm.modules.encoders.modules import *
-
+from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule, noise_like
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution, normal_kl
from ldm.modules.ema import LitEma
-from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ldm.models.autoencoder import *
-from ldm.models.diffusion.ddim import *
-from ldm.modules.diffusionmodules.openaimodel import *
-from ldm.modules.diffusionmodules.model import *
-
-
-from ldm.modules.diffusionmodules.model import Model, Encoder, Decoder
-
-from ldm.util import instantiate_from_config
-
+from ldm.modules.encoders.modules import *
+from ldm.util import count_params, default, exists, instantiate_from_config, isimage, ismap, log_txt_as_img, mean_flat
+from omegaconf import ListConfig
+from torch.optim.lr_scheduler import LambdaLR
+from torchvision.utils import make_grid
+from tqdm import tqdm
-__conditioning_keys__ = {'concat': 'c_concat',
- 'crossattn': 'c_crossattn',
- 'adm': 'y'}
+__conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'}
def disabled_train(self, mode=True):
@@ -70,40 +55,41 @@ def uniform_on_device(r1, r2, shape, device):
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
- def __init__(self,
- unet_config,
- timesteps=1000,
- beta_schedule="linear",
- loss_type="l2",
- ckpt_path=None,
- ignore_keys=[],
- load_only_unet=False,
- monitor="val/loss",
- use_ema=True,
- first_stage_key="image",
- image_size=256,
- channels=3,
- log_every_t=100,
- clip_denoised=True,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- given_betas=None,
- original_elbo_weight=0.,
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
- l_simple_weight=1.,
- conditioning_key=None,
- parameterization="eps", # all assuming fixed variance schedules
- scheduler_config=None,
- use_positional_encodings=False,
- learn_logvar=False,
- logvar_init=0.,
- use_fp16 = True,
- make_it_fit=False,
- ucg_training=None,
- reset_ema=False,
- reset_num_ema_updates=False,
- ):
+ def __init__(
+ self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ use_fp16=True,
+ make_it_fit=False,
+ ucg_training=None,
+ reset_ema=False,
+ reset_num_ema_updates=False,
+ ):
super().__init__()
assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
self.parameterization = parameterization
@@ -112,18 +98,18 @@ def __init__(self,
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
- self.image_size = image_size # try conv?
+ self.image_size = image_size
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.unet_config = unet_config
self.conditioning_key = conditioning_key
self.model = DiffusionWrapper(unet_config, conditioning_key)
- count_params(self.model, verbose=True)
+ # count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+ rank_zero_info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
@@ -136,21 +122,26 @@ def __init__(self,
if monitor is not None:
self.monitor = monitor
self.make_it_fit = make_it_fit
- self.ckpt_path = ckpt_path
+ self.ckpt = ckpt
self.ignore_keys = ignore_keys
self.load_only_unet = load_only_unet
self.reset_ema = reset_ema
self.reset_num_ema_updates = reset_num_ema_updates
- if reset_ema: assert exists(ckpt_path)
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
- if reset_ema:
- assert self.use_ema
- print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
- self.model_ema = LitEma(self.model)
+ if reset_ema:
+ assert exists(ckpt)
+ '''
+ Uncomment if you Use DDP Strategy
+ '''
+ # if ckpt is not None:
+ # self.init_from_ckpt(ckpt, ignore_keys=ignore_keys, only_model=load_only_unet)
+ # if reset_ema:
+ # assert self.use_ema
+ # rank_zero_info(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ # self.model_ema = LitEma(self.model)
+
if reset_num_ema_updates:
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ rank_zero_info(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
assert self.use_ema
self.model_ema.reset_num_updates()
@@ -160,9 +151,13 @@ def __init__(self,
self.linear_start = linear_start
self.linear_end = linear_end
self.cosine_s = cosine_s
-
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.register_schedule(given_betas=given_betas,
+ beta_schedule=beta_schedule,
+ timesteps=timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s)
self.loss_type = loss_type
@@ -176,12 +171,20 @@ def __init__(self,
if self.ucg_training:
self.ucg_prng = np.random.RandomState()
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ def register_schedule(self,
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ betas = make_beta_schedule(beta_schedule,
+ timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
@@ -208,24 +211,23 @@ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
- 1. - alphas_cumprod) + self.v_posterior * betas
+ 1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
- self.register_buffer('posterior_mean_coef1', to_torch(
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
- self.register_buffer('posterior_mean_coef2', to_torch(
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef1',
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2',
+ to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
- lvlb_weights = self.betas ** 2 / (
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
elif self.parameterization == "v":
- lvlb_weights = torch.ones_like(self.betas ** 2 / (
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ lvlb_weights = torch.ones_like(self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) *
+ (1 - self.alphas_cumprod)))
else:
raise NotImplementedError("mu not supported")
lvlb_weights[0] = lvlb_weights[1]
@@ -238,14 +240,14 @@ def ema_scope(self, context=None):
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
- print(f"{context}: Switched to EMA weights")
+ rank_zero_info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
- print(f"{context}: Restored training weights")
+ rank_zero_info(f"{context}: Restored training weights")
@torch.no_grad()
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
@@ -256,18 +258,13 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
+ rank_zero_info("Deleting key {} from state_dict.".format(k))
del sd[k]
if self.make_it_fit:
- n_params = len([name for name, _ in
- itertools.chain(self.named_parameters(),
- self.named_buffers())])
- for name, param in tqdm(
- itertools.chain(self.named_parameters(),
- self.named_buffers()),
- desc="Fitting old weights to new weights",
- total=n_params
- ):
+ n_params = len([name for name, _ in itertools.chain(self.named_parameters(), self.named_buffers())])
+ for name, param in tqdm(itertools.chain(self.named_parameters(), self.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params):
if not name in sd:
continue
old_shape = sd[name].shape
@@ -304,11 +301,11 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
- print(f"Missing Keys:\n {missing}")
+ rank_zero_info(f"Missing Keys:\n {missing}")
if len(unexpected) > 0:
- print(f"\nUnexpected Keys:\n {unexpected}")
+ rank_zero_info(f"\nUnexpected Keys:\n {unexpected}")
def q_mean_variance(self, x_start, t):
"""
@@ -323,30 +320,22 @@ def q_mean_variance(self, x_start, t):
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
- return (
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
- )
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise)
def predict_start_from_z_and_v(self, x_t, t, v):
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
- )
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v)
def predict_eps_from_z_and_v(self, x_t, t, v):
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
- )
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t)
def q_posterior(self, x_start, x_t, t):
- posterior_mean = (
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
- )
+ posterior_mean = (extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
@@ -379,7 +368,8 @@ def p_sample_loop(self, shape, return_intermediates=False):
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ img = self.p_sample(img,
+ torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
@@ -400,10 +390,8 @@ def q_sample(self, x_start, t, noise=None):
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_v(self, x, noise, t):
- return (
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
- )
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x)
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
@@ -485,11 +473,9 @@ def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
- self.log_dict(loss_dict, prog_bar=True,
- logger=True, on_step=True, on_epoch=True)
+ self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
- self.log("global_step", self.global_step,
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
+ self.log("global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
@@ -580,7 +566,8 @@ def __init__(self,
scale_by_std=False,
use_fp16=True,
force_null_conditioning=False,
- *args, **kwargs):
+ *args,
+ **kwargs):
self.force_null_conditioning = force_null_conditioning
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
@@ -590,7 +577,7 @@ def __init__(self,
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
conditioning_key = None
-
+
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
@@ -599,7 +586,7 @@ def __init__(self,
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
-
+
if not scale_by_std:
self.scale_factor = scale_factor
else:
@@ -611,40 +598,44 @@ def __init__(self,
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
-
- self.restarted_from_ckpt = False
- if self.ckpt_path is not None:
- self.init_from_ckpt(self.ckpt_path, self.ignore_keys)
- self.restarted_from_ckpt = True
- if self.reset_ema:
- assert self.use_ema
- print(
- f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
- self.model_ema = LitEma(self.model)
+ '''
+ Uncomment if you Use DDP Strategy
+ '''
+ # self.restarted_from_ckpt = False
+ # if self.ckpt is not None:
+ # self.init_from_ckpt(self.ckpt, self.ignore_keys)
+ # self.restarted_from_ckpt = True
+ # if self.reset_ema:
+ # assert self.use_ema
+ # rank_zero_info(
+ # f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ # self.model_ema = LitEma(self.model)
if self.reset_num_ema_updates:
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ rank_zero_info(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
assert self.use_ema
self.model_ema.reset_num_updates()
def configure_sharded_model(self) -> None:
rank_zero_info("Configure sharded model for LatentDiffusion")
self.model = DiffusionWrapper(self.unet_config, self.conditioning_key)
+ count_params(self.model, verbose=True)
if self.use_ema:
self.model_ema = LitEma(self.model)
- if self.ckpt_path is not None:
- self.init_from_ckpt(self.ckpt_path, ignore_keys=self.ignore_keys, only_model=self.load_only_unet)
+ if self.ckpt is not None:
+ self.init_from_ckpt(self.ckpt, ignore_keys=self.ignore_keys, only_model=self.load_only_unet)
if self.reset_ema:
assert self.use_ema
- print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ rank_zero_info(
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
self.model_ema = LitEma(self.model)
- if self.reset_num_ema_updates:
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
- assert self.use_ema
- self.model_ema.reset_num_updates()
- self.register_schedule(given_betas=self.given_betas, beta_schedule=self.beta_schedule, timesteps=self.timesteps,
- linear_start=self.linear_start, linear_end=self.linear_end, cosine_s=self.cosine_s)
+ self.register_schedule(given_betas=self.given_betas,
+ beta_schedule=self.beta_schedule,
+ timesteps=self.timesteps,
+ linear_start=self.linear_start,
+ linear_end=self.linear_end,
+ cosine_s=self.cosine_s)
self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
@@ -654,20 +645,16 @@ def configure_sharded_model(self) -> None:
self.instantiate_first_stage(self.first_stage_config)
self.instantiate_cond_stage(self.cond_stage_config)
- if self.ckpt_path is not None:
- self.init_from_ckpt(self.ckpt_path, self.ignore_keys)
+ if self.ckpt is not None:
+ self.init_from_ckpt(self.ckpt, self.ignore_keys)
self.restarted_from_ckpt = True
if self.reset_ema:
assert self.use_ema
- print(
+ rank_zero_info(
f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
self.model_ema = LitEma(self.model)
- if self.reset_num_ema_updates:
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
- assert self.use_ema
- self.model_ema.reset_num_updates()
- def make_cond_schedule(self, ):
+ def make_cond_schedule(self,):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
@@ -679,19 +666,23 @@ def on_train_batch_start(self, batch, batch_idx):
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
- print("### USING STD-RESCALING ###")
+ rank_zero_info("### USING STD-RESCALING ###")
x = super().get_input(batch, self.first_stage_key)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer('scale_factor', 1. / z.flatten().std())
- print(f"setting self.scale_factor to {self.scale_factor}")
- print("### USING STD-RESCALING ###")
+ rank_zero_info(f"setting self.scale_factor to {self.scale_factor}")
+ rank_zero_info("### USING STD-RESCALING ###")
def register_schedule(self,
- given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ given_betas=None,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3):
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
@@ -708,10 +699,10 @@ def instantiate_first_stage(self, config):
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
- print("Using first stage also as cond stage.")
+ rank_zero_info("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
- print(f"Training {self.__class__.__name__} as an unconditional model.")
+ rank_zero_info(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
@@ -729,10 +720,10 @@ def instantiate_cond_stage(self, config):
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = []
for zd in tqdm(samples, desc=desc):
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
- force_not_quantize=force_no_decoder_quantization))
+ denoise_row.append(
+ self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
@@ -783,21 +774,23 @@ def delta_border(self, h, w):
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
- self.split_input_params["clip_max_weight"], )
+ weighting = torch.clip(
+ weighting,
+ self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"],
+ )
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
- L_weighting = torch.clip(L_weighting,
- self.split_input_params["clip_min_tie_weight"],
+ L_weighting = torch.clip(L_weighting, self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"])
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
@@ -815,7 +808,7 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
@@ -823,12 +816,13 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
- dilation=1, padding=0,
+ dilation=1,
+ padding=0,
stride=(stride[0] * uf, stride[1] * uf))
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
elif df > 1 and uf == 1:
@@ -836,12 +830,13 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
- dilation=1, padding=0,
+ dilation=1,
+ padding=0,
stride=(stride[0] // df, stride[1] // df))
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
else:
@@ -850,8 +845,15 @@ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once
return fold, unfold, normalization, weighting
@torch.no_grad()
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
- cond_key=None, return_original_cond=False, bs=None, return_x=False):
+ def get_input(self,
+ batch,
+ k,
+ return_first_stage_outputs=False,
+ force_c_encode=False,
+ cond_key=None,
+ return_original_cond=False,
+ bs=None,
+ return_x=False):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
@@ -900,7 +902,7 @@ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=F
out.extend([x])
if return_original_cond:
out.append(xc)
-
+
return out
@torch.no_grad()
@@ -929,7 +931,7 @@ def forward(self, x, c, *args, **kwargs):
assert c is not None
if self.cond_stage_trainable:
c = self.get_learned_conditioning(c)
- if self.shorten_cond_schedule: # TODO: drop this option
+ if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
@@ -1007,8 +1009,16 @@ def p_losses(self, x_start, cond, t, noise=None):
return loss, loss_dict
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
- return_x0=False, score_corrector=None, corrector_kwargs=None):
+ def p_mean_variance(self,
+ x,
+ c,
+ t,
+ clip_denoised: bool,
+ return_codebook_ids=False,
+ quantize_denoised=False,
+ return_x0=False,
+ score_corrector=None,
+ corrector_kwargs=None):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
@@ -1039,15 +1049,29 @@ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=Fals
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ def p_sample(self,
+ x,
+ c,
+ t,
+ clip_denoised=False,
+ repeat_noise=False,
+ return_codebook_ids=False,
+ quantize_denoised=False,
+ return_x0=False,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None):
b, *_, device = *x.shape, x.device
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ outputs = self.p_mean_variance(x=x,
+ c=c,
+ t=t,
+ clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
@@ -1070,9 +1094,22 @@ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ def progressive_denoising(self,
+ cond,
+ shape,
+ verbose=True,
+ callback=None,
+ quantize_denoised=False,
+ img_callback=None,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ batch_size=None,
+ x_T=None,
+ start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
@@ -1089,16 +1126,17 @@ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quanti
intermediates = []
if cond is not None:
if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ cond = {
+ key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(
+ map(lambda x: x[:batch_size], cond[key])) for key in cond
+ }
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
- total=timesteps) if verbose else reversed(
- range(0, timesteps))
+ total=timesteps) if verbose else reversed(range(0, timesteps))
if type(temperature) == float:
temperature = [temperature] * timesteps
@@ -1109,11 +1147,16 @@ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quanti
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
- img, x0_partial = self.p_sample(img, cond, ts,
+ img, x0_partial = self.p_sample(img,
+ cond,
+ ts,
clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised, return_x0=True,
- temperature=temperature[i], noise_dropout=noise_dropout,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ quantize_denoised=quantize_denoised,
+ return_x0=True,
+ temperature=temperature[i],
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
@@ -1121,14 +1164,26 @@ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quanti
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
return img, intermediates
@torch.no_grad()
- def p_sample_loop(self, cond, shape, return_intermediates=False,
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, img_callback=None, start_T=None,
+ def p_sample_loop(self,
+ cond,
+ shape,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ callback=None,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ start_T=None,
log_every_t=None):
if not log_every_t:
@@ -1151,7 +1206,7 @@ def p_sample_loop(self, cond, shape, return_intermediates=False,
if mask is not None:
assert x0 is not None
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
@@ -1160,51 +1215,64 @@ def p_sample_loop(self, cond, shape, return_intermediates=False,
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
- img = self.p_sample(img, cond, ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised)
+ img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
+ if callback:
+ callback(i)
+ if img_callback:
+ img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
- verbose=True, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, shape=None, **kwargs):
+ def sample(self,
+ cond,
+ batch_size=16,
+ return_intermediates=False,
+ x_T=None,
+ verbose=True,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ shape=None,
+ **kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ cond = {
+ key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(
+ map(lambda x: x[:batch_size], cond[key])) for key in cond
+ }
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
shape,
- return_intermediates=return_intermediates, x_T=x_T,
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
- mask=mask, x0=x0)
+ return_intermediates=return_intermediates,
+ x_T=x_T,
+ verbose=verbose,
+ timesteps=timesteps,
+ quantize_denoised=quantize_denoised,
+ mask=mask,
+ x0=x0)
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
- samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
- shape, cond, verbose=False, **kwargs)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
else:
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
- return_intermediates=True, **kwargs)
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs)
return samples, intermediates
@@ -1226,7 +1294,7 @@ def get_unconditional_conditioning(self, batch_size, null_label=None):
return self.get_learned_conditioning(xc)
else:
raise NotImplementedError("todo")
- if isinstance(c, list): # in case the encoder gives us a list
+ if isinstance(c, list): # in case the encoder gives us a list
for i in range(len(c)):
c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
else:
@@ -1234,16 +1302,29 @@ def get_unconditional_conditioning(self, batch_size, null_label=None):
return c
@torch.no_grad()
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ def log_images(self,
+ batch,
+ N=8,
+ n_row=4,
+ sample=True,
+ ddim_steps=50,
+ ddim_eta=0.,
+ return_keys=None,
+ quantize_denoised=True,
+ inpaint=True,
+ plot_denoise_rows=False,
+ plot_progressive_rows=True,
+ plot_diffusion_rows=True,
+ unconditional_guidance_scale=1.,
+ unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ z, c, x, xrec, xc = self.get_input(batch,
+ self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
@@ -1283,7 +1364,7 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
@@ -1292,8 +1373,11 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0
if sample:
# get denoise row
with ema_scope("Sampling"):
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta)
+ samples, z_denoise_row = self.sample_log(cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
@@ -1305,8 +1389,11 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with ema_scope("Plotting Quantized Denoised"):
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta,
+ samples, z_denoise_row = self.sample_log(cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
@@ -1318,11 +1405,15 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0
if self.model.conditioning_key == "crossattn-adm":
uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
with ema_scope("Sampling with classifier-free guidance"):
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=uc,
- )
+ samples_cfg, _ = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
@@ -1334,8 +1425,13 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with ema_scope("Plotting Inpaint"):
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ samples, _ = self.sample_log(cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ eta=ddim_eta,
+ ddim_steps=ddim_steps,
+ x0=z[:N],
+ mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
@@ -1343,8 +1439,13 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0
# outpaint
mask = 1. - mask
with ema_scope("Plotting Outpaint"):
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ samples, _ = self.sample_log(cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ eta=ddim_eta,
+ ddim_steps=ddim_steps,
+ x0=z[:N],
+ mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
@@ -1367,10 +1468,10 @@ def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.cond_stage_trainable:
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ rank_zero_info(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
- print('Diffusion model optimizing logvar')
+ rank_zero_info('Diffusion model optimizing logvar')
params.append(self.logvar)
from colossalai.nn.optimizer import HybridAdam
@@ -1381,13 +1482,8 @@ def configure_optimizers(self):
assert 'target' in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
- print("Setting up LambdaLR scheduler...")
- scheduler = [
- {
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- }]
+ rank_zero_info("Setting up LambdaLR scheduler...")
+ scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}]
return [opt], scheduler
return opt
@@ -1402,6 +1498,7 @@ def to_rgb(self, x):
class DiffusionWrapper(pl.LightningModule):
+
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
@@ -1444,6 +1541,7 @@ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=N
class LatentUpscaleDiffusion(LatentDiffusion):
+
def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
super().__init__(*args, **kwargs)
# assumes that neither the cond_stage nor the low_scale_model contain trainable params
@@ -1464,8 +1562,12 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
if not log_mode:
z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
else:
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
- force_c_encode=True, return_original_cond=True, bs=bs)
+ z, c, x, xrec, xc = super().get_input(batch,
+ self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=bs)
x_low = batch[self.low_scale_key][:bs]
x_low = rearrange(x_low, 'b h w c -> b c h w')
if self.use_fp16:
@@ -1485,15 +1587,28 @@ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
return z, all_conds
@torch.no_grad()
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
- plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
- unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+ def log_images(self,
+ batch,
+ N=8,
+ n_row=4,
+ sample=True,
+ ddim_steps=200,
+ ddim_eta=1.,
+ return_keys=None,
+ plot_denoise_rows=False,
+ plot_progressive_rows=True,
+ plot_diffusion_rows=True,
+ unconditional_guidance_scale=1.,
+ unconditional_guidance_label=None,
+ use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
use_ddim = ddim_steps is not None
log = dict()
- z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch,
+ self.first_stage_key,
+ bs=N,
log_mode=True)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
@@ -1528,7 +1643,7 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
@@ -1537,8 +1652,11 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
if sample:
# get denoise row
with ema_scope("Sampling"):
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta)
+ samples, z_denoise_row = self.sample_log(cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
@@ -1555,7 +1673,7 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
if k == "c_crossattn":
assert isinstance(c[k], list) and len(c[k]) == 1
uc[k] = [uc_tmp]
- elif k == "c_adm": # todo: only run with text-based guidance?
+ elif k == "c_adm": # todo: only run with text-based guidance?
assert isinstance(c[k], torch.Tensor)
#uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
uc[k] = c[k]
@@ -1565,11 +1683,15 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
uc[k] = c[k]
with ema_scope("Sampling with classifier-free guidance"):
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=uc,
- )
+ samples_cfg, _ = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
@@ -1590,18 +1712,18 @@ class LatentFinetuneDiffusion(LatentDiffusion):
To disable finetuning mode, set finetune_keys to None
"""
- def __init__(self,
- concat_keys: tuple,
- finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
- "model_ema.diffusion_modelinput_blocks00weight"
- ),
- keep_finetune_dims=4,
- # if model was trained without concat mode before and we would like to keep these channels
- c_concat_log_start=None, # to log reconstruction of c_concat codes
- c_concat_log_end=None,
- *args, **kwargs
- ):
- ckpt_path = kwargs.pop("ckpt_path", None)
+ def __init__(
+ self,
+ concat_keys: tuple,
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight"),
+ keep_finetune_dims=4,
+ # if model was trained without concat mode before and we would like to keep these channels
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
+ c_concat_log_end=None,
+ *args,
+ **kwargs):
+ ckpt = kwargs.pop("ckpt", None)
ignore_keys = kwargs.pop("ignore_keys", list())
super().__init__(*args, **kwargs)
self.finetune_keys = finetune_keys
@@ -1609,9 +1731,10 @@ def __init__(self,
self.keep_dims = keep_finetune_dims
self.c_concat_log_start = c_concat_log_start
self.c_concat_log_end = c_concat_log_end
- if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
- if exists(ckpt_path):
- self.init_from_ckpt(ckpt_path, ignore_keys)
+ if exists(self.finetune_keys):
+ assert exists(ckpt), 'can only finetune from a given checkpoint'
+ if exists(ckpt):
+ self.init_from_ckpt(ckpt, ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
@@ -1621,7 +1744,7 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
+ rank_zero_info("Deleting key {} from state_dict.".format(k))
del sd[k]
# make it explicit, finetune by including extra input channels
@@ -1629,25 +1752,38 @@ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
new_entry = None
for name, param in self.named_parameters():
if name in self.finetune_keys:
- print(
- f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
- new_entry = torch.zeros_like(param) # zero init
+ rank_zero_info(
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only"
+ )
+ new_entry = torch.zeros_like(param) # zero init
assert exists(new_entry), 'did not find matching parameter to modify'
new_entry[:, :self.keep_dims, ...] = sd[k]
sd[k] = new_entry
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ rank_zero_info(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
- print(f"Missing Keys: {missing}")
+ rank_zero_info(f"Missing Keys: {missing}")
if len(unexpected) > 0:
- print(f"Unexpected Keys: {unexpected}")
+ rank_zero_info(f"Unexpected Keys: {unexpected}")
@torch.no_grad()
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ def log_images(self,
+ batch,
+ N=8,
+ n_row=4,
+ sample=True,
+ ddim_steps=200,
+ ddim_eta=1.,
+ return_keys=None,
+ quantize_denoised=True,
+ inpaint=True,
+ plot_denoise_rows=False,
+ plot_progressive_rows=True,
+ plot_diffusion_rows=True,
+ unconditional_guidance_scale=1.,
+ unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
ema_scope = self.ema_scope if use_ema_scope else nullcontext
@@ -1690,7 +1826,7 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
@@ -1699,9 +1835,14 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
if sample:
# get denoise row
with ema_scope("Sampling"):
- samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
- batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta)
+ samples, z_denoise_row = self.sample_log(cond={
+ "c_concat": [c_cat],
+ "c_crossattn": [c]
+ },
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
@@ -1714,12 +1855,18 @@ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=
uc_cat = c_cat
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
with ema_scope("Sampling with classifier-free guidance"):
- samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
- batch_size=N, ddim=use_ddim,
- ddim_steps=ddim_steps, eta=ddim_eta,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=uc_full,
- )
+ samples_cfg, _ = self.sample_log(
+ cond={
+ "c_concat": [c_cat],
+ "c_crossattn": [c]
+ },
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
@@ -1733,11 +1880,7 @@ class LatentInpaintDiffusion(LatentFinetuneDiffusion):
To disable finetuning mode, set finetune_keys to None
"""
- def __init__(self,
- concat_keys=("mask", "masked_image"),
- masked_image_key="masked_image",
- *args, **kwargs
- ):
+ def __init__(self, concat_keys=("mask", "masked_image"), masked_image_key="masked_image", *args, **kwargs):
super().__init__(concat_keys, *args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
@@ -1746,8 +1889,12 @@ def __init__(self,
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
- force_c_encode=True, return_original_cond=True, bs=bs)
+ z, c, x, xrec, xc = super().get_input(batch,
+ self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=bs)
assert exists(self.concat_keys)
c_cat = list()
@@ -1793,8 +1940,12 @@ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwarg
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
- force_c_encode=True, return_original_cond=True, bs=bs)
+ z, c, x, xrec, xc = super().get_input(batch,
+ self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=bs)
assert exists(self.concat_keys)
assert len(self.concat_keys) == 1
@@ -1812,7 +1963,8 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs
align_corners=False,
)
- depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc,
+ dim=[1, 2, 3],
keepdim=True)
cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
c_cat.append(cc)
@@ -1836,13 +1988,19 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
"""
condition on low-res image (and optionally on some spatial noise augmentation)
"""
- def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
- low_scale_config=None, low_scale_key=None, *args, **kwargs):
+
+ def __init__(self,
+ concat_keys=("lr",),
+ reshuffle_patch_size=None,
+ low_scale_config=None,
+ low_scale_key=None,
+ *args,
+ **kwargs):
super().__init__(concat_keys=concat_keys, *args, **kwargs)
self.reshuffle_patch_size = reshuffle_patch_size
self.low_scale_model = None
if low_scale_config is not None:
- print("Initializing a low-scale model")
+ rank_zero_info("Initializing a low-scale model")
assert exists(low_scale_key)
self.instantiate_low_stage(low_scale_config)
self.low_scale_key = low_scale_key
@@ -1858,8 +2016,12 @@ def instantiate_low_stage(self, config):
def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
# note: restricted to non-trainable encoders currently
assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
- force_c_encode=True, return_original_cond=True, bs=bs)
+ z, c, x, xrec, xc = super().get_input(batch,
+ self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=bs)
assert exists(self.concat_keys)
assert len(self.concat_keys) == 1
@@ -1871,8 +2033,10 @@ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs
cc = rearrange(cc, 'b h w c -> b c h w')
if exists(self.reshuffle_patch_size):
assert isinstance(self.reshuffle_patch_size, int)
- cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
- p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
+ cc = rearrange(cc,
+ 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+ p1=self.reshuffle_patch_size,
+ p2=self.reshuffle_patch_size)
if bs is not None:
cc = cc[:bs]
cc = cc.to(self.device)
diff --git a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py
index 57b9a4b80f4b..fb088db58919 100644
--- a/examples/images/diffusion/ldm/modules/diffusionmodules/model.py
+++ b/examples/images/diffusion/ldm/modules/diffusionmodules/model.py
@@ -1,10 +1,11 @@
# pytorch_diffusion + derived encoder decoder
import math
+from typing import Any, Optional
+
+import numpy as np
import torch
import torch.nn as nn
-import numpy as np
from einops import rearrange
-from typing import Optional, Any
try:
from lightning.pytorch.utilities import rank_zero_info
@@ -38,14 +39,14 @@ def get_timestep_embedding(timesteps, embedding_dim):
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
- return x*torch.sigmoid(x)
+ return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
@@ -53,15 +54,12 @@ def Normalize(in_channels, num_groups=32):
class Upsample(nn.Module):
+
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
- self.conv = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
@@ -71,20 +69,17 @@ def forward(self, x):
class Downsample(nn.Module):
+
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=3,
- stride=2,
- padding=0)
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
- pad = (0,1,0,1)
+ pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
@@ -93,8 +88,8 @@ def forward(self, x):
class ResnetBlock(nn.Module):
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
- dropout, temb_channels=512):
+
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@@ -102,34 +97,17 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
- self.conv1 = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
- self.temb_proj = torch.nn.Linear(temb_channels,
- out_channels)
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = torch.nn.Conv2d(out_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
- self.nin_shortcut = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0)
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, temb):
h = x
@@ -138,7 +116,7 @@ def forward(self, x, temb):
h = self.conv1(h)
if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
@@ -151,35 +129,20 @@ def forward(self, x, temb):
else:
x = self.nin_shortcut(x)
- return x+h
+ return x + h
class AttnBlock(nn.Module):
+
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.k = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.v = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_out = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
@@ -189,23 +152,24 @@ def forward(self, x):
v = self.v(h_)
# compute attention
- b,c,h,w = q.shape
- q = q.reshape(b,c,h*w)
- q = q.permute(0,2,1) # b,hw,c
- k = k.reshape(b,c,h*w) # b,c,hw
- w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
- v = v.reshape(b,c,h*w)
- w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b,c,h,w)
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
- return x+h_
+ return x + h_
+
class MemoryEfficientAttnBlock(nn.Module):
"""
@@ -213,32 +177,17 @@ class MemoryEfficientAttnBlock(nn.Module):
see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
Note: this is a single-head self-attention operation
"""
+
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.k = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.v = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_out = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
@@ -253,27 +202,20 @@ def forward(self, x):
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
q, k, v = map(
- lambda t: t.unsqueeze(3)
- .reshape(B, t.shape[1], 1, C)
- .permute(0, 2, 1, 3)
- .reshape(B * 1, t.shape[1], C)
- .contiguous(),
+ lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(0, 2, 1, 3).reshape(B * 1, t.shape[1], C).
+ contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
- out = (
- out.unsqueeze(0)
- .reshape(B, 1, out.shape[1], C)
- .permute(0, 2, 1, 3)
- .reshape(B, out.shape[1], C)
- )
+ out = (out.unsqueeze(0).reshape(B, 1, out.shape[1], C).permute(0, 2, 1, 3).reshape(B, out.shape[1], C))
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
- return x+out
+ return x + out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
@@ -283,10 +225,10 @@ def forward(self, x, context=None, mask=None):
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
- assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear",
+ "none"], f'attn_type {attn_type} unknown'
if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
attn_type = "vanilla-xformers"
- rank_zero_info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
@@ -303,13 +245,26 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
class Model(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla"):
super().__init__()
- if use_linear_attn: attn_type = "linear"
+ if use_linear_attn:
+ attn_type = "linear"
self.ch = ch
- self.temb_ch = self.ch*4
+ self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
@@ -320,39 +275,34 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
- torch.nn.Linear(self.ch,
- self.temb_ch),
- torch.nn.Linear(self.temb_ch,
- self.temb_ch),
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
])
# downsampling
- self.conv_in = torch.nn.Conv2d(in_channels,
- self.ch,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
- in_ch_mult = (1,)+tuple(ch_mult)
+ in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
- block_in = ch*in_ch_mult[i_level]
- block_out = ch*ch_mult[i_level]
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
- block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
+ block.append(
+ ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
- if i_level != self.num_resolutions-1:
+ if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
@@ -374,15 +324,16 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
- block_out = ch*ch_mult[i_level]
- skip_in = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks+1):
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
- skip_in = ch*in_ch_mult[i_level]
- block.append(ResnetBlock(in_channels=block_in+skip_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
@@ -392,15 +343,11 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
+ self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- out_ch,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t=None, context=None):
#assert x.shape[2] == x.shape[3] == self.resolution
@@ -425,7 +372,7 @@ def forward(self, x, t=None, context=None):
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
- if i_level != self.num_resolutions-1:
+ if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
@@ -436,9 +383,8 @@ def forward(self, x, t=None, context=None):
# upsampling
for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
- h = self.up[i_level].block[i_block](
- torch.cat([h, hs.pop()], dim=1), temb)
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
@@ -455,12 +401,26 @@ def get_last_layer(self):
class Encoder(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
**ignore_kwargs):
super().__init__()
- if use_linear_attn: attn_type = "linear"
+ if use_linear_attn:
+ attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
@@ -469,33 +429,30 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
self.in_channels = in_channels
# downsampling
- self.conv_in = torch.nn.Conv2d(in_channels,
- self.ch,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
- in_ch_mult = (1,)+tuple(ch_mult)
+ in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
- block_in = ch*in_ch_mult[i_level]
- block_out = ch*ch_mult[i_level]
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
- block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
+ block.append(
+ ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
- if i_level != self.num_resolutions-1:
+ if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
@@ -515,7 +472,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
- 2*z_channels if double_z else z_channels,
+ 2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
@@ -532,7 +489,7 @@ def forward(self, x):
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
- if i_level != self.num_resolutions-1:
+ if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
@@ -549,12 +506,27 @@ def forward(self, x):
class Decoder(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
- attn_type="vanilla", **ignorekwargs):
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignorekwargs):
super().__init__()
- if use_linear_attn: attn_type = "linear"
+ if use_linear_attn:
+ attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
@@ -565,19 +537,14 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
- in_ch_mult = (1,)+tuple(ch_mult)
- block_in = ch*ch_mult[self.num_resolutions-1]
- curr_res = resolution // 2**(self.num_resolutions-1)
- self.z_shape = (1,z_channels,curr_res,curr_res)
- rank_zero_info("Working with z of shape {} = {} dimensions.".format(
- self.z_shape, np.prod(self.z_shape)))
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2**(self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ rank_zero_info("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
- self.conv_in = torch.nn.Conv2d(z_channels,
- block_in,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
@@ -596,12 +563,13 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
- block_out = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks+1):
- block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
@@ -611,15 +579,11 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
+ self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- out_ch,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
@@ -638,7 +602,7 @@ def forward(self, z):
# upsampling
for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
+ for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
@@ -658,31 +622,24 @@ def forward(self, z):
class SimpleDecoder(nn.Module):
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
- self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
- ResnetBlock(in_channels=in_channels,
- out_channels=2 * in_channels,
- temb_channels=0, dropout=0.0),
- ResnetBlock(in_channels=2 * in_channels,
- out_channels=4 * in_channels,
- temb_channels=0, dropout=0.0),
- ResnetBlock(in_channels=4 * in_channels,
- out_channels=2 * in_channels,
- temb_channels=0, dropout=0.0),
- nn.Conv2d(2*in_channels, in_channels, 1),
- Upsample(in_channels, with_conv=True)])
+ self.model = nn.ModuleList([
+ nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0),
+ nn.Conv2d(2 * in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)
+ ])
# end
self.norm_out = Normalize(in_channels)
- self.conv_out = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
for i, layer in enumerate(self.model):
- if i in [1,2,3]:
+ if i in [1, 2, 3]:
x = layer(x, None)
else:
x = layer(x)
@@ -694,25 +651,26 @@ def forward(self, x):
class UpsampleDecoder(nn.Module):
- def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
- ch_mult=(2,2), dropout=0.0):
+
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ curr_res = resolution // 2**(self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
- res_block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
+ res_block.append(
+ ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
@@ -721,11 +679,7 @@ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
# end
self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
+ self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
# upsampling
@@ -742,35 +696,35 @@ def forward(self, x):
class LatentRescaler(nn.Module):
+
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
- self.conv_in = nn.Conv2d(in_channels,
- mid_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0) for _ in range(depth)])
+ self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
+ self.res_block1 = nn.ModuleList([
+ ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
+ for _ in range(depth)
+ ])
self.attn = AttnBlock(mid_channels)
- self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0) for _ in range(depth)])
-
- self.conv_out = nn.Conv2d(mid_channels,
- out_channels,
- kernel_size=1,
- )
+ self.res_block2 = nn.ModuleList([
+ ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0)
+ for _ in range(depth)
+ ])
+
+ self.conv_out = nn.Conv2d(
+ mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
- x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = torch.nn.functional.interpolate(x,
+ size=(int(round(x.shape[2] * self.factor)),
+ int(round(x.shape[3] * self.factor))))
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
@@ -779,17 +733,37 @@ def forward(self, x):
class MergedRescaleEncoder(nn.Module):
- def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True,
- ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+
+ def __init__(self,
+ in_channels,
+ ch,
+ resolution,
+ out_ch,
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ ch_mult=(1, 2, 4, 8),
+ rescale_factor=1.0,
+ rescale_module_depth=1):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
- self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
- z_channels=intermediate_chn, double_z=False, resolution=resolution,
- attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ self.encoder = Encoder(in_channels=in_channels,
+ num_res_blocks=num_res_blocks,
+ ch=ch,
+ ch_mult=ch_mult,
+ z_channels=intermediate_chn,
+ double_z=False,
+ resolution=resolution,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
out_ch=None)
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
- mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+ self.rescaler = LatentRescaler(factor=rescale_factor,
+ in_channels=intermediate_chn,
+ mid_channels=intermediate_chn,
+ out_channels=out_ch,
+ depth=rescale_module_depth)
def forward(self, x):
x = self.encoder(x)
@@ -798,15 +772,36 @@ def forward(self, x):
class MergedRescaleDecoder(nn.Module):
- def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
- dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+
+ def __init__(self,
+ z_channels,
+ out_ch,
+ resolution,
+ num_res_blocks,
+ attn_resolutions,
+ ch,
+ ch_mult=(1, 2, 4, 8),
+ dropout=0.0,
+ resamp_with_conv=True,
+ rescale_factor=1.0,
+ rescale_module_depth=1):
super().__init__()
- tmp_chn = z_channels*ch_mult[-1]
- self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
- resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
- ch_mult=ch_mult, resolution=resolution, ch=ch)
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
- out_channels=tmp_chn, depth=rescale_module_depth)
+ tmp_chn = z_channels * ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch,
+ z_channels=tmp_chn,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ in_channels=None,
+ num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult,
+ resolution=resolution,
+ ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor,
+ in_channels=z_channels,
+ mid_channels=tmp_chn,
+ out_channels=tmp_chn,
+ depth=rescale_module_depth)
def forward(self, x):
x = self.rescaler(x)
@@ -815,16 +810,26 @@ def forward(self, x):
class Upsampler(nn.Module):
+
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
- num_blocks = int(np.log2(out_size//in_size))+1
- factor_up = 1.+ (out_size % in_size)
- rank_zero_info(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
- self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ num_blocks = int(np.log2(out_size // in_size)) + 1
+ factor_up = 1. + (out_size % in_size)
+ rank_zero_info(
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
+ )
+ self.rescaler = LatentRescaler(factor=factor_up,
+ in_channels=in_channels,
+ mid_channels=2 * in_channels,
out_channels=in_channels)
- self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
- attn_resolutions=[], in_channels=None, ch=in_channels,
+ self.decoder = Decoder(out_ch=out_channels,
+ resolution=out_size,
+ z_channels=in_channels,
+ num_res_blocks=2,
+ attn_resolutions=[],
+ in_channels=None,
+ ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)])
def forward(self, x):
@@ -834,23 +839,21 @@ def forward(self, x):
class Resize(nn.Module):
+
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
- rank_zero_info(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ rank_zero_info(
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=4,
- stride=2,
- padding=1)
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x, scale_factor=1.0):
- if scale_factor==1.0:
+ if scale_factor == 1.0:
return x
else:
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
diff --git a/examples/images/diffusion/main.py b/examples/images/diffusion/main.py
index 87d495123714..4dd88a5eca44 100644
--- a/examples/images/diffusion/main.py
+++ b/examples/images/diffusion/main.py
@@ -106,7 +106,20 @@ def str2bool(v):
nargs="?",
help="disable test",
)
- parser.add_argument("-p", "--project", help="name of new or path to existing project")
+ parser.add_argument(
+ "-p",
+ "--project",
+ help="name of new or path to existing project",
+ )
+ parser.add_argument(
+ "-c",
+ "--ckpt",
+ type=str,
+ const=True,
+ default="",
+ nargs="?",
+ help="load pretrained checkpoint from stable AI",
+ )
parser.add_argument(
"-d",
"--debug",
@@ -145,22 +158,7 @@ def str2bool(v):
default=True,
help="scale base-lr by ngpu * batch_size * n_accumulate",
)
- parser.add_argument(
- "--use_fp16",
- type=str2bool,
- nargs="?",
- const=True,
- default=True,
- help="whether to use fp16",
- )
- parser.add_argument(
- "--flash",
- type=str2bool,
- const=True,
- default=False,
- nargs="?",
- help="whether to use flash attention",
- )
+
return parser
@@ -341,6 +339,12 @@ def on_fit_start(self, trainer, pl_module):
except FileNotFoundError:
pass
+ # def on_fit_end(self, trainer, pl_module):
+ # if trainer.global_rank == 0:
+ # ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
+ # rank_zero_info(f"Saving final checkpoint in {ckpt_path}.")
+ # trainer.save_checkpoint(ckpt_path)
+
class ImageLogger(Callback):
@@ -535,7 +539,10 @@ def on_train_epoch_end(self, trainer, pl_module):
raise ValueError("-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint")
+
+ ckpt = None
if opt.resume:
+ rank_zero_info("Resuming from {}".format(opt.resume))
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
@@ -543,13 +550,13 @@ def on_train_epoch_end(self, trainer, pl_module):
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir = "/".join(paths[:-2])
+ rank_zero_info("logdir: {}".format(logdir))
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
- opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split("/")
@@ -558,6 +565,7 @@ def on_train_epoch_end(self, trainer, pl_module):
if opt.name:
name = "_" + opt.name
elif opt.base:
+ rank_zero_info("Using base config {}".format(opt.base))
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = "_" + cfg_name
@@ -566,6 +574,9 @@ def on_train_epoch_end(self, trainer, pl_module):
nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname)
+ if opt.ckpt:
+ ckpt = opt.ckpt
+
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
@@ -582,14 +593,11 @@ def on_train_epoch_end(self, trainer, pl_module):
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
- print(trainer_config)
if not trainer_config["accelerator"] == "gpu":
del trainer_config["accelerator"]
cpu = True
- print("Running on CPU")
else:
cpu = False
- print("Running on GPU")
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
@@ -597,10 +605,12 @@ def on_train_epoch_end(self, trainer, pl_module):
use_fp16 = trainer_config.get("precision", 32) == 16
if use_fp16:
config.model["params"].update({"use_fp16": True})
- print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
else:
config.model["params"].update({"use_fp16": False})
- print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
+
+ if ckpt is not None:
+ config.model["params"].update({"ckpt": ckpt})
+ rank_zero_info("Using ckpt_path = {}".format(config.model["params"]["ckpt"]))
model = instantiate_from_config(config.model)
# trainer and callbacks
@@ -639,7 +649,6 @@ def on_train_epoch_end(self, trainer, pl_module):
# config the strategy, defualt is ddp
if "strategy" in trainer_config:
strategy_cfg = trainer_config["strategy"]
- print("Using strategy: {}".format(strategy_cfg["target"]))
strategy_cfg["target"] = LIGHTNING_PACK_NAME + strategy_cfg["target"]
else:
strategy_cfg = {
@@ -648,7 +657,6 @@ def on_train_epoch_end(self, trainer, pl_module):
"find_unused_parameters": False
}
}
- print("Using strategy: DDPStrategy")
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
@@ -664,7 +672,6 @@ def on_train_epoch_end(self, trainer, pl_module):
}
}
if hasattr(model, "monitor"):
- print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 3
@@ -673,7 +680,6 @@ def on_train_epoch_end(self, trainer, pl_module):
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
- print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
@@ -710,8 +716,6 @@ def on_train_epoch_end(self, trainer, pl_module):
"target": "main.CUDACallback"
},
}
- if version.parse(pl.__version__) >= version.parse('1.4.0'):
- default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
@@ -737,15 +741,11 @@ def on_train_epoch_end(self, trainer, pl_module):
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
- if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
- callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
- elif 'ignore_keys_callback' in callbacks_cfg:
- del callbacks_cfg['ignore_keys_callback']
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
- trainer.logdir = logdir ###
+ trainer.logdir = logdir
# data
data = instantiate_from_config(config.data)
@@ -754,9 +754,9 @@ def on_train_epoch_end(self, trainer, pl_module):
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
- print("#### Data #####")
+
for k in data.datasets:
- print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
+ rank_zero_info(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
# configure learning rate
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
@@ -768,17 +768,17 @@ def on_train_epoch_end(self, trainer, pl_module):
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
- print(f"accumulate_grad_batches = {accumulate_grad_batches}")
+ rank_zero_info(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
- print(
+ rank_zero_info(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)"
.format(model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
else:
model.learning_rate = base_lr
- print("++++ NOT USING LR SCALING ++++")
- print(f"Setting learning rate to {model.learning_rate:.2e}")
+ rank_zero_info("++++ NOT USING LR SCALING ++++")
+ rank_zero_info(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1
def melk(*args, **kwargs):
diff --git a/examples/images/diffusion/requirements.txt b/examples/images/diffusion/requirements.txt
index 60c4b903e01f..59d027fcf60f 100644
--- a/examples/images/diffusion/requirements.txt
+++ b/examples/images/diffusion/requirements.txt
@@ -1,18 +1,19 @@
albumentations==1.3.0
-opencv-python==4.6.0
+opencv-python==4.6.0.66
pudb==2019.2
prefetch_generator
imageio==2.9.0
imageio-ffmpeg==0.4.2
-torchmetrics==0.6
+torchmetrics==0.7
omegaconf==2.1.1
test-tube>=0.7.5
streamlit>=0.73.1
einops==0.3.0
-transformers==4.19.2
+transformers
webdataset==0.2.5
open-clip-torch==2.7.0
gradio==3.11
+lightning==1.9.0
datasets
colossalai
-e .
diff --git a/examples/images/diffusion/scripts/txt2img.sh b/examples/images/diffusion/scripts/txt2img.sh
index 549bb03a6885..bc6480b6bdaa 100755
--- a/examples/images/diffusion/scripts/txt2img.sh
+++ b/examples/images/diffusion/scripts/txt2img.sh
@@ -1,6 +1,5 @@
-python scripts/txt2img.py --prompt "Teyvat, Name:Layla, Element: Cryo, Weapon:Sword, Region:Sumeru, Model type:Medium Female, Description:a woman in a blue outfit holding a sword" --plms \
+python scripts/txt2img.py --prompt "Teyvat, Medium Female, a woman in a blue outfit holding a sword" --plms \
--outdir ./output \
- --config /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/checkpoints/last.ckpt \
- --ckpt /home/lcmql/data2/Genshin/2022-11-18T16-38-46_train_colossalai_teyvattest/configs/2022-11-18T16-38-46-project.yaml \
+ --ckpt checkpoints/last.ckpt \
+ --config configs/2023-02-02T18-06-14-project.yaml \
--n_samples 4
-
diff --git a/examples/images/diffusion/test_ci.sh b/examples/images/diffusion/test_ci.sh
new file mode 100755
index 000000000000..44cf47046684
--- /dev/null
+++ b/examples/images/diffusion/test_ci.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+set -euxo pipefail
+
+conda env create -f environment.yaml
+
+conda activate ldm
+
+conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch
+pip install transformers diffusers invisible-watermark
+
+CUDA_EXT=1 pip install colossalai
+
+wget https://huggingface.co/stabilityai/stable-diffusion-2-base/resolve/main/512-base-ema.ckpt
+
+python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt 512-base-ema.ckpt
diff --git a/examples/images/diffusion/train_colossalai.sh b/examples/images/diffusion/train_colossalai.sh
index 4223a69412fb..c56ed7876e5a 100755
--- a/examples/images/diffusion/train_colossalai.sh
+++ b/examples/images/diffusion/train_colossalai.sh
@@ -1,5 +1,5 @@
-HF_DATASETS_OFFLINE=1
-TRANSFORMERS_OFFLINE=1
-DIFFUSERS_OFFLINE=1
+HF_DATASETS_OFFLINE=1
+TRANSFORMERS_OFFLINE=1
+DIFFUSERS_OFFLINE=1
-python main.py --logdir /tmp -t -b /configs/train_colossalai.yaml
+python main.py --logdir /tmp --train --base configs/Teyvat/train_colossalai_teyvat.yaml --ckpt diffuser_root_dir/512-base-ema.ckpt
diff --git a/examples/images/dreambooth/README.md b/examples/images/dreambooth/README.md
index a306a3abfc2c..b067a437c764 100644
--- a/examples/images/dreambooth/README.md
+++ b/examples/images/dreambooth/README.md
@@ -5,18 +5,18 @@ The `train_dreambooth_colossalai.py` script shows how to implement the training
By accommodating model data in CPU and GPU and moving the data to the computing device when necessary, [Gemini](https://www.colossalai.org/docs/advanced_tutorials/meet_gemini), the Heterogeneous Memory Manager of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) can breakthrough the GPU memory wall by using GPU and CPU memory (composed of CPU DRAM or nvme SSD memory) together at the same time. Moreover, the model scale can be further improved by combining heterogeneous training with the other parallel approaches, such as data parallel, tensor parallel and pipeline parallel.
-## Installing the dependencies
+## Installation
-Before running the scripts, make sure to install the library's training dependencies:
+To begin with, make sure your operating system has the cuda version suitable for this exciting training session, which is cuda11.6-11.8. Notice that you may want to make sure the module versions suitable for the whole environment. Before running the scripts, make sure to install the library's training dependencies:
```bash
-pip install -r requirements_colossalai.txt
+pip install -r requirements.txt
```
### Install [colossalai](https://github.com/hpcaitech/ColossalAI.git)
```bash
-pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
+pip install colossalai
```
**From source**
@@ -37,9 +37,7 @@ The `text` include the tag `Teyvat`, `Name`,`Element`, `Weapon`, `Region`, `Mode
## Training
-The arguement `placement` can be `cpu`, `auto`, `cuda`, with `cpu` the GPU RAM required can be minimized to 4GB but will deceleration, with `cuda` you can also reduce GPU memory by half but accelerated training, with `auto` a more balanced solution for speed and memory can be obtained。
-
-**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
+We provide the script `colossalai.sh` to run the training task with colossalai. Meanwhile, we also provided traditional training process of dreambooth, `dreambooth.sh`, for possible comparation. For instance, the script of training process for [stable-diffusion-v1-4] model can be modified into:
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
@@ -59,12 +57,17 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
--max_train_steps=400 \
--placement="cuda"
```
-
+- `MODEL_NAME` refers to the model you are training.
+- `INSTANCE_DIR` refers to personalized path to instance images, you might need to insert information here.
+- `OUTPUT_DIR` refers to local path to save the trained model, you might need to find a path with enough space.
+- `resolution` refers to the corresponding resolution number of your target model. Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.
+- `placement` refers to the training strategy supported by Colossal AI, defult = 'cuda', which refers to loading all the parameters into cuda memory. On the other hand, 'cpu' refers to 'cpu offload' strategy while 'auto' enables 'Gemini', both featured by Colossal AI.
### Training with prior-preservation loss
Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data.
-According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time.
+
+According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. The general script can be then modified as the following.
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
@@ -91,7 +94,7 @@ torchrun --nproc_per_node 2 train_dreambooth_colossalai.py \
## Inference
-Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
+Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. `--instance_prompt="a photo of sks dog" ` in the above example) in your prompt.
```python
from diffusers import StableDiffusionPipeline
@@ -105,3 +108,16 @@ image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("dog-bucket.png")
```
+
+## Invitation to open-source contribution
+Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models!
+
+You may contact us or participate in the following ways:
+1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks!
+2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md).
+3. Join the Colossal-AI community on
+[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
+and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
+4. Send your official proposal to email contact@hpcaitech.com
+
+Thanks so much to all of our amazing contributors!
diff --git a/examples/images/dreambooth/requirement_colossalai.txt b/examples/images/dreambooth/requirement_colossalai.txt
deleted file mode 100644
index c4a0e91703bb..000000000000
--- a/examples/images/dreambooth/requirement_colossalai.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-diffusers
-torch
-torchvision
-ftfy
-tensorboard
-modelcards
-transformers
-colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
diff --git a/examples/images/dreambooth/requirements.txt b/examples/images/dreambooth/requirements.txt
index 6c4f40fb5dd0..1ec828c630ef 100644
--- a/examples/images/dreambooth/requirements.txt
+++ b/examples/images/dreambooth/requirements.txt
@@ -5,4 +5,3 @@ transformers>=4.21.0
ftfy
tensorboard
modelcards
-colossalai
diff --git a/examples/images/dreambooth/test_ci.sh b/examples/images/dreambooth/test_ci.sh
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py
index b7e24bfe4a15..5c4c86bc7073 100644
--- a/examples/images/dreambooth/train_dreambooth_colossalai.py
+++ b/examples/images/dreambooth/train_dreambooth_colossalai.py
@@ -10,7 +10,7 @@
import torch.utils.checkpoint
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
-from huggingface_hub import HfFolder, Repository, whoami
+from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
@@ -133,9 +133,13 @@ def parse_args(input_args=None):
default="cpu",
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
- parser.add_argument("--center_crop",
- action="store_true",
- help="Whether to center crop images before resizing to resolution")
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."),
+ )
parser.add_argument("--train_batch_size",
type=int,
default=4,
@@ -149,12 +153,6 @@ def parse_args(input_args=None):
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
- parser.add_argument(
- "--gradient_accumulation_steps",
- type=int,
- default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.",
- )
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
@@ -355,10 +353,13 @@ def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
def main(args):
- colossalai.launch_from_torch(config={})
+ if args.seed is None:
+ colossalai.launch_from_torch(config={})
+ else:
+ colossalai.launch_from_torch(config={}, seed=args.seed)
- if args.seed is not None:
- gpc.set_seed(args.seed)
+ local_rank = gpc.get_local_rank(ParallelMode.DATA)
+ world_size = gpc.get_world_size(ParallelMode.DATA)
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
@@ -387,7 +388,7 @@ def main(args):
for example in tqdm(
sample_dataloader,
desc="Generating class images",
- disable=not gpc.get_local_rank(ParallelMode.DATA) == 0,
+ disable=not local_rank == 0,
):
images = pipeline(example["prompt"]).images
@@ -399,13 +400,14 @@ def main(args):
del pipeline
# Handle the repository creation
- if gpc.get_local_rank(ParallelMode.DATA) == 0:
+ if local_rank == 0:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
- repo = Repository(args.output_dir, clone_from=repo_name)
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
@@ -465,7 +467,7 @@ def main(args):
unet.enable_gradient_checkpointing()
if args.scale_lr:
- args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * gpc.get_world_size(ParallelMode.DATA)
+ args.learning_rate = args.learning_rate * args.train_batch_size * world_size
unet = gemini_zero_dpp(unet, args.placement)
@@ -523,7 +525,7 @@ def collate_fn(examples):
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
@@ -531,8 +533,8 @@ def collate_fn(examples):
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
- num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
)
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
@@ -547,14 +549,14 @@ def collate_fn(examples):
text_encoder.to(get_current_device(), dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# Train!
- total_batch_size = args.train_batch_size * gpc.get_world_size(ParallelMode.DATA) * args.gradient_accumulation_steps
+ total_batch_size = args.train_batch_size * world_size
logger.info("***** Running training *****", ranks=[0])
logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
@@ -562,11 +564,10 @@ def collate_fn(examples):
logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
- logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}", ranks=[0])
logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
# Only show the progress bar once on each machine.
- progress_bar = tqdm(range(args.max_train_steps), disable=not gpc.get_local_rank(ParallelMode.DATA) == 0)
+ progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
progress_bar.set_description("Steps")
global_step = 0
@@ -643,7 +644,7 @@ def collate_fn(examples):
if global_step % args.save_steps == 0:
torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
- if gpc.get_local_rank(ParallelMode.DATA) == 0:
+ if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=torch_unet,
@@ -658,7 +659,7 @@ def collate_fn(examples):
torch.cuda.synchronize()
unet = get_static_torch_model(unet)
- if gpc.get_local_rank(ParallelMode.DATA) == 0:
+ if local_rank == 0:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
new file mode 100644
index 000000000000..3d789ae2ce0f
--- /dev/null
+++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py
@@ -0,0 +1,691 @@
+import argparse
+import hashlib
+import math
+import os
+from pathlib import Path
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
+from diffusers.loaders import AttnProcsLayers
+from diffusers.models.cross_attention import LoRACrossAttnProcessor
+from diffusers.optimization import get_scheduler
+from huggingface_hub import HfFolder, Repository, create_repo, whoami
+from PIL import Image
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import colossalai
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
+from colossalai.nn.parallel.utils import get_static_torch_model
+from colossalai.utils import get_current_device
+from colossalai.utils.model.colo_init_context import ColoInitContext
+
+disable_existing_loggers()
+logger = get_dist_logger()
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default="a photo of sks dog",
+ required=False,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=("Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="text-inversion-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=("The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"),
+ )
+ parser.add_argument(
+ "--placement",
+ type=str,
+ default="cpu",
+ help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=("Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."),
+ )
+ parser.add_argument("--train_batch_size",
+ type=int,
+ default=4,
+ help="Batch size (per device) for the training dataloader.")
+ parser.add_argument("--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images.")
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'),
+ )
+ parser.add_argument("--lr_warmup_steps",
+ type=int,
+ default=500,
+ help="Number of steps for the warmup in the lr scheduler.")
+ parser.add_argument("--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes.")
+
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ if args.class_data_dir is not None:
+ logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ logger.warning("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ size=512,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose([
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ])
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+ example["instance_prompt_ids"] = self.tokenizer(
+ self.instance_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt_ids"] = self.tokenizer(
+ self.class_prompt,
+ padding="do_not_pad",
+ truncation=True,
+ max_length=self.tokenizer.model_max_length,
+ ).input_ids
+
+ return example
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
+ if token is None:
+ token = HfFolder.get_token()
+ if organization is None:
+ username = whoami(token)["name"]
+ return f"{username}/{model_id}"
+ else:
+ return f"{organization}/{model_id}"
+
+
+# Gemini + ZeRO DDP
+def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
+ from colossalai.nn.parallel import GeminiDDP
+
+ model = GeminiDDP(model,
+ device=get_current_device(),
+ placement_policy=placememt_policy,
+ pin_memory=True,
+ search_range_mb=64)
+ return model
+
+
+def main(args):
+ if args.seed is None:
+ colossalai.launch_from_torch(config={})
+ else:
+ colossalai.launch_from_torch(config={}, seed=args.seed)
+
+ local_rank = gpc.get_local_rank(ParallelMode.DATA)
+ world_size = gpc.get_world_size(ParallelMode.DATA)
+
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if get_current_device() == "cuda" else torch.float32
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ pipeline.to(get_current_device())
+
+ for example in tqdm(
+ sample_dataloader,
+ desc="Generating class images",
+ disable=not local_rank == 0,
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+
+ # Handle the repository creation
+ if local_rank == 0:
+ if args.push_to_hub:
+ if args.hub_model_id is None:
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
+ else:
+ repo_name = args.hub_model_id
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
+
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
+ if "step_*" not in gitignore:
+ gitignore.write("step_*\n")
+ if "epoch_*" not in gitignore:
+ gitignore.write("epoch_*\n")
+ elif args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ logger.info(f"Loading tokenizer from {args.tokenizer_name}", ranks=[0])
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer_name,
+ revision=args.revision,
+ use_fast=False,
+ )
+ elif args.pretrained_model_name_or_path:
+ logger.info("Loading tokenizer from pretrained model", ranks=[0])
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path)
+
+ # Load models and create wrapper for stable diffusion
+
+ logger.info(f"Loading text_encoder from {args.pretrained_model_name_or_path}", ranks=[0])
+
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ )
+
+ logger.info(f"Loading AutoencoderKL from {args.pretrained_model_name_or_path}", ranks=[0])
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ )
+
+ logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
+ with ColoInitContext(device=get_current_device()):
+ unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.revision,
+ low_cpu_mem_usage=False)
+ unet.requires_grad_(False)
+
+ # Set correct lora layers
+ lora_attn_procs = {}
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+
+ lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim)
+
+ unet.set_attn_processor(lora_attn_procs)
+ lora_layers = AttnProcsLayers(unet.attn_processors)
+
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ if args.scale_lr:
+ args.learning_rate = args.learning_rate * args.train_batch_size * world_size
+
+ unet = gemini_zero_dpp(unet, args.placement)
+
+ # config optimizer for colossalai zero
+ optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm)
+
+ # load noise_scheduler
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ # prepare dataset
+ logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ def collate_fn(examples):
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if args.with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = tokenizer.pad(
+ {
+ "input_ids": input_ids
+ },
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ return_tensors="pt",
+ ).input_ids
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+ return batch
+
+ train_dataloader = torch.utils.data.DataLoader(train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=collate_fn,
+ num_workers=1)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move text_encode and vae to gpu.
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ vae.to(get_current_device(), dtype=weight_dtype)
+ text_encoder.to(get_current_device(), dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader))
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ total_batch_size = args.train_batch_size * world_size
+
+ logger.info("***** Running training *****", ranks=[0])
+ logger.info(f" Num examples = {len(train_dataset)}", ranks=[0])
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}", ranks=[0])
+ logger.info(f" Num Epochs = {args.num_train_epochs}", ranks=[0])
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}", ranks=[0])
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}", ranks=[0])
+ logger.info(f" Total optimization steps = {args.max_train_steps}", ranks=[0])
+
+ # Only show the progress bar once on each machine.
+ progress_bar = tqdm(range(args.max_train_steps), disable=not local_rank == 0)
+ progress_bar.set_description("Steps")
+ global_step = 0
+
+ torch.cuda.synchronize()
+ for epoch in range(args.num_train_epochs):
+ unet.train()
+ for step, batch in enumerate(train_dataloader):
+ torch.cuda.reset_peak_memory_stats()
+ # Move batch to gpu
+ for key, value in batch.items():
+ batch[key] = value.to(get_current_device(), non_blocking=True)
+
+ # Convert images to latent space
+ optimizer.zero_grad()
+
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * 0.18215
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
+
+ # Predict the noise residual
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
+
+ # Compute prior loss
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+
+ optimizer.backward(loss)
+
+ optimizer.step()
+ lr_scheduler.step()
+ logger.info(f"max GPU_mem cost is {torch.cuda.max_memory_allocated()/2**20} MB", ranks=[0])
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ progress_bar.update(1)
+ global_step += 1
+ logs = {
+ "loss": loss.detach().item(),
+ "lr": optimizer.param_groups[0]["lr"],
+ } # lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step % args.save_steps == 0:
+ torch.cuda.synchronize()
+ torch_unet = get_static_torch_model(unet)
+ if local_rank == 0:
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ torch_unet = torch_unet.to(torch.float32)
+ torch_unet.save_attn_procs(save_path)
+ logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
+ if global_step >= args.max_train_steps:
+ break
+
+ torch.cuda.synchronize()
+ torch_unet = get_static_torch_model(unet)
+
+ if local_rank == 0:
+ torch_unet = torch_unet.to(torch.float32)
+ torch_unet.save_attn_procs(save_path)
+ logger.info(f"Saving model checkpoint to {args.output_dir}", ranks=[0])
+
+ if args.push_to_hub:
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/images/vit/configs/vit_1d_tp2_ci.py b/examples/images/vit/configs/vit_1d_tp2_ci.py
new file mode 100644
index 000000000000..e491e4ada45e
--- /dev/null
+++ b/examples/images/vit/configs/vit_1d_tp2_ci.py
@@ -0,0 +1,32 @@
+from colossalai.amp import AMP_TYPE
+
+# hyperparameters
+# BATCH_SIZE is as per GPU
+# global batch size = BATCH_SIZE x data parallel size
+BATCH_SIZE = 8
+LEARNING_RATE = 3e-3
+WEIGHT_DECAY = 0.3
+NUM_EPOCHS = 3
+WARMUP_EPOCHS = 1
+
+# model config
+IMG_SIZE = 224
+PATCH_SIZE = 16
+HIDDEN_SIZE = 32
+DEPTH = 2
+NUM_HEADS = 4
+MLP_RATIO = 4
+NUM_CLASSES = 10
+CHECKPOINT = False
+SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
+
+USE_DDP = True
+TP_WORLD_SIZE = 2
+TP_TYPE = 'row'
+parallel = dict(tensor=dict(mode="1d", size=TP_WORLD_SIZE),)
+
+fp16 = dict(mode=AMP_TYPE.NAIVE)
+clip_grad_norm = 1.0
+gradient_accumulation = 2
+
+LOG_PATH = "./log_ci"
diff --git a/examples/images/vit/requirements.txt b/examples/images/vit/requirements.txt
index 137a69e80498..1f69794ebe70 100644
--- a/examples/images/vit/requirements.txt
+++ b/examples/images/vit/requirements.txt
@@ -1,2 +1,8 @@
colossalai >= 0.1.12
torch >= 1.8.1
+numpy>=1.24.1
+timm>=0.6.12
+titans>=0.0.7
+tqdm>=4.61.2
+transformers>=4.25.1
+nvidia-dali-cuda110>=1.8.0 --extra-index-url https://developer.download.nvidia.com/compute/redist
diff --git a/examples/images/vit/test_ci.sh b/examples/images/vit/test_ci.sh
new file mode 100644
index 000000000000..41d25ee23521
--- /dev/null
+++ b/examples/images/vit/test_ci.sh
@@ -0,0 +1,9 @@
+export OMP_NUM_THREADS=4
+
+pip install -r requirements.txt
+
+# train
+colossalai run \
+--nproc_per_node 4 train.py \
+--config configs/vit_1d_tp2_ci.py \
+--dummy_data
diff --git a/examples/images/vit/train.py b/examples/images/vit/train.py
index de39801c7972..0b4489244368 100644
--- a/examples/images/vit/train.py
+++ b/examples/images/vit/train.py
@@ -7,6 +7,7 @@
from timm.models.vision_transformer import _create_vision_transformer
from titans.dataloader.imagenet import build_dali_imagenet
from tqdm import tqdm
+from vit import DummyDataLoader
import colossalai
from colossalai.core import global_context as gpc
@@ -56,8 +57,8 @@ def init_spec_func(model, tp_type):
def train_imagenet():
parser = colossalai.get_default_parser()
- parser.add_argument('--from_torch', default=True, action='store_true')
- parser.add_argument('--resume_from', default=False)
+ parser.add_argument('--resume_from', default=False, action='store_true')
+ parser.add_argument('--dummy_data', default=False, action='store_true')
args = parser.parse_args()
colossalai.launch_from_torch(config=args.config)
@@ -74,10 +75,22 @@ def train_imagenet():
logger.log_to_file(log_path)
logger.info('Build data loader', ranks=[0])
- root = os.environ['DATA']
- train_dataloader, test_dataloader = build_dali_imagenet(root,
- train_batch_size=gpc.config.BATCH_SIZE,
- test_batch_size=gpc.config.BATCH_SIZE)
+ if not args.dummy_data:
+ root = os.environ['DATA']
+ train_dataloader, test_dataloader = build_dali_imagenet(root,
+ train_batch_size=gpc.config.BATCH_SIZE,
+ test_batch_size=gpc.config.BATCH_SIZE)
+ else:
+ train_dataloader = DummyDataLoader(length=10,
+ batch_size=gpc.config.BATCH_SIZE,
+ category=gpc.config.NUM_CLASSES,
+ image_size=gpc.config.IMG_SIZE,
+ return_dict=False)
+ test_dataloader = DummyDataLoader(length=5,
+ batch_size=gpc.config.BATCH_SIZE,
+ category=gpc.config.NUM_CLASSES,
+ image_size=gpc.config.IMG_SIZE,
+ return_dict=False)
logger.info('Build model', ranks=[0])
diff --git a/examples/images/vit/vit.py b/examples/images/vit/vit.py
index 14c870b39268..f22e8ea90cec 100644
--- a/examples/images/vit/vit.py
+++ b/examples/images/vit/vit.py
@@ -32,21 +32,24 @@ def __len__(self):
class DummyDataLoader(DummyDataGenerator):
- batch_size = 4
- channel = 3
- category = 8
- image_size = 224
+
+ def __init__(self, length=10, batch_size=4, channel=3, category=8, image_size=224, return_dict=True):
+ super().__init__(length)
+ self.batch_size = batch_size
+ self.channel = channel
+ self.category = category
+ self.image_size = image_size
+ self.return_dict = return_dict
def generate(self):
image_dict = {}
- image_dict['pixel_values'] = torch.rand(DummyDataLoader.batch_size,
- DummyDataLoader.channel,
- DummyDataLoader.image_size,
- DummyDataLoader.image_size,
- device=get_current_device()) * 2 - 1
- image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
+ image_dict['pixel_values'] = torch.rand(
+ self.batch_size, self.channel, self.image_size, self.image_size, device=get_current_device()) * 2 - 1
+ image_dict['label'] = torch.randint(self.category, (self.batch_size,),
dtype=torch.int64,
device=get_current_device())
+ if not self.return_dict:
+ return image_dict['pixel_values'], image_dict['label']
return image_dict
diff --git a/examples/language/bert/run_gemini.sh b/examples/language/bert/run_gemini.sh
new file mode 100644
index 000000000000..d791334e8c97
--- /dev/null
+++ b/examples/language/bert/run_gemini.sh
@@ -0,0 +1,22 @@
+set -x
+# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]
+export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
+
+# The following options only valid when DISTPLAN="colossalai"
+export GPUNUM=${GPUNUM:-1}
+export PLACEMENT=${PLACEMENT:-"cpu"}
+export BATCH_SIZE=${BATCH_SIZE:-16}
+
+# bert | albert
+export MODEL_TYPE=${MODEL_TYPE:-"bert"}
+export TRAIN_STEP=${TRAIN_STEP:-10}
+
+mkdir -p gemini_logs
+
+env CUDA_LAUNCH_BLOCKING=1 torchrun --standalone --nproc_per_node=${GPUNUM} ./train_bert_demo.py \
+--model_type=${MODEL_TYPE} \
+--batch_size=${BATCH_SIZE} \
+--placement=${PLACEMENT} \
+--distplan=${DISTPLAN} \
+--train_step=${TRAIN_STEP} \
+2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_${PLACEMENT}.log
diff --git a/examples/language/bert/test_ci.sh b/examples/language/bert/test_ci.sh
new file mode 100644
index 000000000000..42c63fec50c0
--- /dev/null
+++ b/examples/language/bert/test_ci.sh
@@ -0,0 +1,2 @@
+set -x
+env GPUNUM=1 bash run_gemini.sh
diff --git a/examples/language/bert/train_bert_demo.py b/examples/language/bert/train_bert_demo.py
new file mode 100644
index 000000000000..b690ff787d01
--- /dev/null
+++ b/examples/language/bert/train_bert_demo.py
@@ -0,0 +1,332 @@
+import os
+from functools import partial
+from time import time
+
+import psutil
+import torch
+from packaging import version
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+from transformers import AlbertConfig, AlbertForSequenceClassification, BertConfig, BertForSequenceClassification
+
+import colossalai
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
+from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
+from colossalai.utils import get_current_device
+from colossalai.utils.model.colo_init_context import ColoInitContext
+
+CAI_VERSION = colossalai.__version__
+
+
+def get_tflops(model_numel, batch_size, seq_len, step_time):
+ return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
+
+
+def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
+ from contextlib import nullcontext
+
+ from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
+ if enable_flag:
+ return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
+ on_trace_ready=tensorboard_trace_handler(save_dir),
+ record_shapes=True,
+ profile_memory=True)
+ else:
+
+ class DummyProfiler:
+
+ def __init__(self):
+ self.step_number = 0
+
+ def step(self):
+ self.step_number += 1
+
+ return nullcontext(DummyProfiler())
+
+
+def get_time_stamp():
+ import time
+ cur_time = time.strftime("%d-%H:%M", time.localtime())
+ return cur_time
+
+
+def get_bert_data(batch_size: int, sequence_length: int, vacob_size: int, n_class: int, device: torch.device):
+ input = torch.randint(
+ low=0,
+ high=vacob_size,
+ size=(batch_size, sequence_length),
+ device=device,
+ dtype=torch.long,
+ )
+ label = torch.randint(low=0, high=n_class, size=(batch_size,), device=device, dtype=torch.long)
+ return input, label
+
+
+def parse_args():
+ parser = colossalai.get_default_parser()
+ parser.add_argument(
+ "--distplan",
+ type=str,
+ default='CAI_Gemini',
+ help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
+ )
+ parser.add_argument(
+ "--placement",
+ type=str,
+ default='cpu',
+ help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=8,
+ help="batch size per DP group of training.",
+ )
+ parser.add_argument(
+ "--model_type",
+ type=str,
+ default="bert",
+ help="bert or albert",
+ )
+ parser.add_argument(
+ "--train_step",
+ type=int,
+ default=10,
+ help="training iterations for test",
+ )
+
+ args = parser.parse_args()
+ return args
+
+
+SEQ_LEN = 512
+VOCAB_SIZE = 1000
+NUM_LABELS = 10
+
+
+# Parameter Sharding Strategies for Tensor Parallelism
+def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
+ spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
+ param.set_tensor_spec(*spec)
+
+
+def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
+ split_param_single_dim_tp1d(0, param, pg)
+
+
+def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
+ split_param_single_dim_tp1d(-1, param, pg)
+
+
+def get_cpu_mem():
+ return psutil.Process().memory_info().rss / 1024**2
+
+
+def get_gpu_mem():
+ return torch.cuda.memory_allocated() / 1024**2
+
+
+def get_mem_info(prefix=''):
+ return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
+
+
+def get_model_size(model: nn.Module):
+ total_numel = 0
+ for module in model.modules():
+ for p in module.parameters(recurse=False):
+ total_numel += p.numel()
+ return total_numel
+
+
+def model_builder(args):
+ if args.model_type == "bert":
+ cfg = BertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
+ return BertForSequenceClassification(cfg)
+ elif args.model_type == "albert":
+ cfg = AlbertConfig(vocab_size=VOCAB_SIZE, num_labels=NUM_LABELS)
+ return AlbertForSequenceClassification(cfg)
+ else:
+ raise RuntimeError
+
+
+def model_size_formatter(numel: int) -> str:
+ GB_SIZE = 10**9
+ MB_SIZE = 10**6
+ KB_SIZE = 10**3
+ if numel >= GB_SIZE:
+ return f'{numel / GB_SIZE:.1f}B'
+ elif numel >= MB_SIZE:
+ return f'{numel / MB_SIZE:.1f}M'
+ elif numel >= KB_SIZE:
+ return f'{numel / KB_SIZE:.1f}K'
+ else:
+ return str(numel)
+
+
+def set_cpu_maximum_parallelism():
+ conf_str = torch.__config__.parallel_info()
+ inter_str = conf_str.split("hardware_concurrency() : ")[1]
+ max_concurrency = inter_str.split('\n')[0]
+ os.environ["OMP_NUM_THREADS"] = max_concurrency
+ print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
+
+
+def main():
+ # version check
+ # this example is supposed to work for versions greater than 0.2.0
+ assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
+
+ set_cpu_maximum_parallelism()
+ args = parse_args()
+
+ # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
+ if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
+ raise TypeError(f"{args.distplan} is error")
+
+ # batch size per DP degree
+ BATCH_SIZE = args.batch_size
+
+ NUM_STEPS = args.train_step
+
+ WARMUP_STEPS = 1
+ assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
+ assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
+ PROF_FLAG = False # The flag of profiling, False by default
+
+ disable_existing_loggers()
+ colossalai.launch_from_torch(config={})
+
+ logger = get_dist_logger()
+ logger.info(f" {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
+
+ torch.manual_seed(123)
+ if args.distplan.startswith("CAI"):
+ # all param must use the same process group.
+ world_size = torch.distributed.get_world_size()
+
+ # build a base-bert model
+ with ColoInitContext(device=get_current_device(), dtype=torch.half):
+ model = model_builder(args)
+ # model = BertForSequenceClassification(BertConfig(vocal_size = VOCAB_SIZE))
+
+ # asign running configurations
+ gemini_config = None
+ if args.distplan.startswith("CAI_ZeRO"):
+ optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
+ elif args.distplan == "CAI_Gemini":
+ gemini_config = dict(strict_ddp_mode=True,
+ device=get_current_device(),
+ placement_policy=args.placement,
+ pin_memory=True,
+ hidden_dim=model.config.hidden_size,
+ search_range_mb=128)
+ optim_config = dict(gpu_margin_mem_ratio=0.)
+ else:
+ raise RuntimeError
+
+ # build a highly optimized gpu/cpu optimizer
+ optimizer = HybridAdam(model.parameters(), lr=1e-3)
+
+ if args.distplan == "CAI_ZeRO1":
+ zero_stage = 1
+ elif args.distplan == "CAI_ZeRO2":
+ zero_stage = 2
+ elif args.distplan == "CAI_Gemini":
+ zero_stage = 3
+ else:
+ raise RuntimeError
+
+ # wrap your model and optimizer
+ model = zero_model_wrapper(model, zero_stage, gemini_config)
+ optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
+
+ logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
+ elif args.distplan.startswith("Pytorch"):
+ model = model_builder(args).cuda()
+ model = DDP(model)
+ if args.distplan.endswith("DDP"):
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+ elif args.distplan.endswith("ZeRO"):
+ from torch.distributed.optim import ZeroRedundancyOptimizer
+ optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
+ else:
+ raise RuntimeError
+
+ # model is shared after TP
+ numel = get_model_size(model)
+ logger.info(f"the size of testing model size is {model_size_formatter(numel)}.")
+ logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
+
+ # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
+ # = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
+ # = batch_per_DP_group * numel * seq_len * 8
+ get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
+
+ torch.cuda.synchronize()
+ model.train()
+ tflops_list = []
+
+ def train_step():
+ # we just use randomly generated data here
+ input_ids, labels = get_bert_data(BATCH_SIZE,
+ SEQ_LEN,
+ VOCAB_SIZE,
+ NUM_LABELS,
+ device=torch.cuda.current_device())
+ optimizer.zero_grad()
+
+ start = time()
+ outputs = model(input_ids, labels=labels)
+ loss, logits = outputs[:2]
+ torch.cuda.synchronize()
+ fwd_end = time()
+ fwd_time = fwd_end - start
+ logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
+
+ if args.distplan.startswith("CAI"):
+ optimizer.backward(loss)
+ elif args.distplan.startswith("Pytorch"):
+ loss.backward()
+ else:
+ raise RuntimeError
+
+ torch.cuda.synchronize()
+ bwd_end = time()
+ bwd_time = bwd_end - fwd_end
+ logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
+
+ optimizer.step()
+ torch.cuda.synchronize()
+ optim_time = time() - bwd_end
+ step_time = time() - start
+ logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
+
+ step_tflops = get_tflops_func(step_time)
+ logger.info(
+ f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
+ ranks=[0],
+ )
+ if n >= WARMUP_STEPS:
+ tflops_list.append(step_tflops)
+
+ demo_profiler = get_profile_context(PROF_FLAG,
+ WARMUP_STEPS,
+ NUM_STEPS - WARMUP_STEPS,
+ save_dir=f"profile/{get_time_stamp()}-demo")
+
+ with demo_profiler as prof:
+ for n in range(NUM_STEPS):
+ train_step()
+ prof.step()
+
+ tflops_list.sort()
+ median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
+ logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
+ torch.cuda.synchronize()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/gpt/README.md b/examples/language/gpt/README.md
index 8fdf6be3b6d9..10d6c2ddd5d7 100644
--- a/examples/language/gpt/README.md
+++ b/examples/language/gpt/README.md
@@ -19,11 +19,8 @@ conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
-### Install [Colossal-AI v0.1.12](https://colossalai.org/download/) From Official Website
+### [Install Colossal-AI](https://github.com/hpcaitech/ColossalAI#installation)
-```bash
-pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
-```
### Install requirements
@@ -31,31 +28,42 @@ pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
pip install -r requirements.txt
```
-This is just an example that we download PyTorch=1.12.0, CUDA=11.6 and colossalai=0.1.12+torch1.12cu11.3. You can download another version of PyTorch and its corresponding ColossalAI version. Just make sure that the version of ColossalAI is at least 0.1.10, PyTorch is at least 1.8.1 and transformers is at least 4.231.
+This is just an example that we download PyTorch=1.12.0, CUDA=11.6 and colossalai. You can download another version of PyTorch and its corresponding ColossalAI version. Just make sure that the version of ColossalAI is at least 0.1.10, PyTorch is at least 1.8.1 and transformers is at least 4.231.
If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal-AI>=0.1.12.
## Dataset
-For simplicity, the input data is randonly generated here.
+For simplicity, the input data is randomly generated here.
## Training
-We provide two solutions. One utilizes the hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism.
-The other one uses Pipeline Parallelism Only.
-In the future, we are going merge them together and they can be used orthogonally to each other.
+We provide two stable solutions.
+One utilizes the Gemini to implement hybrid parallel strategies of Gemini, DDP/ZeRO, and Tensor Parallelism for a huggingface GPT model.
+The other one use [Titans](https://github.com/hpcaitech/Titans), a distributed executed model zoo maintained by ColossalAI,to implement the hybrid parallel strategies of TP + ZeRO + PP.
+
+We recommend using Gemini to qucikly run your model in a distributed manner.
+It doesn't require significant changes to the model structures, therefore you can apply it on a new model easily.
+And use Titans as an advanced weapon to pursue a more extreme performance.
+Titans has included the some typical models, such as Vit and GPT.
+However, it requires some efforts to start if facing a new model structure.
### GeminiDPP/ZeRO + Tensor Parallelism
```bash
bash run_gemini.sh
```
-The `train_gpt_demo.py` provides three distributed plans, you can choose the plan you want in `run_gemini.sh`. The Colossal-AI leverages Tensor Parallel and Gemini + ZeRO DDP.
+The `train_gpt_demo.py` provides three distributed plans (except ones already provided by PyTorch), you can choose the plan you want in `run_gemini.sh`. The CAI_Gemini leverages Tensor Parallel and Gemini + ZeRO DDP. For their differences, you may check out the answer to issue [here](https://github.com/hpcaitech/ColossalAI/issues/2590#issuecomment-1418766581).
+
+- ZeRO1 (CAI_ZeRO1)
+- ZeRO2 (CAI_ZeRO2)
+- Gemini + ZeRO DDP (CAI_Gemini)
+- Pytorch DDP (Pytorch_DDP)
+- Pytorch ZeRO (Pytorch_ZeRO)
-- Colossal-AI
-- ZeRO1 (Colossal-AI)
-- ZeRO2 (Colossal-AI)
-- Pytorch DDP
-- Pytorch ZeRO
+### Titans (Tensor Parallelism) + ZeRO + Pipeline Parallelism
+Titans provides a customized GPT model, which uses distributed operators as building blocks.
+In [./titans/README.md], we provide a hybrid parallelism of ZeRO, TP and PP.
+You can switch parallel strategies using a config file.
## Performance
diff --git a/examples/language/gpt/experiments/auto_offload/README.md b/examples/language/gpt/experiments/auto_offload/README.md
new file mode 100644
index 000000000000..a0d252119056
--- /dev/null
+++ b/examples/language/gpt/experiments/auto_offload/README.md
@@ -0,0 +1,37 @@
+# Auto-Offload Demo with GPT2
+
+## Requirements
+
+Before you can launch training, you need to install the following requirements.
+
+### Install PyTorch
+
+```bash
+#conda
+conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
+#pip
+pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
+```
+
+### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website
+
+```bash
+pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
+```
+
+### Install transformers
+
+```bash
+pip install transformers
+```
+
+## Dataset
+
+For simplicity, the input data is randonly generated here.
+
+## Training
+
+```bash
+#Run the auto offload on GPT with default setting and a dummy dataset.
+bash run.sh
+```
diff --git a/examples/language/gpt/experiments/auto_offload/model_zoo.py b/examples/language/gpt/experiments/auto_offload/model_zoo.py
new file mode 100644
index 000000000000..35e44608f810
--- /dev/null
+++ b/examples/language/gpt/experiments/auto_offload/model_zoo.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+from transformers import GPT2Config, GPT2LMHeadModel
+
+class GPTLMModel(nn.Module):
+
+ def __init__(self,
+ hidden_size=768,
+ num_layers=12,
+ num_attention_heads=12,
+ max_seq_len=1024,
+ vocab_size=50257):
+ super().__init__()
+ self.model = GPT2LMHeadModel(
+ GPT2Config(n_embd=hidden_size,
+ n_layer=num_layers,
+ n_head=num_attention_heads,
+ n_positions=max_seq_len,
+ n_ctx=max_seq_len,
+ vocab_size=vocab_size))
+
+ def forward(self, input_ids, attention_mask):
+ # Only return lm_logits
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
+
+
+class GPTLMLoss(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.loss_fn = nn.CrossEntropyLoss()
+
+ def forward(self, logits, labels):
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+def get_gpt2_components(model_type: str, batch_size: int):
+ vocab_size = 1024
+ seq_len = 8
+
+ def gpt2_model_builder():
+ if model_type == "gpt2_medium":
+ return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16)
+ elif model_type == "gpt2_xl":
+ return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32)
+ elif model_type == "gpt2_10b":
+ return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16)
+ elif model_type == "gpt2_14b":
+ return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16)
+ elif model_type == "gpt2_20b":
+ return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16)
+ elif model_type == "gpt2_24b":
+ return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16)
+ else:
+ raise TypeError(f"model_builder {model_type}")
+
+ def gpt2_data_gen(device="cuda"):
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
+ attention_mask = torch.ones_like(input_ids, device=device)
+ kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
+ return kwargs
+
+ return gpt2_model_builder, gpt2_data_gen
\ No newline at end of file
diff --git a/examples/language/gpt/experiments/auto_offload/requirements.txt b/examples/language/gpt/experiments/auto_offload/requirements.txt
new file mode 100644
index 000000000000..3ebde8d460aa
--- /dev/null
+++ b/examples/language/gpt/experiments/auto_offload/requirements.txt
@@ -0,0 +1,2 @@
+colossalai >= 0.1.12
+torch >= 1.8.1
\ No newline at end of file
diff --git a/examples/language/gpt/experiments/auto_offload/run.sh b/examples/language/gpt/experiments/auto_offload/run.sh
new file mode 100644
index 000000000000..6a272ec442ab
--- /dev/null
+++ b/examples/language/gpt/experiments/auto_offload/run.sh
@@ -0,0 +1,8 @@
+export BATCH_SIZE=${BATCH_SIZE:-64}
+export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
+export MEMORY_BUDGET=${MEMORY_BUDGET:-16}
+export SOLVER_TYPE=${SOLVER_TYPE:-"asyn"}
+
+mkdir -p offload_logs
+
+python train_gpt_offload.py --model_type=${MODEL_TYPE} --memory_budget=${MEMORY_BUDGET} --solver_type=${SOLVER_TYPE} --batch_size=${BATCH_SIZE} 2>&1 | tee ./offload_logs/${MODEL_TYPE}_bs_${BATCH_SIZE}_st_${SOLVER_TYPE}.log
diff --git a/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
new file mode 100644
index 000000000000..729d1ce4456b
--- /dev/null
+++ b/examples/language/gpt/experiments/auto_offload/train_gpt_offload.py
@@ -0,0 +1,94 @@
+import time
+import pytest
+import argparse
+from functools import partial
+
+import torch
+from torch.utils._pytree import tree_map
+import torch.multiprocessing as mp
+
+import colossalai
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.fx.profiler import parameter_size
+from colossalai.utils import free_port, get_current_device
+from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
+from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
+from colossalai.auto_parallel.offload.solver import NOT_NVML
+from model_zoo import get_gpt2_components, GPTLMLoss
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model_type', type=str, default="gpt2_medium")
+ parser.add_argument('--batch_size', type=int, default=64)
+ parser.add_argument('--solver_type', type=str, default='asyn')
+ parser.add_argument('--memory_budget', type=float, default=16)
+ return parser.parse_args()
+
+@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
+def train_gpt(args):
+ memory_budget = args.memory_budget * 1024 * 1024 * 1024
+ solver_type = args.solver_type
+ model_type = args.model_type
+ batch_size = args.batch_size
+
+ # build model
+ model_builder, data_gen = get_gpt2_components(model_type=model_type, batch_size=batch_size)
+ label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
+ criterion = GPTLMLoss()
+
+ start_time = time.time()
+ model = model_builder()
+ model.train()
+ param_size = parameter_size(model) / 1024 ** 2 / 2
+ init_time = time.time() - start_time
+ print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
+
+ data_args = data_gen(device="cpu")
+ wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
+ data_args = tree_map(wrap_fn, data_args)
+ start_time = time.time()
+ model = memory_optimize(model, data_args, memory_budget, solver_type)
+ solver_time = time.time() - start_time
+ print(f"solver_time={solver_time:.3f} s")
+
+ hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
+ optim = AMPOptimizer(hybrid_optimizer, model)
+
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.cuda.reset_peak_memory_stats()
+
+ time_list = []
+ data_args = data_gen(device="cuda")
+ data_args = tree_map(wrap_fn, data_args)
+ for step in range(10):
+ optim.zero_grad()
+ torch.cuda.synchronize()
+ start_time = time.time()
+ loss = criterion(model(**data_args), label)
+ optim.backward(loss)
+ torch.cuda.synchronize()
+ time_list.append(time.time() - start_time)
+ optim.step()
+
+ torch.cuda.synchronize()
+
+ exec_time = sum(sorted(time_list)[:5]) / 5
+ runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
+ runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
+ print(f'solver_type: {solver_type} | model_type: {model_type}')
+ print(
+ f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
+ f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
+ )
+ print(time_list)
+
+def run(rank, world_size, port, args):
+ config = {}
+ colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ train_gpt(args)
+
+if __name__ == '__main__':
+ args = parse_args()
+ run_func = partial(run, world_size=1, port=free_port(), args=args)
+ mp.spawn(run_func, nprocs=1)
diff --git a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
index 85c8d64d7809..6ceb7fd87c0a 100644
--- a/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
+++ b/examples/language/gpt/experiments/auto_parallel/auto_parallel_with_gpt.py
@@ -16,14 +16,14 @@
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
-BATCH_SIZE = 8
-SEQ_LENGTH = 128
-HIDDEN_DIM = 3072
+BATCH_SIZE = 16
+SEQ_LENGTH = 1024
+HIDDEN_DIM = 4096
NUM_HEADS = 16
-NUM_LAYERS = 1
+NUM_LAYERS = 4
VOCAB_SIZE = 50257
NUM_STEPS = 10
-FP16 = False
+FP16 = True
def get_cpu_mem():
@@ -40,7 +40,7 @@ def get_mem_info(prefix=''):
def get_tflops(model_numel, batch_size, seq_len, step_time):
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
- return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
+ return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 8
# Randomly Generated Data
@@ -66,13 +66,7 @@ def main():
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
}
- # Both device mesh initialization and model initialization will be integrated into autoparallelize
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
- # Enable auto-parallel
- gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True)
+ gm, solution = autoparallelize(model, meta_input_sample, return_solution=True)
# print solution on rank 0
if gpc.get_global_rank() == 0:
diff --git a/examples/language/gpt/experiments/auto_parallel/requirements.txt b/examples/language/gpt/experiments/auto_parallel/requirements.txt
index ff046ad1cae9..1b2561f098d5 100644
--- a/examples/language/gpt/experiments/auto_parallel/requirements.txt
+++ b/examples/language/gpt/experiments/auto_parallel/requirements.txt
@@ -1,4 +1,4 @@
colossalai >= 0.1.12
torch >= 1.8.1
-transformers >= 4.231
+transformers >= 4.23.1
PuLP >= 2.7.0
diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt
new file mode 100644
index 000000000000..7b8cd7edd11e
Binary files /dev/null and b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_12_layers.pt differ
diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt
new file mode 100644
index 000000000000..9b431a45baba
Binary files /dev/null and b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_1_layers.pt differ
diff --git a/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt
new file mode 100644
index 000000000000..79a448c1b06f
Binary files /dev/null and b/examples/language/gpt/experiments/auto_parallel/saved_solution/solution_4_layers.pt differ
diff --git a/examples/language/gpt/experiments/pipeline_parallel/requirements.txt b/examples/language/gpt/experiments/pipeline_parallel/requirements.txt
new file mode 100644
index 000000000000..137a69e80498
--- /dev/null
+++ b/examples/language/gpt/experiments/pipeline_parallel/requirements.txt
@@ -0,0 +1,2 @@
+colossalai >= 0.1.12
+torch >= 1.8.1
diff --git a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
index 79efa61b0783..ad69888b8cc8 100644
--- a/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
+++ b/examples/language/gpt/experiments/pipeline_parallel/train_gpt_pp.py
@@ -8,11 +8,16 @@
from tqdm import tqdm
from colossalai.fx import ColoTracer
-from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass
+from colossalai.fx.passes.adding_split_node_pass import (
+ avgnode_split_pass,
+ gpipe_dp_split_pass,
+ split_with_split_nodes_pass,
+)
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.middleware.adaptor import get_fx_topology
-from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
+from colossalai.pipeline.rpc._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine
from colossalai.pipeline.rpc.utils import rpc_run
@@ -55,13 +60,25 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
-def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
+# Create annotated model which is noted where to be splitted.
+def get_annotated_model(model, data_kwargs, num_stages, num_microbatches):
tracer = ColoTracer()
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
- annotated_model = avgnode_split_pass(gm, stage_num)
+ interp_meta_args = tuple([v.to('meta') for k, v in data_kwargs.items()])
+ interp = MetaInfoProp(gm)
+ interp.run(*interp_meta_args)
+
+ #annotated_model = avgnode_split_pass(gm, num_stages)
+ annotated_model = gpipe_dp_split_pass(gm, num_stages, num_microbatches, mode='block', block_limit=0.01)
+
+ return annotated_model
+
+
+def create_partition_module(pp_rank: int, num_stages: int, model, data_kwargs, num_microbatches):
+ annotated_model = get_annotated_model(model, data_kwargs, num_stages, num_microbatches)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
topo = get_fx_topology(top_module)
for submodule in split_submodules:
@@ -70,8 +87,8 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
return split_submodules[pp_rank + 1]
-def partition(model, data_kwargs, pp_rank: int, chunk: int, stage_num: int):
- module = create_partition_module(pp_rank, stage_num, model, data_kwargs)
+def partition(model, data_kwargs, num_microbatches, pp_rank: int, chunk: int, stage_num: int):
+ module = create_partition_module(pp_rank, stage_num, model, data_kwargs, num_microbatches)
return module
@@ -103,24 +120,26 @@ def run_master(args):
warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
# create model
+ logger.info(f'start model_builder')
model = model_builder(model_type)(checkpoint=False)
+ logger.info(f'end model_builder')
# set 1f1b pipeline engine
- pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs),
- stage_num=stage_num,
- num_microbatches=num_microbatches,
- device=device,
- chunk=1,
- criterion=criterion,
- metric=None,
- checkpoint=False)
+ pp_engine = FillDrainPipelineEngine(partition_fn=partial(partition, model, warmup_data_kwargs, num_microbatches),
+ stage_num=stage_num,
+ num_microbatches=num_microbatches,
+ device=device,
+ chunk=1,
+ criterion=criterion,
+ metric=None,
+ checkpoint=False)
partition_numels = pp_engine.remote_numels()
for rank, numel in partition_numels.items():
logger.info(f'{rank=} numel in the partition:{numel}')
# build optim
- pp_engine.initialize_optimizer(HybridAdam, lr=1e-3)
+ pp_engine.initialize_optimizer(torch.optim.Adam, lr=1e-3)
ranks_tflops = {}
for n in range(NUM_STEPS):
diff --git a/examples/language/gpt/gemini/benchmark_gemini.sh b/examples/language/gpt/gemini/benchmark_gemini.sh
index 13086666eefd..3a42e13645f6 100644
--- a/examples/language/gpt/gemini/benchmark_gemini.sh
+++ b/examples/language/gpt/gemini/benchmark_gemini.sh
@@ -1,18 +1,20 @@
for MODEL_TYPE in "gpt2_medium"; do
- for BATCH_SIZE in 16; do
- for GPUNUM in 1 2 4 8; do
- for TPDEGREE in 1 2 4 8; do
- if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
- continue
- fi
- for PLACEMENT in "cpu" "auto"; do
- echo "****************** Begin ***************************"
- echo "* benchmrking MODEL_TYPE ${MODEL_TYPE} BS ${BATCH_SIZE} BS ${BS} GPUNUM ${GPUNUM} TPDEGREE ${TPDEGREE} PLACEMENT ${PLACEMENT}"
- MODEL_TYPE=${MODEL_TYPE} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
- bash ./gemini/run_gemini.sh
- echo "****************** Finished ***************************"
- echo ""
- echo ""
+ for DISTPLAN in "CAI_Gemini"; do
+ for BATCH_SIZE in 16; do
+ for GPUNUM in 1 2 4 8; do
+ for TPDEGREE in 1 2 4 8; do
+ if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
+ continue
+ fi
+ for PLACEMENT in "cpu" "auto"; do
+ echo "****************** Begin ***************************"
+ echo "+ benchmrking MODEL ${MODEL_TYPE} DISTPLAN ${DISTPLAN} GPU ${GPUNUM} BS ${BATCH_SIZE} TP ${TPDEGREE} POLICY ${PLACEMENT}"
+ MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
+ bash ./run_gemini.sh
+ echo "****************** Finished ***************************"
+ echo ""
+ echo ""
+ done
done
done
done
diff --git a/examples/language/gpt/gemini/commons/model_zoo.py b/examples/language/gpt/gemini/commons/model_zoo.py
index c31b3fa6d103..65124d9e4884 100644
--- a/examples/language/gpt/gemini/commons/model_zoo.py
+++ b/examples/language/gpt/gemini/commons/model_zoo.py
@@ -53,6 +53,14 @@ def gpt2_24b(checkpoint=True):
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
+def gpt2_30b(checkpoint=True):
+ return GPTLMModel(hidden_size=8192, num_layers=37, num_attention_heads=16, checkpoint=checkpoint)
+
+
+def gpt2_40b(checkpoint=True):
+ return GPTLMModel(hidden_size=8192, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
+
+
def model_builder(model_size: str) -> callable:
if model_size == "gpt2_medium":
return gpt2_medium
@@ -66,6 +74,10 @@ def model_builder(model_size: str) -> callable:
return gpt2_20b
elif model_size == "gpt2_24b":
return gpt2_24b
+ elif model_size == "gpt2_30b":
+ return gpt2_30b
+ elif model_size == "gpt2_40b":
+ return gpt2_40b
else:
raise TypeError(f"model_builder {model_size}")
diff --git a/examples/language/gpt/gemini/commons/utils.py b/examples/language/gpt/gemini/commons/utils.py
index 782f546dc26c..7bd098c1927c 100644
--- a/examples/language/gpt/gemini/commons/utils.py
+++ b/examples/language/gpt/gemini/commons/utils.py
@@ -1,4 +1,17 @@
+import time
+from contextlib import nullcontext
+
import torch
+from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
+
+
+class DummyProfiler:
+
+ def __init__(self):
+ self.step_number = 0
+
+ def step(self):
+ self.step_number += 1
# Randomly Generated Data
@@ -10,3 +23,19 @@ def get_data(batch_size, seq_len, vocab_size):
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
+
+
+def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
+ if enable_flag:
+ return profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
+ on_trace_ready=tensorboard_trace_handler(save_dir),
+ record_shapes=True,
+ profile_memory=True)
+ else:
+ return nullcontext(DummyProfiler())
+
+
+def get_time_stamp():
+ cur_time = time.strftime("%d-%H:%M", time.localtime())
+ return cur_time
diff --git a/examples/language/gpt/gemini/requirements.txt b/examples/language/gpt/gemini/requirements.txt
new file mode 100644
index 000000000000..137a69e80498
--- /dev/null
+++ b/examples/language/gpt/gemini/requirements.txt
@@ -0,0 +1,2 @@
+colossalai >= 0.1.12
+torch >= 1.8.1
diff --git a/examples/language/gpt/gemini/run_gemini.sh b/examples/language/gpt/gemini/run_gemini.sh
index ad577c350d39..ad4e9419c1bd 100644
--- a/examples/language/gpt/gemini/run_gemini.sh
+++ b/examples/language/gpt/gemini/run_gemini.sh
@@ -1,17 +1,23 @@
set -x
-# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
-export DISTPAN=${DISTPAN:-"colossalai"}
+# distplan in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]
+export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
-# The following options only valid when DISTPAN="colossalai"
+# The following options only valid when DISTPLAN="colossalai"
export GPUNUM=${GPUNUM:-1}
export TPDEGREE=${TPDEGREE:-1}
export PLACEMENT=${PLACEMENT:-"cpu"}
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
-
+export TRAIN_STEP=${TRAIN_STEP:-10}
# export PYTHONPATH=$PWD:$PYTHONPATH
+if [ ${USE_SHARD_INIT} = "True" ]; then
+ USE_SHARD_INIT="--shardinit"
+else
+ USE_SHARD_INIT=""
+fi
+
mkdir -p gemini_logs
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
@@ -19,6 +25,7 @@ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--model_type=${MODEL_TYPE} \
--batch_size=${BATCH_SIZE} \
--placement=${PLACEMENT} \
---shardinit=${USE_SHARD_INIT} \
---distplan=${DISTPAN} \
-2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
+${USE_SHARD_INIT} \
+--distplan=${DISTPLAN} \
+--train_step=${TRAIN_STEP} \
+2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
diff --git a/examples/language/gpt/gemini/test_ci.sh b/examples/language/gpt/gemini/test_ci.sh
new file mode 100644
index 000000000000..6079d5ed615b
--- /dev/null
+++ b/examples/language/gpt/gemini/test_ci.sh
@@ -0,0 +1,35 @@
+set -x
+$(cd `dirname $0`;pwd)
+export TRAIN_STEP=4
+
+for MODEL_TYPE in "gpt2_medium"; do
+ for DISTPLAN in "colossalai"; do
+ for BATCH_SIZE in 2; do
+ for GPUNUM in 1 4; do
+ for TPDEGREE in 1 2; do
+ if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
+ continue
+ fi
+ for PLACEMENT in "cpu" "auto"; do
+ MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
+ bash ./run_gemini.sh
+ done
+ done
+ done
+ done
+ done
+
+ for DISTPLAN in "zero1" "zero2"; do
+ for BATCH_SIZE in 2; do
+ for GPUNUM in 1 4; do
+ for TPDEGREE in 1; do
+ if [ ${TPDEGREE} -gt ${GPUNUM} ]; then
+ continue
+ fi
+ MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
+ bash ./run_gemini.sh
+ done
+ done
+ done
+ done
+done
diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py
index 29f8c8ef1215..f46226bce2b5 100644
--- a/examples/language/gpt/gemini/train_gpt_demo.py
+++ b/examples/language/gpt/gemini/train_gpt_demo.py
@@ -6,32 +6,27 @@
import torch
import torch.nn as nn
from commons.model_zoo import model_builder
-from commons.utils import get_data, get_tflops
+from commons.utils import get_data, get_profile_context, get_tflops, get_time_stamp
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.nn.parallel import ZeroDDP
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
CAI_VERSION = colossalai.__version__
-if version.parse(CAI_VERSION) > version.parse("0.1.10"):
- # These are added after 0.1.10
- from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
- from colossalai.nn.parallel import GeminiDDP
- from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
-
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--distplan",
type=str,
- default='colossalai',
+ default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
)
parser.add_argument(
@@ -48,8 +43,7 @@ def parse_args():
)
parser.add_argument(
"--shardinit",
- type=bool,
- default=False,
+ action='store_true',
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
@@ -65,6 +59,13 @@ def parse_args():
default="gpt2_medium",
help="model model scale",
)
+ parser.add_argument(
+ "--train_step",
+ type=int,
+ default=10,
+ help="training iterations for test",
+ )
+
args = parser.parse_args()
return args
@@ -179,56 +180,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
param.visited = True
-# Gemini + ZeRO DDP
-def build_gemini(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
- fp16_init_scale = 2**5
- gpu_margin_mem_ratio_for_auto = 0
-
- if version.parse(CAI_VERSION) > version.parse("0.1.10"):
- model = GeminiDDP(model,
- device=get_current_device(),
- placement_policy=placement_policy,
- pin_memory=True,
- hidden_dim=model.config.n_embd,
- search_range_mb=64)
- # configure the const policy
- if placement_policy == 'const':
- model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
- # build a highly optimized cpu optimizer
- optimizer = GeminiAdamOptimizer(model,
- lr=1e-3,
- initial_scale=fp16_init_scale,
- gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
- elif version.parse("0.1.9") <= version.parse(CAI_VERSION) <= version.parse("0.1.10"):
- from colossalai.gemini import ChunkManager, GeminiManager
- from colossalai.nn.optimizer import HybridAdam
- from colossalai.zero import ZeroOptimizer
- chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 1024, filter_exlarge_params=True)
- chunk_manager = ChunkManager(chunk_size,
- pg,
- enable_distributed_storage=True,
- init_device=GeminiManager.get_default_device(placement_policy))
- gemini_manager = GeminiManager(placement_policy, chunk_manager)
- model = ZeroDDP(model, gemini_manager)
- optimizer = HybridAdam(model.parameters(), lr=1e-3)
- optimizer = ZeroOptimizer(optimizer,
- model,
- initial_scale=fp16_init_scale,
- gpu_margin_mem_ratio=gpu_margin_mem_ratio_for_auto)
- else:
- raise NotImplemented(f"CAI version {CAI_VERSION} is not supported")
- return model, optimizer
-
-
def main():
# version check
- # this example is supposed to work for versions greater than 0.1.9
- assert version.parse(CAI_VERSION) >= version.parse("0.1.9")
+ # this example is supposed to work for versions greater than 0.2.0
+ assert version.parse(CAI_VERSION) >= version.parse("0.2.0")
set_cpu_maximum_parallelism()
args = parse_args()
- if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
+ # if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
+ if args.distplan not in ["CAI_ZeRO1", "CAI_ZeRO2", "CAI_Gemini", "Pytorch_DDP", "Pytorch_ZeRO"]:
raise TypeError(f"{args.distplan} is error")
# batch size per DP degree
@@ -236,10 +197,12 @@ def main():
SEQ_LEN = 1024
VOCAB_SIZE = 50257
- NUM_STEPS = 10
+ NUM_STEPS = args.train_step
+
WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
- assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
+ assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
+ PROF_FLAG = False # The flag of profiling, False by default
disable_existing_loggers()
colossalai.launch_from_torch(config={})
@@ -251,49 +214,71 @@ def main():
criterion = GPTLMLoss()
torch.manual_seed(123)
- if args.distplan == "colossalai":
+ if args.distplan.startswith("CAI"):
# all param must use the same process group.
world_size = torch.distributed.get_world_size()
- shard_pg = ProcessGroup(tp_degree=world_size)
+ shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
+ if args.shardinit and args.distplan != "CAI_Gemini":
+ raise RuntimeError("You can only use shardinit with CAI_Gemini")
+
# build GPT model
- if version.parse(CAI_VERSION) > version.parse("0.1.10"):
- with ColoInitContext(device=get_current_device(),
- dtype=torch.half,
- default_dist_spec=default_dist_spec,
- default_pg=shard_pg):
- model = model_builder(args.model_type)(checkpoint=True)
- else:
- with ColoInitContext(device=get_current_device()):
- model = model_builder(args.model_type)(checkpoint=True)
+ with ColoInitContext(device=get_current_device(),
+ dtype=torch.half,
+ default_dist_spec=default_dist_spec,
+ default_pg=shard_pg):
+ model = model_builder(args.model_type)(checkpoint=True)
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
- tensor_parallelize(model, tp_pg)
+ # You should notice that v0.1.10 is not compatible with TP degree > 1
+ if args.tp_degree > 1:
+ tensor_parallelize(model, tp_pg)
+
+ # asign running configurations
+ gemini_config = None
+ if args.distplan.startswith("CAI_ZeRO"):
+ optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
+ elif args.distplan == "CAI_Gemini":
+ gemini_config = dict(strict_ddp_mode=args.tp_degree == 1,
+ device=get_current_device(),
+ placement_policy=args.placement,
+ pin_memory=True,
+ hidden_dim=model.config.n_embd,
+ search_range_mb=128)
+ optim_config = dict(gpu_margin_mem_ratio=0.)
+ else:
+ raise RuntimeError
- # build a Gemini model and a highly optimized cpu optimizer
- # Gemini + ZeRO DP, Note it must be used after TP
- model, optimizer = build_gemini(model, tp_pg, args.placement)
+ # build a highly optimized gpu/cpu optimizer
+ optimizer = HybridAdam(model.parameters(), lr=1e-3)
+
+ if args.distplan == "CAI_ZeRO1":
+ zero_stage = 1
+ elif args.distplan == "CAI_ZeRO2":
+ zero_stage = 2
+ elif args.distplan == "CAI_Gemini":
+ zero_stage = 3
+ else:
+ raise RuntimeError
+
+ # wrap your model and optimizer
+ model = zero_model_wrapper(model, zero_stage, gemini_config)
+ optimizer = zero_optim_wrapper(model, optimizer, optim_config=optim_config)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
- else:
+ elif args.distplan.startswith("Pytorch"):
+ assert args.tp_degree == 1, "The degree of TP should be 1 for DDP examples."
model = model_builder(args.model_type)(checkpoint=True).cuda()
-
- if args.distplan.startswith("torch"):
model = DDP(model)
- if args.distplan.endswith("ddp"):
- optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
- elif args.distplan.endswith("zero"):
+ if args.distplan.endswith("DDP"):
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+ elif args.distplan.endswith("ZeRO"):
from torch.distributed.optim import ZeroRedundancyOptimizer
- optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
- elif args.distplan.startswith("zero"):
- partition_flag = args.distplan == "zero2"
- optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
- optimizer = LowLevelZeroOptimizer(optimizer,
- overlap_communication=True,
- partition_grad=partition_flag,
- verbose=True)
+ optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=1e-3)
+ else:
+ raise RuntimeError
# model is shared after TP
numel = get_model_size(model)
@@ -308,7 +293,8 @@ def main():
torch.cuda.synchronize()
model.train()
tflops_list = []
- for n in range(NUM_STEPS):
+
+ def train_step():
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
@@ -321,17 +307,18 @@ def main():
fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
- if args.distplan in ["colossalai", "zero1", "zero2"]:
+ if args.distplan.startswith("CAI"):
optimizer.backward(loss)
- elif args.distplan in ["torch_ddp", "torch_zero"]:
+ elif args.distplan.startswith("Pytorch"):
loss.backward()
+ else:
+ raise RuntimeError
+
torch.cuda.synchronize()
bwd_end = time()
bwd_time = bwd_end - fwd_end
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
- if args.distplan in ["zero1", "zero2"]:
- optimizer.sync_grad()
optimizer.step()
torch.cuda.synchronize()
optim_time = time() - bwd_end
@@ -346,6 +333,16 @@ def main():
if n >= WARMUP_STEPS:
tflops_list.append(step_tflops)
+ demo_profiler = get_profile_context(PROF_FLAG,
+ WARMUP_STEPS,
+ NUM_STEPS - WARMUP_STEPS,
+ save_dir=f"profile/{get_time_stamp()}-demo")
+
+ with demo_profiler as prof:
+ for n in range(NUM_STEPS):
+ train_step()
+ prof.step()
+
tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
diff --git a/examples/language/gpt/requirements.txt b/examples/language/gpt/requirements.txt
index e1f131468fb8..ef58bb76bfc8 100644
--- a/examples/language/gpt/requirements.txt
+++ b/examples/language/gpt/requirements.txt
@@ -1 +1,2 @@
transformers >= 4.23
+colossalai
diff --git a/examples/language/gpt/test_ci.sh b/examples/language/gpt/test_ci.sh
index ad0cfa325d37..d67c17229e71 100644
--- a/examples/language/gpt/test_ci.sh
+++ b/examples/language/gpt/test_ci.sh
@@ -1,16 +1,2 @@
-pip install -r requirements.txt
-
-# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
-export DISTPAN="colossalai"
-
-# The following options only valid when DISTPAN="colossalai"
-export TPDEGREE=2
-export GPUNUM=4
-export PLACEMENT='cpu'
-export USE_SHARD_INIT=False
-export BATCH_SIZE=8
-export MODEL_TYPE="gpt2_medium"
-
-
-mkdir -p logs
-torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log
+set -x
+cd gemini && bash test_ci.sh
diff --git a/examples/language/gpt/titans/LICENSE b/examples/language/gpt/titans/LICENSE
new file mode 100644
index 000000000000..261eeb9e9f8b
--- /dev/null
+++ b/examples/language/gpt/titans/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/examples/language/gpt/titans/README.md b/examples/language/gpt/titans/README.md
new file mode 100644
index 000000000000..e954f35fae0d
--- /dev/null
+++ b/examples/language/gpt/titans/README.md
@@ -0,0 +1,48 @@
+# Run GPT With Colossal-AI
+
+## How to Prepare Webtext Dataset
+
+You can download the preprocessed sample dataset for this demo via our [Google Drive sharing link](https://drive.google.com/file/d/1QKI6k-e2gJ7XgS8yIpgPPiMmwiBP_BPE/view?usp=sharing).
+
+
+You can also avoid dataset preparation by using `--use_dummy_dataset` during running.
+
+## Run this Demo
+
+Use the following commands to install prerequisites.
+
+```bash
+# assuming using cuda 11.3
+pip install -r requirements.txt
+```
+
+Use the following commands to execute training.
+
+```Bash
+#!/usr/bin/env sh
+# if you want to use real dataset, then remove --use_dummy_dataset
+# export DATA=/path/to/small-gpt-dataset.json'
+
+# run on a single node
+colossalai run --nproc_per_node= train_gpt.py --config configs/ --from_torch --use_dummy_dataset
+
+# run on multiple nodes
+colossalai run --nproc_per_node= \
+ --master_addr \
+ --master_port \
+ --hosts \
+ train_gpt.py \
+ --config configs/ \
+ --from_torch \
+ --use_dummy_dataset
+
+# run on multiple nodes with slurm
+srun python \
+ train_gpt.py \
+ --config configs/ \
+ --host \
+ --use_dummy_dataset
+
+```
+
+You can set the `` to any file in the `configs` folder. To simply get it running, you can start with `gpt_small_zero3_pp1d.py` on a single node first. You can view the explanations in the config file regarding how to change the parallel setting.
diff --git a/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py
new file mode 100644
index 000000000000..7bf53303948a
--- /dev/null
+++ b/examples/language/gpt/titans/configs/gpt2_small_zero3_pp1d.py
@@ -0,0 +1,31 @@
+from model import GPT2_small_pipeline_hybrid
+
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.zero.shard_utils import TensorShardStrategy
+
+BATCH_SIZE = 8
+NUM_EPOCHS = 10
+SEQ_LEN = 1024
+NUM_MICRO_BATCHES = 4
+HIDDEN_SIZE = 768
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
+
+# if you do no want zero, just comment out this dictionary
+zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
+ optimizer_config=dict(initial_scale=2**5))
+
+optimizer = dict(
+ type=HybridAdam,
+ lr=0.000015,
+ weight_decay=1e-2,
+)
+
+model = dict(type=GPT2_small_pipeline_hybrid, checkpoint=True, num_chunks=1)
+
+# pipeline parallel: modify integer value for the number of pipeline stages
+# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node
+# for the current model implementation, mode can only be 1D or None
+parallel = dict(
+ pipeline=1,
+ tensor=dict(size=2, mode='1d'),
+)
diff --git a/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py
new file mode 100644
index 000000000000..9f9816b3004f
--- /dev/null
+++ b/examples/language/gpt/titans/configs/gpt3_zero3_pp1d.py
@@ -0,0 +1,31 @@
+from model import GPT3_pipeline_hybrid
+
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.zero.shard_utils import TensorShardStrategy
+
+BATCH_SIZE = 192
+NUM_EPOCHS = 60
+SEQ_LEN = 2048
+NUM_MICRO_BATCHES = 192
+HIDDEN_SIZE = 12288
+TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
+
+# if you do no want zero, just comment out this dictionary
+zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
+ optimizer_config=dict(initial_scale=2**16))
+
+optimizer = dict(
+ type=HybridAdam,
+ lr=0.00015,
+ weight_decay=1e-2,
+)
+
+model = dict(type=GPT3_pipeline_hybrid, checkpoint=True, num_chunks=1)
+
+# pipeline parallel: modify integer value for the number of pipeline stages
+# tensor parallel: modify size to set the tensor parallel size, usually the number of GPUs per node
+# for the current model implementation, mode can only be 1D or None
+parallel = dict(
+ pipeline=1,
+ tensor=dict(size=2, mode='1d'), # for the current model implementation, mode can only be 1D or None
+)
diff --git a/examples/language/gpt/titans/dataset/webtext.py b/examples/language/gpt/titans/dataset/webtext.py
new file mode 100644
index 000000000000..64f5944a97f9
--- /dev/null
+++ b/examples/language/gpt/titans/dataset/webtext.py
@@ -0,0 +1,43 @@
+import json
+import os
+from typing import Optional
+
+import torch
+from torch.utils.data import Dataset
+from transformers import GPT2Tokenizer
+
+from colossalai.registry import DATASETS
+
+
+@DATASETS.register_module
+class WebtextDataset(Dataset):
+
+ def __init__(self, path: Optional[str] = None, seq_len=1024) -> None:
+ super().__init__()
+ if path is not None:
+ root = os.path.dirname(path)
+ encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
+ if os.path.isfile(encoded_data_cache_path):
+ seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
+ if seq_len_ == seq_len:
+ self.data = data
+ self.attention_mask = attention_mask
+ return
+ raw_data = []
+ with open(path) as f:
+ for line in f.readlines():
+ raw_data.append(json.loads(line)['text'])
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
+ tokenizer.pad_token = tokenizer.unk_token
+ encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
+ self.data = encoded_data['input_ids']
+ self.attention_mask = encoded_data['attention_mask']
+ else:
+ self.data = torch.randint(0, 50257, (10240, seq_len))
+ self.attention_mask = torch.ones_like(self.data)
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, index):
+ return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]
diff --git a/examples/language/gpt/titans/model/__init__.py b/examples/language/gpt/titans/model/__init__.py
new file mode 100644
index 000000000000..eec48ef893fb
--- /dev/null
+++ b/examples/language/gpt/titans/model/__init__.py
@@ -0,0 +1,3 @@
+from .embed import vocab_parallel_cross_entropy
+from .gpt1d import *
+from .pipeline_gpt1d import *
diff --git a/examples/language/gpt/titans/model/embed.py b/examples/language/gpt/titans/model/embed.py
new file mode 100644
index 000000000000..6369b9f8c5a1
--- /dev/null
+++ b/examples/language/gpt/titans/model/embed.py
@@ -0,0 +1,599 @@
+import torch
+import torch.nn.init as init
+from torch import Tensor
+from torch import distributed as dist
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn.parameter import Parameter
+
+from colossalai.context import ParallelMode, seed
+from colossalai.core import global_context as gpc
+from colossalai.nn.layer.base_layer import ParallelLayer
+from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
+from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row
+from colossalai.nn.layer.utils import divide
+from colossalai.registry import LAYERS, LOSSES, MODELS
+from colossalai.utils import get_current_device
+
+
+class VocabParallelEmbedding(torch.nn.Module):
+ """Language model embeddings.
+
+ Arguments:
+ hidden_size: hidden size
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ init_method: weight initialization method
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(self,
+ hidden_size,
+ vocab_size,
+ max_sequence_length,
+ embedding_dropout_prob,
+ num_tokentypes=0,
+ dtype=torch.float):
+ super(VocabParallelEmbedding, self).__init__()
+
+ self.hidden_size = hidden_size
+ self.num_tokentypes = num_tokentypes
+
+ # Word embeddings (parallel).
+ self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype)
+ self._word_embeddings_key = 'word_embeddings'
+
+ # Position embedding (serial).
+ self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype)
+ self._position_embeddings_key = 'position_embeddings'
+ # Initialize the position embeddings.
+ # self.init_method(self.position_embeddings.weight)
+
+ # Token type embedding.
+ # Add this as an optional field that can be added through
+ # method call so we can load a pretrain model without
+ # token types and add them as needed.
+ self._tokentype_embeddings_key = 'tokentype_embeddings'
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype)
+ # Initialize the token-type embeddings.
+ # self.init_method(self.tokentype_embeddings.weight)
+ else:
+ self.tokentype_embeddings = None
+
+ # Embeddings dropout
+ self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+
+ def zero_parameters(self):
+ """Zero out all parameters in embedding."""
+ self.word_embeddings.weight.data.fill_(0)
+ self.word_embeddings.weight.shared = True
+ self.position_embeddings.weight.data.fill_(0)
+ self.position_embeddings.weight.shared = True
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings.weight.data.fill_(0)
+ self.tokentype_embeddings.weight.shared = True
+
+ def add_tokentype_embeddings(self, num_tokentypes):
+ """Add token-type embedding. This function is provided so we can add
+ token-type embeddings in case the pretrained model does not have it.
+ This allows us to load the model normally and then add this embedding.
+ """
+ if self.tokentype_embeddings is not None:
+ raise Exception('tokentype embeddings is already initialized')
+ if torch.distributed.get_rank() == 0:
+ print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
+ self.num_tokentypes = num_tokentypes
+ self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
+ # Initialize the token-type embeddings.
+ # self.init_method(self.tokentype_embeddings.weight)
+
+ def forward(self, input_ids, position_ids=None, tokentype_ids=None):
+ # Embeddings.
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ words_embeddings = self.word_embeddings(input_ids)
+
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+ if position_ids is None:
+ position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+ position_embeddings = self.position_embeddings(position_ids)
+
+ embeddings = words_embeddings + position_embeddings
+
+ # Dropout.
+ with seed(ParallelMode.TENSOR):
+ embeddings = self.embedding_dropout(embeddings)
+ return embeddings
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
+ """For easy load."""
+
+ state_dict_ = {}
+ state_dict_[self._word_embeddings_key] \
+ = self.word_embeddings.state_dict(destination, prefix, keep_vars)
+ state_dict_[self._position_embeddings_key] \
+ = self.position_embeddings.state_dict(
+ destination, prefix, keep_vars)
+ if self.num_tokentypes > 0:
+ state_dict_[self._tokentype_embeddings_key] \
+ = self.tokentype_embeddings.state_dict(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ # Word embedding.
+ if self._word_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._word_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'word_embeddings' in key:
+ state_dict_[key.split('word_embeddings.')[1]] \
+ = state_dict[key]
+ self.word_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Position embedding.
+ if self._position_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._position_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'position_embeddings' in key:
+ state_dict_[key.split('position_embeddings.')[1]] \
+ = state_dict[key]
+ self.position_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Tokentype embedding.
+ if self.num_tokentypes > 0:
+ state_dict_ = {}
+ if self._tokentype_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._tokentype_embeddings_key]
+ else:
+ # for backward compatibility.
+ for key in state_dict.keys():
+ if 'tokentype_embeddings' in key:
+ state_dict_[key.split('tokentype_embeddings.')[1]] \
+ = state_dict[key]
+ if len(state_dict_.keys()) > 0:
+ self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
+ else:
+ print('***WARNING*** expected tokentype embeddings in the '
+ 'checkpoint but could not find it',
+ flush=True)
+
+
+class VocabParallelEmbedding1D(torch.nn.Module):
+ """Embedding parallelized in the vocabulary dimension.
+
+ This is mainly adapted from torch.nn.Embedding and all the default
+ values are kept.
+ Arguments:
+ num_embeddings: vocabulary size.
+ embedding_dim: size of hidden state.
+ init_method: method to initialize weights.
+ """
+
+ def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None):
+ super(VocabParallelEmbedding1D, self).__init__()
+ # Keep the input dimensions.
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ # Set the details for compatibility.
+ self.padding_idx = None
+ self.max_norm = None
+ self.norm_type = 2.
+ self.scale_grad_by_freq = False
+ self.sparse = False
+ self._weight = None
+ self.tensor_model_parallel_size = gpc.tensor_parallel_size
+ # Divide the weight matrix along the vocabulary dimension.
+ self.vocab_start_index, self.vocab_end_index = \
+ VocabUtility.vocab_range_from_global_vocab_size(
+ self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D),
+ self.tensor_model_parallel_size)
+ self.num_embeddings_per_partition = self.vocab_end_index - \
+ self.vocab_start_index
+
+ # Allocate weights and initialize.
+ factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
+ self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
+ init.uniform_(self.weight, -1, 1)
+
+ def forward(self, input_):
+ if self.tensor_model_parallel_size > 1:
+ # Build the mask.
+ input_mask = (input_ < self.vocab_start_index) | \
+ (input_ >= self.vocab_end_index)
+ # Mask the input.
+ masked_input = input_.clone() - self.vocab_start_index
+ masked_input[input_mask] = 0
+ else:
+ masked_input = input_
+ # Get the embeddings.
+ output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type,
+ self.scale_grad_by_freq, self.sparse)
+ # Mask the output embedding.
+ if self.tensor_model_parallel_size > 1:
+ output_parallel[input_mask, :] = 0.0
+ # Reduce across all the model parallel GPUs.
+ output = output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
+ return output
+
+
+@LOSSES.register_module
+class vocab_parallel_cross_entropy(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, vocab_parallel_logits, target):
+ """Helper function for the cross entropy."""
+ vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous()
+ target = target[..., 1:].contiguous()
+ return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)),
+ target.view(-1))
+
+
+class _VocabParallelCrossEntropy(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, vocab_parallel_logits, target):
+
+ # Maximum value along vocab dimension across all GPUs.
+ logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
+ torch.distributed.all_reduce(logits_max,
+ op=torch.distributed.ReduceOp.MAX,
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+ # Subtract the maximum value.
+ vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
+
+ # Get the partition's vocab indices
+ get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
+ partition_vocab_size = vocab_parallel_logits.size()[-1]
+ rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
+ world_size = gpc.tensor_parallel_size
+ vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
+
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+ masked_target = target.clone() - vocab_start_index
+ masked_target[target_mask] = 0
+
+ # Get predicted-logits = logits[target].
+ # For Simplicity, we convert logits to a 2-D tensor with size
+ # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
+ logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
+ predicted_logits_1d = predicted_logits_1d.clone().contiguous()
+ predicted_logits = predicted_logits_1d.view_as(target)
+ predicted_logits[target_mask] = 0.0
+ # All reduce is needed to get the chunks from other GPUs.
+ torch.distributed.all_reduce(predicted_logits,
+ op=torch.distributed.ReduceOp.SUM,
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+
+ # Sum of exponential of logits along vocab dimension across all GPUs.
+ exp_logits = vocab_parallel_logits
+ torch.exp(vocab_parallel_logits, out=exp_logits)
+ sum_exp_logits = exp_logits.sum(dim=-1)
+ torch.distributed.all_reduce(sum_exp_logits,
+ op=torch.distributed.ReduceOp.SUM,
+ group=gpc.get_group(ParallelMode.PARALLEL_1D))
+
+ # Loss = log(sum(exp(logits))) - predicted-logit.
+ loss = torch.log(sum_exp_logits) - predicted_logits
+ loss = loss.mean()
+ # Store softmax, target-mask and masked-target for backward pass.
+ exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+ ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
+ return loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+
+ # Retreive tensors from the forward path.
+ softmax, target_mask, masked_target_1d = ctx.saved_tensors
+
+ # All the inputs have softmax as their gradient.
+ grad_input = softmax
+ # For simplicity, work with the 2D gradient.
+ partition_vocab_size = softmax.size()[-1]
+ grad_2d = grad_input.view(-1, partition_vocab_size)
+
+ # Add the gradient from matching classes.
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+ grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
+
+ # Finally elementwise multiplication with the output gradients.
+ grad_input.mul_(grad_output.unsqueeze(dim=-1))
+
+ return grad_input, None
+
+
+class VocabUtility:
+ """Split the vocabulary into `world_size` chunks amd return the
+ first and last index of the vocabulary belonging to the `rank`
+ partition: Note that indices in [fist, last)"""
+
+ @staticmethod
+ def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
+ index_f = rank * per_partition_vocab_size
+ index_l = index_f + per_partition_vocab_size
+ return index_f, index_l
+
+ @staticmethod
+ def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
+ per_partition_vocab_size = divide(global_vocab_size, world_size)
+ return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
+
+
+class VocabParallelGPTLMHead1D(ParallelLayer):
+ """
+ Language model head that shares the same parameters with the embedding matrix.
+ """
+
+ def __init__(self, embed=None, vocab_size=None, dtype=None, embed_dim=None):
+ super().__init__()
+ if embed is not None:
+ self.head = embed
+ else:
+ self.head = VocabParallelEmbedding1D(vocab_size, embed_dim, dtype=dtype)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = reduce_grad(x, ParallelMode.PARALLEL_1D)
+ x = F.linear(x, self.head.weight)
+ return x
+
+
+###################################
+
+
+class HiddenParallelEmbedding(torch.nn.Module):
+ """Language model embeddings.
+
+ Arguments:
+ hidden_size: hidden size
+ vocab_size: vocabulary size
+ max_sequence_length: maximum size of sequence. This
+ is used for positional embedding
+ embedding_dropout_prob: dropout probability for embeddings
+ init_method: weight initialization method
+ num_tokentypes: size of the token-type embeddings. 0 value
+ will ignore this embedding
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ vocab_size,
+ max_sequence_length,
+ embedding_dropout_prob,
+ dtype=torch.float,
+ padding_idx: int = 0,
+ num_tokentypes=0,
+ ):
+ super(HiddenParallelEmbedding, self).__init__()
+
+ self.hidden_size = hidden_size
+ self.num_tokentypes = num_tokentypes
+
+ # Word embeddings (parallel).
+ self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
+ self._word_embeddings_key = 'word_embeddings'
+
+ # Position embedding (serial).
+ self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
+ self._position_embeddings_key = 'position_embeddings'
+ # Initialize the position embeddings.
+ # self.init_method(self.position_embeddings.weight)
+
+ # Token type embedding.
+ # Add this as an optional field that can be added through
+ # method call so we can load a pretrain model without
+ # token types and add them as needed.
+ self._tokentype_embeddings_key = 'tokentype_embeddings'
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
+ # Initialize the token-type embeddings.
+ # self.init_method(self.tokentype_embeddings.weight)
+ else:
+ self.tokentype_embeddings = None
+
+ # Embeddings dropout
+ self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
+
+ def zero_parameters(self):
+ """Zero out all parameters in embedding."""
+ self.word_embeddings.weight.data.fill_(0)
+ self.word_embeddings.weight.shared = True
+ self.position_embeddings.weight.data.fill_(0)
+ self.position_embeddings.weight.shared = True
+ if self.num_tokentypes > 0:
+ self.tokentype_embeddings.weight.data.fill_(0)
+ self.tokentype_embeddings.weight.shared = True
+
+ def add_tokentype_embeddings(self, num_tokentypes):
+ """Add token-type embedding. This function is provided so we can add
+ token-type embeddings in case the pretrained model does not have it.
+ This allows us to load the model normally and then add this embedding.
+ """
+ if self.tokentype_embeddings is not None:
+ raise Exception('tokentype embeddings is already initialized')
+ if torch.distributed.get_rank() == 0:
+ print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
+ self.num_tokentypes = num_tokentypes
+ self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
+ # Initialize the token-type embeddings.
+ # self.init_method(self.tokentype_embeddings.weight)
+
+ def forward(self, input_ids, position_ids=None, tokentype_ids=None):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ words_embeddings = self.word_embeddings(input_ids)
+
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+ if position_ids is None:
+ position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+ position_embeddings = self.position_embeddings(position_ids)
+
+ embeddings = words_embeddings + position_embeddings
+
+ # Dropout.
+ with seed(ParallelMode.TENSOR):
+ embeddings = self.embedding_dropout(embeddings)
+ return embeddings
+
+ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
+ """For easy load."""
+
+ state_dict_ = {}
+ state_dict_[self._word_embeddings_key] \
+ = self.word_embeddings.state_dict(destination, prefix, keep_vars)
+ state_dict_[self._position_embeddings_key] \
+ = self.position_embeddings.state_dict(
+ destination, prefix, keep_vars)
+ if self.num_tokentypes > 0:
+ state_dict_[self._tokentype_embeddings_key] \
+ = self.tokentype_embeddings.state_dict(
+ destination, prefix, keep_vars)
+
+ return state_dict_
+
+ def load_state_dict(self, state_dict, strict=True):
+ """Customized load."""
+
+ # Word embedding.
+ if self._word_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._word_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'word_embeddings' in key:
+ state_dict_[key.split('word_embeddings.')[1]] \
+ = state_dict[key]
+ self.word_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Position embedding.
+ if self._position_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._position_embeddings_key]
+ else:
+ # for backward compatibility.
+ state_dict_ = {}
+ for key in state_dict.keys():
+ if 'position_embeddings' in key:
+ state_dict_[key.split('position_embeddings.')[1]] \
+ = state_dict[key]
+ self.position_embeddings.load_state_dict(state_dict_, strict=strict)
+
+ # Tokentype embedding.
+ if self.num_tokentypes > 0:
+ state_dict_ = {}
+ if self._tokentype_embeddings_key in state_dict:
+ state_dict_ = state_dict[self._tokentype_embeddings_key]
+ else:
+ # for backward compatibility.
+ for key in state_dict.keys():
+ if 'tokentype_embeddings' in key:
+ state_dict_[key.split('tokentype_embeddings.')[1]] \
+ = state_dict[key]
+ if len(state_dict_.keys()) > 0:
+ self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
+ else:
+ print('***WARNING*** expected tokentype embeddings in the '
+ 'checkpoint but could not find it',
+ flush=True)
+
+
+class HiddenParallelEmbedding1D(torch.nn.Module):
+ """Embedding parallelized in the vocabulary dimension.
+
+ This is mainly adapted from torch.nn.Embedding and all the default
+ values are kept.
+ Arguments:
+ num_embeddings: vocabulary size.
+ embedding_dim: size of hidden state.
+ init_method: method to initialize weights.
+ """
+
+ def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx: int = None, init_method=None):
+ super(HiddenParallelEmbedding1D, self).__init__()
+ # Keep the input dimensions.
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
+ # Set the details for compatibility.
+ self.padding_idx = padding_idx
+ self.max_norm = None
+ self.norm_type = 2.
+ self.scale_grad_by_freq = False
+ self.sparse = False
+ self._weight = None
+
+ # Allocate weights and initialize.
+ factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
+ self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
+ init.uniform_(self.weight, -1, 1)
+
+ def forward(self, input_):
+
+ # Get the embeddings.
+ output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type,
+ self.scale_grad_by_freq, self.sparse)
+
+ # Reduce across all the model parallel GPUs.
+ output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
+ return output
+
+
+@LAYERS.register_module
+class HiddenParallelGPTLMHead1D(ParallelLayer):
+ """
+ Language model head that shares the same parameters with the embedding matrix.
+ """
+
+ def __init__(
+ self,
+ embed=None,
+ embed_dim=None,
+ vocab_size=None,
+ dtype=None,
+ ):
+ super().__init__()
+ if embed is not None:
+ self.head = embed
+ self.synced_embed = True
+ else:
+ # self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
+ # (hidden_size/q, vocab_size)
+ self.synced_embed = False
+ self.head = Linear1D_Row(in_features=embed_dim,
+ out_features=vocab_size,
+ bias=False,
+ dtype=dtype,
+ parallel_input=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ if self.synced_embed:
+ x = F.linear(x, self.head.weight)
+ else:
+ x = self.head(x)
+
+ return x
diff --git a/examples/language/gpt/titans/model/gpt1d.py b/examples/language/gpt/titans/model/gpt1d.py
new file mode 100644
index 000000000000..2edd03606b7d
--- /dev/null
+++ b/examples/language/gpt/titans/model/gpt1d.py
@@ -0,0 +1,349 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+
+import math
+
+import torch
+from torch import Tensor
+from torch import nn as nn
+
+from colossalai import kernel
+from colossalai import nn as col_nn
+from colossalai.core import global_context as gpc
+from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
+from colossalai.nn.layer import Linear1D_Col, Linear1D_Row
+from colossalai.nn.layer.base_layer import ParallelLayer
+from colossalai.nn.layer.utils import ACT2FN, divide
+from colossalai.utils import checkpoint
+from colossalai.utils.activation_checkpoint import checkpoint
+
+__all__ = [
+ 'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
+]
+
+
+class GPTMLP1D(ParallelLayer):
+
+ def __init__(
+ self,
+ in_features: int,
+ mlp_ratio: int,
+ act_func: str = 'gelu',
+ dropout_prob: float = 0.,
+ dtype=None,
+ checkpoint: bool = False,
+ skip_bias_add: bool = False,
+ ):
+ super().__init__()
+
+ self.in_features = in_features
+ self.mlp_ratio = mlp_ratio
+ self.checkpoint = checkpoint
+ self.skip_bias_add = skip_bias_add
+
+ self.act = ACT2FN[act_func]
+ skip_dense_1_add_bias = False
+
+ # Project to mlp_ratio * h.
+ self.dense_1 = Linear1D_Col(
+ self.in_features,
+ int(self.mlp_ratio * self.in_features),
+ dtype=dtype,
+ gather_output=False,
+ skip_bias_add=skip_dense_1_add_bias,
+ )
+
+ # Project back to h.
+ self.dense_2 = Linear1D_Row(
+ int(self.mlp_ratio * self.in_features),
+ self.in_features,
+ dtype=dtype,
+ parallel_input=True,
+ )
+
+ self.dropout = col_nn.Dropout(dropout_prob)
+
+ def _forward(self, hidden_states: Tensor) -> Tensor:
+ intermediate_output = self.dense_1(hidden_states)
+ intermediate_output = self.act(intermediate_output)
+
+ output = self.dense_2(intermediate_output)
+ output = self.dropout(output)
+ return output
+
+ def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
+ return checkpoint(self._forward, False, hidden_states)
+
+ def forward(self, hidden_states: Tensor) -> Tensor:
+ if self.checkpoint:
+ return self._checkpoint_forward(hidden_states)
+ else:
+ return self._forward(hidden_states)
+
+
+class GenericGPTSelfAttention1D(ParallelLayer):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_attention_heads: int,
+ attention_dropout_prob: float,
+ hidden_dropout_prob: float,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings=1024,
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.attention_head_size = divide(hidden_size, num_attention_heads)
+ self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
+ self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
+ self.checkpoint = checkpoint
+ self.query_key_value = Linear1D_Col(
+ hidden_size,
+ 3 * hidden_size,
+ dtype=dtype,
+ )
+ self.attention_dropout = col_nn.Dropout(attention_dropout_prob)
+ self.dense = Linear1D_Row(
+ hidden_size,
+ hidden_size,
+ dtype=dtype,
+ parallel_input=True,
+ )
+ self.dropout = col_nn.Dropout(hidden_dropout_prob)
+
+ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
+ raise NotImplementedError
+
+ def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+ query_key_value = self.query_key_value(hidden_states)
+ new_qkv_shape = query_key_value.shape[:-1] + \
+ (self.num_attention_heads_per_partition, 3 * self.attention_head_size)
+ query_key_value = query_key_value.view(new_qkv_shape)
+ query_key_value = query_key_value.permute((0, 2, 1, 3))
+ query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
+
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ attention_scores = self.softmax_forward(attention_scores, attention_mask, query_layer, key_layer)
+
+ attention_scores = attention_scores.type(value_layer.dtype)
+
+ attention_probs = self.attention_dropout(attention_scores)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.transpose(1, 2)
+ new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
+ context_layer = context_layer.reshape(new_context_layer_shape)
+ output = self.dense(context_layer)
+ output = self.dropout(output)
+
+ return output
+
+ def _checkpoint_forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+ return checkpoint(self._forward, False, hidden_states, attention_mask)
+
+ def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
+ if self.checkpoint:
+ return self._checkpoint_forward(hidden_states, attention_mask)
+ else:
+ return self._forward(hidden_states, attention_mask)
+
+
+class GPTSelfAttention1D(GenericGPTSelfAttention1D):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ attention_dropout_prob: float,
+ hidden_dropout_prob: float,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings=1024):
+ super().__init__(hidden_size,
+ num_attention_heads,
+ attention_dropout_prob,
+ hidden_dropout_prob,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings)
+ self.softmax = nn.Softmax(dim=-1)
+ max_positions = max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions),
+ dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
+
+ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ # causal mask
+ query_length, key_length = query_layer.size(-2), key_layer.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
+ attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
+ if attention_mask is not None:
+ # Apply the attention mask
+ attention_scores = attention_scores + attention_mask
+ attention_scores = self.softmax(attention_scores)
+ return attention_scores
+
+
+class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ attention_dropout_prob: float,
+ hidden_dropout_prob: float,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings=1024):
+ super().__init__(hidden_size,
+ num_attention_heads,
+ attention_dropout_prob,
+ hidden_dropout_prob,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings)
+ self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
+ input_in_bf16=False,
+ attn_mask_type=AttnMaskType.causal,
+ scaled_masked_softmax_fusion=True,
+ mask_func=None,
+ softmax_in_fp32=True,
+ scale=math.sqrt(self.attention_head_size))
+
+ def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
+ return self.softmax(attention_scores, attention_mask)
+
+
+class GenericGPTTransformerLayer1D(ParallelLayer):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ act_func: str = 'gelu',
+ mlp_ratio: float = 4.0,
+ attention_dropout_prob: float = 0.,
+ hidden_dropout_prob: float = 0.,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ apply_post_layer_norm: bool = False,
+ attention=None,
+ layer_norm=None):
+ super().__init__()
+ self.checkpoint = checkpoint
+ self.dtype = dtype
+ self.norm1 = layer_norm(hidden_size, eps=layer_norm_epsilon)
+ self.apply_post_layer_norm = apply_post_layer_norm
+ self.attention = attention(
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ attention_dropout_prob=attention_dropout_prob,
+ hidden_dropout_prob=hidden_dropout_prob,
+ dtype=dtype,
+ max_position_embeddings=max_position_embeddings,
+ checkpoint=False,
+ )
+
+ self.norm2 = layer_norm(hidden_size, eps=layer_norm_epsilon)
+ self.mlp = GPTMLP1D(
+ in_features=hidden_size,
+ dropout_prob=hidden_dropout_prob,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ dtype=dtype,
+ checkpoint=False,
+ )
+
+ def _forward(self, hidden_states, attention_mask) -> Tensor:
+ if not self.apply_post_layer_norm:
+ residual = hidden_states
+ hidden_states = self.norm1(hidden_states)
+ if self.apply_post_layer_norm:
+ residual = hidden_states
+ attention_output = self.attention(hidden_states, attention_mask)
+ hidden_states = residual + attention_output
+
+ if not self.apply_post_layer_norm:
+ residual = hidden_states
+ hidden_states = self.norm2(hidden_states)
+ if self.apply_post_layer_norm:
+ residual = hidden_states
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + feed_forward_hidden_states
+
+ output = (hidden_states, attention_mask)
+ return output
+
+ def forward(self, hidden_states, attention_mask):
+ if self.checkpoint:
+ return checkpoint(self._forward, False, hidden_states, attention_mask)
+ else:
+ return self._forward(hidden_states, attention_mask)
+
+
+class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ act_func: str = 'gelu',
+ mlp_ratio: float = 4,
+ attention_dropout_prob: float = 0,
+ hidden_dropout_prob: float = 0,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 0.00001,
+ apply_post_layer_norm: bool = False):
+ attention = GPTSelfAttention1D
+ layer_norm = nn.LayerNorm
+ super().__init__(hidden_size,
+ num_attention_heads,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ attention_dropout_prob=attention_dropout_prob,
+ hidden_dropout_prob=hidden_dropout_prob,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings,
+ layer_norm_epsilon=layer_norm_epsilon,
+ apply_post_layer_norm=apply_post_layer_norm,
+ attention=attention,
+ layer_norm=layer_norm)
+
+
+class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
+
+ def __init__(self,
+ hidden_size: int,
+ num_attention_heads: int,
+ act_func: str = 'gelu',
+ mlp_ratio: float = 4,
+ attention_dropout_prob: float = 0,
+ hidden_dropout_prob: float = 0,
+ dtype=None,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 0.00001,
+ apply_post_layer_norm: bool = False):
+ attention = FusedGPTSelfAttention1D
+ layer_norm = kernel.LayerNorm
+ super().__init__(hidden_size,
+ num_attention_heads,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ attention_dropout_prob=attention_dropout_prob,
+ hidden_dropout_prob=hidden_dropout_prob,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings,
+ layer_norm_epsilon=layer_norm_epsilon,
+ apply_post_layer_norm=apply_post_layer_norm,
+ attention=attention,
+ layer_norm=layer_norm)
diff --git a/examples/language/gpt/titans/model/pipeline_gpt1d.py b/examples/language/gpt/titans/model/pipeline_gpt1d.py
new file mode 100644
index 000000000000..30180285bc70
--- /dev/null
+++ b/examples/language/gpt/titans/model/pipeline_gpt1d.py
@@ -0,0 +1,322 @@
+import inspect
+
+# import model_zoo.gpt.gpt as col_gpt
+import titans.model.gpt.gpt as col_gpt
+import torch
+import torch.nn as nn
+
+from colossalai import kernel
+from colossalai import nn as col_nn
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import get_dist_logger
+from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
+from colossalai.pipeline.utils import partition_uniform
+
+from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
+from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
+
+__all__ = [
+ 'GPT2_small_pipeline_1D',
+ 'GPT2_exlarge_pipeline_1D',
+ 'GPT3_pipeline_1D',
+ 'GPT2_exlarge_pipeline_hybrid',
+ 'GPT2_small_pipeline_hybrid',
+ 'GPT3_pipeline_hybrid',
+]
+
+
+class GenericPipelineGPT(nn.Module):
+
+ def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
+ super().__init__()
+ self.embedding = embedding
+ self.blocks = blocks
+ self.norm = norm
+ self.head = head
+ assert blocks is not None
+ if norm is not None or head is not None:
+ assert norm is not None and head is not None
+
+ def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
+ if self.embedding is not None:
+ hidden_states = self.embedding(input_ids=input_ids)
+ batch_size = hidden_states.shape[0]
+ attention_mask = attention_mask.view(batch_size, -1)
+ attention_mask = attention_mask[:, None, None, :]
+ attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * -10000.0
+ for block in self.blocks:
+ hidden_states, attention_mask = block(hidden_states, attention_mask)
+ if self.norm is not None:
+ hidden_states = self.head(self.norm(hidden_states))
+ return hidden_states
+
+
+class PipelineGPT1D(GenericPipelineGPT):
+
+ def __init__(self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ embed_drop_rate: float = 0.,
+ act_func: str = 'gelu',
+ mlp_ratio: int = 4.0,
+ attn_drop_rate: float = 0.,
+ drop_rate: float = 0.,
+ dtype: torch.dtype = torch.float,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ apply_post_layer_norm: bool = False,
+ first: bool = False,
+ last: bool = False,
+ embed_split_hidden=False):
+ embedding = None
+ norm = None
+ head = None
+ embed_cls = VocabParallelEmbedding
+ head_cls = VocabParallelGPTLMHead1D
+ if embed_split_hidden:
+ embed_cls = HiddenParallelEmbedding
+ head_cls = HiddenParallelGPTLMHead1D
+ if first:
+ embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
+ blocks = nn.ModuleList([
+ GPTTransformerLayer1D(hidden_size,
+ num_attention_heads,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ attention_dropout_prob=attn_drop_rate,
+ hidden_dropout_prob=drop_rate,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings,
+ layer_norm_epsilon=layer_norm_epsilon,
+ apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
+ ])
+ if last:
+ norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+ head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
+ super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
+
+
+class FusedPipelineGPT1D(GenericPipelineGPT):
+
+ def __init__(self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ embed_drop_rate: float = 0.,
+ act_func: str = 'gelu',
+ mlp_ratio: int = 4.0,
+ attn_drop_rate: float = 0.,
+ drop_rate: float = 0.,
+ dtype: torch.dtype = torch.float,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ apply_post_layer_norm: bool = False,
+ first: bool = False,
+ last: bool = False,
+ embed_split_hidden=False):
+ embedding = None
+ norm = None
+ head = None
+ embed_cls = VocabParallelEmbedding
+ head_cls = VocabParallelGPTLMHead1D
+ if embed_split_hidden:
+ embed_cls = HiddenParallelEmbedding
+ head_cls = HiddenParallelGPTLMHead1D
+ if first:
+ embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
+ blocks = nn.ModuleList([
+ FusedGPTTransformerLayer1D(hidden_size,
+ num_attention_heads,
+ act_func=act_func,
+ mlp_ratio=mlp_ratio,
+ attention_dropout_prob=attn_drop_rate,
+ hidden_dropout_prob=drop_rate,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ max_position_embeddings=max_position_embeddings,
+ layer_norm_epsilon=layer_norm_epsilon,
+ apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
+ ])
+ if last:
+ norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+ head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
+ super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
+
+ def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
+ if self.embedding is not None:
+ hidden_states = self.embedding(input_ids=input_ids)
+ attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
+ for block in self.blocks:
+ hidden_states, attention_mask = block(hidden_states, attention_mask)
+ if self.norm is not None:
+ hidden_states = self.head(self.norm(hidden_states))
+ return hidden_states
+
+
+class PipelineGPTHybrid(GenericPipelineGPT):
+
+ def __init__(self,
+ num_layers: int = 12,
+ hidden_size: int = 768,
+ num_attention_heads: int = 12,
+ vocab_size: int = 50304,
+ embed_drop_rate: float = 0.,
+ act_func: str = 'gelu',
+ mlp_ratio: int = 4,
+ attn_drop_rate: float = 0.,
+ drop_rate: float = 0.,
+ dtype: torch.dtype = torch.float,
+ checkpoint: bool = False,
+ max_position_embeddings: int = 1024,
+ layer_norm_epsilon: float = 1e-5,
+ apply_post_layer_norm: bool = False,
+ first: bool = False,
+ last: bool = False,
+ embed_split_hidden=False):
+ embedding = None
+ norm = None
+ head = None
+ if first:
+ embedding = col_gpt.GPTEmbedding(hidden_size,
+ vocab_size,
+ max_position_embeddings,
+ dropout=embed_drop_rate,
+ dtype=dtype)
+ blocks = nn.ModuleList([
+ col_gpt.GPTBlock(hidden_size,
+ num_attention_heads,
+ mlp_ratio=mlp_ratio,
+ attention_dropout=attn_drop_rate,
+ dropout=drop_rate,
+ dtype=dtype,
+ checkpoint=checkpoint,
+ activation=nn.functional.gelu) for _ in range(num_layers)
+ ])
+ if last:
+ norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
+ # head = col_gpt.GPTLMHead(vocab_size=vocab_size,
+ # hidden_size=hidden_size,
+ # dtype=dtype,
+ # bias=False)
+ head = col_nn.Classifier(hidden_size, vocab_size, dtype=dtype, bias=False)
+ super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
+
+
+def _filter_kwargs(func, kwargs):
+ sig = inspect.signature(func)
+ return {k: v for k, v in kwargs.items() if k in sig.parameters}
+
+
+def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ logger = get_dist_logger()
+
+ if gpc.is_initialized(ParallelMode.PIPELINE):
+ pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
+ pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
+ else:
+ pipeline_size = 1
+ pipeline_rank = 0
+ rank = gpc.get_global_rank()
+
+ if pipeline_size > 1:
+ wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
+ else:
+ wrapper = None
+ parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
+ models = []
+ for start, end in parts:
+ kwargs['num_layers'] = end - start
+ kwargs['first'] = start == 0
+ kwargs['last'] = end == num_layers
+ logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
+ chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
+
+ if wrapper is not None:
+ if start == 0:
+ wrapper.register_module(chunk.embedding.word_embeddings)
+ elif end == num_layers:
+ wrapper.register_module(chunk.head)
+ models.append(chunk)
+ if len(models) == 1:
+ model = models[0]
+ else:
+ model = nn.ModuleList(models)
+
+ numel = 0
+ for _, param in model.named_parameters(recurse=True):
+ numel += param.numel()
+ logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
+ return model
+
+
+def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs):
+ model = FusedPipelineGPT1D if fused else PipelineGPT1D
+ return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
+
+
+def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
+ return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
+
+
+def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
+ cfg = dict(hidden_size=768,
+ num_attention_heads=12,
+ checkpoint=checkpoint,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
+
+
+def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
+ cfg = dict(hidden_size=1600,
+ num_attention_heads=32,
+ checkpoint=checkpoint,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
+
+
+def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
+ cfg = dict(hidden_size=12288,
+ num_attention_heads=96,
+ checkpoint=checkpoint,
+ max_position_embeddings=2048,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
+
+
+def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
+ cfg = dict(hidden_size=1600,
+ num_attention_heads=32,
+ checkpoint=checkpoint,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
+
+
+def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
+ cfg = dict(hidden_size=768,
+ num_attention_heads=12,
+ checkpoint=checkpoint,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
+
+
+def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
+ cfg = dict(hidden_size=12288,
+ num_attention_heads=96,
+ checkpoint=checkpoint,
+ max_position_embeddings=2048,
+ dtype=dtype,
+ embed_split_hidden=embed_split_hidden)
+ return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)
diff --git a/examples/language/gpt/titans/requirements.txt b/examples/language/gpt/titans/requirements.txt
new file mode 100644
index 000000000000..64ff7a4abcd8
--- /dev/null
+++ b/examples/language/gpt/titans/requirements.txt
@@ -0,0 +1,4 @@
+torch==1.12.1
+titans==0.0.7
+colossalai==0.2.0+torch1.12cu11.3
+-f https://release.colossalai.org
diff --git a/examples/language/gpt/titans/run.sh b/examples/language/gpt/titans/run.sh
new file mode 100644
index 000000000000..a1a7fc737db0
--- /dev/null
+++ b/examples/language/gpt/titans/run.sh
@@ -0,0 +1,3 @@
+export DATA=/data/scratch/gpt_data/small-gpt-dataset.json
+DUMMY_DATA=--use_dummy_dataset
+colossalai run --nproc_per_node=2 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch $DUMMY_DATA
diff --git a/examples/language/gpt/titans/test_ci.sh b/examples/language/gpt/titans/test_ci.sh
new file mode 100644
index 000000000000..7cb24c1a4082
--- /dev/null
+++ b/examples/language/gpt/titans/test_ci.sh
@@ -0,0 +1 @@
+colossalai run --nproc_per_node=4 train_gpt.py --config ./configs/gpt2_small_zero3_pp1d.py --from_torch --use_dummy_dataset
diff --git a/examples/language/gpt/titans/train_gpt.py b/examples/language/gpt/titans/train_gpt.py
new file mode 100644
index 000000000000..66225d6c8044
--- /dev/null
+++ b/examples/language/gpt/titans/train_gpt.py
@@ -0,0 +1,113 @@
+import contextlib
+import os
+
+import torch
+import torch.nn as nn
+from dataset.webtext import WebtextDataset
+from titans.model.gpt import GPTLMLoss
+
+import colossalai
+import colossalai.utils as utils
+from colossalai.context.parallel_mode import ParallelMode
+from colossalai.core import global_context as gpc
+from colossalai.logging import disable_existing_loggers, get_dist_logger
+from colossalai.nn import LinearWarmupLR
+from colossalai.trainer import Trainer, hooks
+from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
+from colossalai.utils.timer import MultiTimer
+from colossalai.zero.init_ctx import ZeroInitContext
+
+
+def calc_local_model_size(model: torch.nn.Module):
+ numel_per_device = 0
+ for p in model.parameters():
+ numel_per_device += p.numel()
+ return numel_per_device
+
+
+VOCAB_SIZE = 50257
+
+
+def main():
+ parser = colossalai.get_default_parser()
+ parser.add_argument('--from_torch', default=False, action='store_true')
+ parser.add_argument('--use_dummy_dataset', default=False, action='store_true')
+ args = parser.parse_args()
+ disable_existing_loggers()
+ if args.from_torch:
+ colossalai.launch_from_torch(config=args.config)
+ else:
+ colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
+ logger = get_dist_logger()
+
+ data_path = None if args.use_dummy_dataset else os.environ['DATA']
+ logger.info(f'Build data loader from path {data_path}', ranks=[0])
+
+ train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN)
+ train_dataloader = utils.get_dataloader(train_ds,
+ seed=42,
+ batch_size=gpc.config.BATCH_SIZE,
+ pin_memory=True,
+ shuffle=True,
+ drop_last=True)
+
+ logger.info('Build model', ranks=[0])
+ use_pipeline = is_using_pp()
+ use_interleaved = hasattr(gpc.config.model, 'num_chunks')
+ use_zero3 = hasattr(gpc.config, 'zero')
+ ctx = contextlib.nullcontext()
+ if use_zero3:
+ ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
+ shard_strategy=gpc.config.zero.model_config.shard_strategy,
+ shard_param=True)
+ with ctx:
+ model = gpc.config.model.pop('type')(**gpc.config.model)
+ if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList):
+ model = nn.ModuleList([model])
+
+ if use_zero3:
+ numel = ctx.model_numel_tensor.item()
+ else:
+ numel = calc_local_model_size(model)
+
+ tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \
+ * gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)
+
+ criterion = getattr(gpc.config, 'loss_fn', None)
+ if criterion is not None:
+ criterion = criterion.type()
+ else:
+ criterion = GPTLMLoss()
+ logger.info('Build optimizer', ranks=[0])
+ optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
+ lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
+ engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
+ optimizer,
+ criterion,
+ train_dataloader=train_dataloader,
+ lr_scheduler=lr_scheduler)
+ global_batch_size = gpc.config.BATCH_SIZE * \
+ gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
+ logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
+ timier = MultiTimer()
+ trainer = Trainer(engine=engine, logger=logger, timer=timier)
+ hook_list = [
+ hooks.LossHook(),
+ hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=True),
+ hooks.LogMetricByEpochHook(logger),
+ hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
+ hooks.LogMetricByStepHook(),
+ hooks.LogMemoryByEpochHook(logger),
+ # hooks.LogMemoryByEpochHook(logger),
+ # hooks.LogTimingByEpochHook(timer, logger),
+ ]
+ trainer.fit(train_dataloader=train_dataloader,
+ epochs=gpc.config.NUM_EPOCHS,
+ test_interval=1,
+ hooks=hook_list,
+ display_progress=True,
+ return_output_label=False)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/language/opt/requirements.txt b/examples/language/opt/requirements.txt
new file mode 100644
index 000000000000..137a69e80498
--- /dev/null
+++ b/examples/language/opt/requirements.txt
@@ -0,0 +1,2 @@
+colossalai >= 0.1.12
+torch >= 1.8.1
diff --git a/examples/language/opt/run_gemini.sh b/examples/language/opt/run_gemini.sh
index d9625723a1ae..73f231292a13 100644
--- a/examples/language/opt/run_gemini.sh
+++ b/examples/language/opt/run_gemini.sh
@@ -1,13 +1,20 @@
set -x
export BS=${BS:-16}
export MEMCAP=${MEMCAP:-0}
-# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7`, `13b`, `30b`, `66b`. For `175b`
+# Acceptable values include `125m`, `350m`, `1.3b`, `2.7b`, `6.7b`, `13b`, `30b`, `66b`. For `175b`
export MODEL=${MODEL:-"125m"}
export GPUNUM=${GPUNUM:-1}
+export USE_SHARD_INIT=${USE_SHARD_INIT:-"false"}
# make directory for logs
mkdir -p ./logs
+if [ ${USE_SHARD_INIT} = "true" ]; then
+ USE_SHARD_INIT="--shardinit"
+else
+ USE_SHARD_INIT=""
+fi
+
export MODLE_PATH="facebook/opt-${MODEL}"
# HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1
@@ -17,4 +24,5 @@ torchrun \
train_gemini_opt.py \
--mem_cap ${MEMCAP} \
--model_name_or_path ${MODLE_PATH} \
+ ${USE_SHARD_INIT} \
--batch_size ${BS} 2>&1 | tee ./logs/colo_${MODEL}_bs_${BS}_cap_${MEMCAP}_gpu_${GPUNUM}.log
diff --git a/examples/language/opt/test_ci.sh b/examples/language/opt/test_ci.sh
new file mode 100644
index 000000000000..317f602cda3c
--- /dev/null
+++ b/examples/language/opt/test_ci.sh
@@ -0,0 +1,4 @@
+for GPUNUM in 2 1
+do
+env BS=2 MODEL="125m" GPUNUM=$GPUNUM bash ./run_gemini.sh
+done
diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py
index 64426ba4285c..4993ce25db17 100755
--- a/examples/language/opt/train_gemini_opt.py
+++ b/examples/language/opt/train_gemini_opt.py
@@ -39,6 +39,8 @@
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
+from colossalai.tensor import ProcessGroup, ShardSpec
+
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
@@ -102,6 +104,11 @@ def parse_args():
help="Model type to use if training from scratch.",
choices=MODEL_TYPES,
)
+ parser.add_argument(
+ "--shardinit",
+ action="store_true",
+ help="Initialize the model with tensor parallel",
+ )
parser.add_argument("--mem_cap", type=int, default=0, help="use mem cap")
parser.add_argument("--init_in_cpu", action='store_true', default=False, help="init training model in cpu")
args = parser.parse_args()
@@ -159,16 +166,28 @@ def main():
else:
init_dev = get_current_device()
+ # shard init prameters
+ if args.shardinit:
+ logger.info("Sharding initialization !", ranks=[0])
+ else:
+ logger.info("Skipping sharding initialization", ranks=[0])
+
+ world_size = torch.distributed.get_world_size()
+ shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
+ default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
+
# build model
- if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
- # currently, there has a bug in pretrained opt-13b
- # we can not import it until huggingface fix it
+ if args.model_name_or_path is None:
logger.info("Train a new model from scratch", ranks=[0])
- with ColoInitContext(device=init_dev, dtype=torch.half):
+ with ColoInitContext(device=init_dev, dtype=torch.half,
+ default_dist_spec=default_dist_spec,
+ default_pg=shard_pg):
model = OPTForCausalLM(config)
else:
logger.info("Finetune a pre-trained model", ranks=[0])
- with ColoInitContext(device=init_dev, dtype=torch.half):
+ with ColoInitContext(device=init_dev, dtype=torch.half,
+ default_dist_spec=default_dist_spec,
+ default_pg=shard_pg):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
@@ -179,7 +198,8 @@ def main():
numel = sum([p.numel() for p in model.parameters()])
PLACEMENT_POLICY = 'cpu'
- model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, pin_memory=True)
+ model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY,
+ pin_memory=True, strict_ddp_mode=args.shardinit)
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
SEQ_LEN = 1024
diff --git a/examples/language/palm/run.sh b/examples/language/palm/run.sh
index 4aa868953f7b..7a533509e009 100644
--- a/examples/language/palm/run.sh
+++ b/examples/language/palm/run.sh
@@ -8,4 +8,4 @@ export PLACEMENT='cpu'
export USE_SHARD_INIT=False
export BATCH_SIZE=4
-env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train_new.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
\ No newline at end of file
+env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --tp_degree=${TPDEGREE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
diff --git a/examples/language/palm/test_ci.sh b/examples/language/palm/test_ci.sh
new file mode 100644
index 000000000000..f21095578077
--- /dev/null
+++ b/examples/language/palm/test_ci.sh
@@ -0,0 +1,9 @@
+$(cd `dirname $0`;pwd)
+
+for BATCH_SIZE in 2
+do
+for GPUNUM in 1 4
+do
+env OMP_NUM_THREADS=12 torchrun --standalone --nproc_per_node=${GPUNUM} --master_port 29501 train.py --dummy_data=True --batch_size=${BATCH_SIZE} 2>&1 | tee run.log
+done
+done
diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py
index 7c080b7f321d..2f012780da77 100644
--- a/examples/language/palm/train.py
+++ b/examples/language/palm/train.py
@@ -1,27 +1,30 @@
import gzip
import random
+from functools import partial
+from time import time
import numpy as np
import torch
+import torch.nn as nn
import torch.optim as optim
import tqdm
from packaging import version
from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
-from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
-from colossalai.nn.parallel import GeminiDDP, ZeroDDP
+from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
# constants
-NUM_BATCHES = int(1000)
+NUM_BATCHES = int(10)
+WARMUP_BATCHES = 1
GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 2e-4
VALIDATE_EVERY = 100
@@ -63,9 +66,16 @@ def parse_args():
default=8,
help="batch size per DP group of training.",
)
+ parser.add_argument(
+ "--dummy_data",
+ type=bool,
+ default=False,
+ help="use dummy dataset.",
+ )
args = parser.parse_args()
return args
+
# helpers
def cycle(loader):
while True:
@@ -77,10 +87,22 @@ def decode_token(token):
return str(chr(max(32, token)))
+def get_tflops(model_numel, batch_size, seq_len, step_time):
+ return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
+
+
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
+def get_model_size(model: nn.Module):
+ total_numel = 0
+ for module in model.modules():
+ for p in module.parameters(recurse=False):
+ total_numel += p.numel()
+ return total_numel
+
+
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
@@ -104,6 +126,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
+
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
@@ -117,6 +140,7 @@ def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
+
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
@@ -143,20 +167,33 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
-
param.visited = True
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
- raise TypeError(f"{args.distplan} is error")
+ raise TypeError(f"{args.distplan} is error")
disable_existing_loggers()
colossalai.launch_from_torch(config={})
+logger = get_dist_logger()
+
+
+def generate_dataset(dummy_data: bool = False):
+ if not dummy_data:
+ with gzip.open("./data/enwik8.gz") as file:
+ X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
+ trX, vaX = np.split(X, [int(90e6)])
+ data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
+ # print(f"data_train {data_train.shape} {data_train.dtype} {max(data_train)} {min(data_train)}")
+ # print(f"data_val {data_val.shape} {data_val.dtype} {max(data_val)} {min(data_val)}")
+ return data_train, data_val
+ else:
+ return torch.randint(0, 100, (90000000,)), torch.randint(0, 100, (5000000,))
+
+
+data_train, data_val = generate_dataset(args.dummy_data)
-with gzip.open("./data/enwik8.gz") as file:
- X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
- trX, vaX = np.split(X, [int(90e6)])
- data_train, data_val = torch.from_numpy(trX), torch.from_numpy(vaX)
+print("generate dataset ready!")
class TextSamplerDataset(Dataset):
@@ -188,7 +225,7 @@ def __len__(self):
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx:
- model = PaLM(num_tokens=256, dim=512, depth=8)
+ model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
@@ -205,25 +242,42 @@ def __len__(self):
model.cuda()
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
-
+# model is shared after TP
+numel = get_model_size(model)
+get_tflops_func = partial(get_tflops, numel, args.batch_size, SEQ_LEN)
# training
model.train()
-
+tflops_list = []
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
if args.distplan == "colossalai":
optimizer.zero_grad()
-
+ start = time()
loss = model(next(train_loader))
+ fwd_end = time()
+ fwd_time = fwd_end - start
# loss.backward()
optimizer.backward(loss)
+ bwd_end = time()
+ bwd_time = bwd_end - fwd_end
- print(f"training loss: {loss.item()}")
+ # print(f"training loss: {loss.item()}")
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
# optim.step()
# optim.zero_grad()
optimizer.step()
+ optim_time = time() - bwd_end
+ step_time = time() - start
+
+ step_tflops = get_tflops_func(step_time)
+ logger.info(
+ f"[{i + 1}/{NUM_BATCHES}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
+ ranks=[0],
+ )
+ if i >= WARMUP_BATCHES:
+ tflops_list.append(step_tflops)
+
else:
for __ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
@@ -234,12 +288,16 @@ def __len__(self):
optim.step()
optim.zero_grad()
- # TODO
- # if i % VALIDATE_EVERY == 0:
- # model.eval()
- # with torch.no_grad():
- # loss = model(next(val_loader))
- # print(f"validation loss: {loss.item()}")
+tflops_list.sort()
+median_index = ((NUM_BATCHES - WARMUP_BATCHES) >> 1) + WARMUP_BATCHES
+logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
+
+# TODO
+# if i % VALIDATE_EVERY == 0:
+# model.eval()
+# with torch.no_grad():
+# loss = model(next(val_loader))
+# print(f"validation loss: {loss.item()}")
# if i % GENERATE_EVERY == 0:
# model.eval()
@@ -249,4 +307,4 @@ def __len__(self):
# sample = model.generate(inp[None, ...], GENERATE_LENGTH)
# output_str = decode_tokens(sample[0])
- # print(output_str)
\ No newline at end of file
+ # print(output_str)
diff --git a/examples/tutorial/README.md b/examples/tutorial/README.md
index bef7c8905033..f4843331fd54 100644
--- a/examples/tutorial/README.md
+++ b/examples/tutorial/README.md
@@ -1,8 +1,10 @@
# Colossal-AI Tutorial Hands-on
+> This path is an abbreviated tutorial prepared for specific activities and may not be maintained in real time. For use of Colossal-AI, please refer to other [examples](https://github.com/hpcaitech/ColossalAI/tree/main/examples) and [documents](https://www.colossalai.org/).
+
## Introduction
-Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), etc.
+Welcome to the [Colossal-AI](https://github.com/hpcaitech/ColossalAI) tutorial, which has been accepted as official tutorials by top conference [SC](https://sc22.supercomputing.org/), [AAAI](https://aaai.org/Conferences/AAAI-23/), [PPoPP](https://ppopp23.sigplan.org/), [CVPR](https://cvpr2023.thecvf.com/), [ISC](https://www.isc-hpc.com/), etc.
[Colossal-AI](https://github.com/hpcaitech/ColossalAI), a unified deep learning system for the big model era, integrates
@@ -15,33 +17,18 @@ quickly deploy large AI model training and inference, reducing large AI model tr
[**Colossal-AI**](https://github.com/hpcaitech/ColossalAI) |
[**Paper**](https://arxiv.org/abs/2110.14883) |
[**Documentation**](https://www.colossalai.org/) |
-[**Forum**](https://github.com/hpcaitech/ColossalAI/discussions) |
+[**Issue**](https://github.com/hpcaitech/ColossalAI/issues/new/choose) |
[**Slack**](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
## Table of Content
- - Multi-dimensional Parallelism
- - Know the components and sketch of Colossal-AI
- - Step-by-step from PyTorch to Colossal-AI
- - Try data/pipeline parallelism and 1D/2D/2.5D/3D tensor parallelism using a unified model
- - Sequence Parallelism
- - Try sequence parallelism with BERT
- - Combination of data/pipeline/sequence parallelism
- - Faster training and longer sequence length
- - Large Batch Training Optimization
- - Comparison of small/large batch size with SGD/LARS optimizer
- - Acceleration from a larger batch size
- - Auto-Parallelism
- - Parallelism with normal non-distributed training code
- - Model tracing + solution solving + runtime communication inserting all in one auto-parallelism system
- - Try single program, multiple data (SPMD) parallel with auto-parallelism SPMD solver on ResNet50
- - Fine-tuning and Serving for OPT
- - Try pre-trained OPT model weights with Colossal-AI
- - Fine-tuning OPT with limited hardware using ZeRO, Gemini and parallelism
- - Deploy the fine-tuned model to inference service
- - Acceleration of Stable Diffusion
- - Stable Diffusion with Lightning
- - Try Lightning Colossal-AI strategy to optimize memory and accelerate speed
+ - Multi-dimensional Parallelism [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/hybrid_parallel) [[video]](https://www.youtube.com/watch?v=OwUQKdA2Icc)
+ - Sequence Parallelism [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/sequence_parallel) [[video]](https://www.youtube.com/watch?v=HLLVKb7Cszs)
+ - Large Batch Training Optimization [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/large_batch_optimizer) [[video]](https://www.youtube.com/watch?v=9Un0ktxJZbI)
+ - Automatic Parallelism [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/auto_parallel) [[video]](https://www.youtube.com/watch?v=_-2jlyidxqE)
+ - Fine-tuning and Inference for OPT [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/opt) [[video]](https://www.youtube.com/watch?v=jbEFNVzl67Y)
+ - Optimized AlphaFold [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/tutorial/fastfold) [[video]](https://www.youtube.com/watch?v=-zP13LfJP7w)
+ - Optimized Stable Diffusion [[code]](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) [[video]](https://www.youtube.com/watch?v=8KHeUjjc-XQ)
## Discussion
@@ -52,17 +39,8 @@ If you think there is a need to discuss anything, you may jump to our [Slack](ht
If you encounter any problem while running these tutorials, you may want to raise an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose) in this repository.
## 🛠️ Setup environment
-You should use `conda` to create a virtual environment, we recommend **python 3.8**, e.g. `conda create -n colossal python=3.8`. This installation commands are for CUDA 11.3, if you have a different version of CUDA, please download PyTorch and Colossal-AI accordingly.
-
-```
-# install torch
-# visit https://pytorch.org/get-started/locally/ to download other versions
-pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
-
-# install latest ColossalAI
-# visit https://colossalai.org/download to download corresponding version of Colossal-AI
-pip install colossalai==0.1.11rc3+torch1.12cu11.3 -f https://release.colossalai.org
-```
+[[video]](https://www.youtube.com/watch?v=dpMYj974ZIc) You should use `conda` to create a virtual environment, we recommend **python 3.8**, e.g. `conda create -n colossal python=3.8`. This installation commands are for CUDA 11.3, if you have a different version of CUDA, please download PyTorch and Colossal-AI accordingly.
+You can refer to the [Installation](https://github.com/hpcaitech/ColossalAI#installation) to set up your environment.
You can run `colossalai check -i` to verify if you have correctly set up your environment 🕹️.

@@ -74,120 +52,3 @@ Then clone the Colossal-AI repository from GitHub.
git clone https://github.com/hpcaitech/ColossalAI.git
cd ColossalAI/examples/tutorial
```
-
-## 🔥 Multi-dimensional Hybrid Parallel with Vision Transformer
-1. Go to **hybrid_parallel** folder in the **tutorial** directory.
-2. Install our model zoo.
-```bash
-pip install titans
-```
-3. Run with synthetic data which is of similar shape to CIFAR10 with the `-s` flag.
-```bash
-colossalai run --nproc_per_node 4 train.py --config config.py -s
-```
-
-4. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs.
-
-## ☀️ Sequence Parallel with BERT
-1. Go to the **sequence_parallel** folder in the **tutorial** directory.
-2. Run with the following command
-```bash
-export PYTHONPATH=$PWD
-colossalai run --nproc_per_node 4 train.py -s
-```
-3. The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again.
-
-## 📕 Large batch optimization with LARS and LAMB
-1. Go to the **large_batch_optimizer** folder in the **tutorial** directory.
-2. Run with synthetic data
-```bash
-colossalai run --nproc_per_node 4 train.py --config config.py -s
-```
-
-## 😀 Auto-Parallel Tutorial
-1. Go to the **auto_parallel** folder in the **tutorial** directory.
-2. Install `pulp` and `coin-or-cbc` for the solver.
-```bash
-pip install pulp
-conda install -c conda-forge coin-or-cbc
-```
-2. Run the auto parallel resnet example with 4 GPUs with synthetic dataset.
-```bash
-colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s
-```
-
-You should expect to the log like this. This log shows the edge cost on the computation graph as well as the sharding strategy for an operation. For example, `layer1_0_conv1 S01R = S01R X RR` means that the first dimension (batch) of the input and output is sharded while the weight is not sharded (S means sharded, R means replicated), simply equivalent to data parallel training.
-
-
-## 🎆 Auto-Checkpoint Tutorial
-1. Stay in the `auto_parallel` folder.
-2. Install the dependencies.
-```bash
-pip install matplotlib transformers
-```
-3. Run a simple resnet50 benchmark to automatically checkpoint the model.
-```bash
-python auto_ckpt_solver_test.py --model resnet50
-```
-
-You should expect the log to be like this
-
-
-This shows that given different memory budgets, the model is automatically injected with activation checkpoint and its time taken per iteration. You can run this benchmark for GPT as well but it can much longer since the model is larger.
-```bash
-python auto_ckpt_solver_test.py --model gpt2
-```
-
-4. Run a simple benchmark to find the optimal batch size for checkpointed model.
-```bash
-python auto_ckpt_batchsize_test.py
-```
-
-You can expect the log to be like
-
-
-## 🚀 Run OPT finetuning and inference
-1. Install the dependency
-```bash
-pip install datasets accelerate
-```
-2. Run finetuning with synthetic datasets with one GPU
-```bash
-bash ./run_clm_synthetic.sh
-```
-3. Run finetuning with 4 GPUs
-```bash
-bash ./run_clm_synthetic.sh 16 0 125m 4
-```
-4. Run inference with OPT 125M
-```bash
-docker hpcaitech/tutorial:opt-inference
-docker run -it --rm --gpus all --ipc host -p 7070:7070 hpcaitech/tutorial:opt-inference
-```
-5. Start the http server inside the docker container with tensor parallel size 2
-```bash
-python opt_fastapi.py opt-125m --tp 2 --checkpoint /data/opt-125m
-```
-
-## 🖼️ Accelerate Stable Diffusion with Colossal-AI
-1. Create a new environment for diffusion
-```bash
-conda env create -f environment.yaml
-conda activate ldm
-```
-2. Install Colossal-AI from our official page
-```bash
-pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
-```
-3. Install PyTorch Lightning compatible commit
-```bash
-git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
-pip install -r requirements.txt && pip install .
-cd ..
-```
-
-4. Comment out the `from_pretrained` field in the `train_colossalai_cifar10.yaml`.
-5. Run training with CIFAR10.
-```bash
-python main.py -logdir /tmp -t true -postfix test -b configs/train_colossalai_cifar10.yaml
-```
diff --git a/examples/tutorial/auto_parallel/README.md b/examples/tutorial/auto_parallel/README.md
index e99a018c2da1..bb014b9067b2 100644
--- a/examples/tutorial/auto_parallel/README.md
+++ b/examples/tutorial/auto_parallel/README.md
@@ -1,73 +1,52 @@
-# Auto-Parallelism with ResNet
+# Auto-Parallelism
-## 🚀Quick Start
-### Auto-Parallel Tutorial
-1. Install `pulp` and `coin-or-cbc` for the solver.
-```bash
-pip install pulp
-conda install -c conda-forge coin-or-cbc
-```
-2. Run the auto parallel resnet example with 4 GPUs with synthetic dataset.
-```bash
-colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py -s
-```
+## Table of contents
-You should expect to the log like this. This log shows the edge cost on the computation graph as well as the sharding strategy for an operation. For example, `layer1_0_conv1 S01R = S01R X RR` means that the first dimension (batch) of the input and output is sharded while the weight is not sharded (S means sharded, R means replicated), simply equivalent to data parallel training.
-
+- [Auto-Parallelism](#auto-parallelism)
+ - [Table of contents](#table-of-contents)
+ - [📚 Overview](#-overview)
+ - [🚀 Quick Start](#-quick-start)
+ - [Setup](#setup)
+ - [Auto-Parallel Tutorial](#auto-parallel-tutorial)
+ - [Auto-Checkpoint Tutorial](#auto-checkpoint-tutorial)
-### Auto-Checkpoint Tutorial
-1. Stay in the `auto_parallel` folder.
-2. Install the dependencies.
-```bash
-pip install matplotlib transformers
-```
-3. Run a simple resnet50 benchmark to automatically checkpoint the model.
-```bash
-python auto_ckpt_solver_test.py --model resnet50
-```
+## 📚 Overview
-You should expect the log to be like this
-
+This tutorial folder contains a simple demo to run auto-parallelism with ResNet. Meanwhile, this diretory also contains demo scripts to run automatic activation checkpointing, but both features are still experimental for now and no guarantee that they will work for your version of Colossal-AI.
-This shows that given different memory budgets, the model is automatically injected with activation checkpoint and its time taken per iteration. You can run this benchmark for GPT as well but it can much longer since the model is larger.
-```bash
-python auto_ckpt_solver_test.py --model gpt2
-```
+## 🚀 Quick Start
-4. Run a simple benchmark to find the optimal batch size for checkpointed model.
-```bash
-python auto_ckpt_batchsize_test.py
-```
+### Setup
-You can expect the log to be like
-
-
-
-## Prepare Dataset
-
-We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
-The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
-If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
+1. Create a conda environment
```bash
-export DATA=/path/to/data
+conda create -n auto python=3.8
+conda activate auto
```
-## extra requirements to use autoparallel
+2. Install `requirements` and `coin-or-cbc` for the solver.
```bash
-pip install pulp
-conda install coin-or-cbc
+pip install -r requirements.txt
+conda install -c conda-forge coin-or-cbc
```
-## Run on 2*2 device mesh
+
+### Auto-Parallel Tutorial
+
+Run the auto parallel resnet example with 4 GPUs with synthetic dataset.
```bash
colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py
```
-## Auto Checkpoint Benchmarking
+You should expect to the log like this. This log shows the edge cost on the computation graph as well as the sharding strategy for an operation. For example, `layer1_0_conv1 S01R = S01R X RR` means that the first dimension (batch) of the input and output is sharded while the weight is not sharded (S means sharded, R means replicated), simply equivalent to data parallel training.
+
+
+
+### Auto-Checkpoint Tutorial
We prepare two bechmarks for you to test the performance of auto checkpoint
@@ -86,21 +65,3 @@ python auto_ckpt_solver_test.py --model resnet50
# tun auto_ckpt_batchsize_test.py
python auto_ckpt_batchsize_test.py
```
-
-There are some results for your reference
-
-## Auto Checkpoint Solver Test
-
-### ResNet 50
-
-
-### GPT2 Medium
-
-
-## Auto Checkpoint Batch Size Test
-```bash
-===============test summary================
-batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s
-batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s
-batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s
-```
diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
index e4aff13e484a..a6a9ad0a312c 100644
--- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
+++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py
@@ -1,37 +1,13 @@
-import argparse
-import os
-from pathlib import Path
-
import torch
-from titans.utils import barrier_context
-from torch.fx import GraphModule
-from torchvision import transforms
-from torchvision.datasets import CIFAR10
from torchvision.models import resnet50
from tqdm import tqdm
import colossalai
-from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
-from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
-from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
-from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
-from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
-from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
-from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
+from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingLR
-from colossalai.utils import get_dataloader
-
-DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()
-
-
-def parse_args():
- parser = argparse.ArgumentParser()
- parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10")
- return parser.parse_args()
def synthesize_data():
@@ -41,82 +17,20 @@ def synthesize_data():
def main():
- args = parse_args()
colossalai.launch_from_torch(config='./config.py')
logger = get_dist_logger()
- if not args.synthetic:
- with barrier_context():
- # build dataloaders
- train_dataset = CIFAR10(root=DATA_ROOT,
- download=True,
- transform=transforms.Compose([
- transforms.RandomCrop(size=32, padding=4),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
- std=[0.2023, 0.1994, 0.2010]),
- ]))
-
- test_dataset = CIFAR10(root=DATA_ROOT,
- train=False,
- transform=transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
- ]))
-
- train_dataloader = get_dataloader(
- dataset=train_dataset,
- add_sampler=True,
- shuffle=True,
- batch_size=gpc.config.BATCH_SIZE,
- pin_memory=True,
- )
-
- test_dataloader = get_dataloader(
- dataset=test_dataset,
- add_sampler=True,
- batch_size=gpc.config.BATCH_SIZE,
- pin_memory=True,
- )
- else:
- train_dataloader, test_dataloader = None, None
-
- # initialize device mesh
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
# trace the model with meta data
- tracer = ColoTracer()
model = resnet50(num_classes=10).cuda()
- input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
-
- # prepare info for solver
- solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- graph_analyser = GraphAnalyser(gm)
-
- # solve the solution
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
- ret = solver.call_solver_serialized_args()
- solution = list(ret[0])
- if gpc.get_global_rank() == 0:
- for index, node in enumerate(graph.nodes):
- print(node.name, node.strategies_vector[solution[index]].name)
- # process the graph for distributed training ability
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
- gm = runtime_apply_pass(gm)
- gm.recompile()
+ input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
+ device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True)
+ model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True)
+ if gpc.get_global_rank() == 0:
+ for node_strategy in solution:
+ print(node_strategy)
# build criterion
criterion = torch.nn.CrossEntropyLoss()
@@ -127,65 +41,46 @@ def main():
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
for epoch in range(gpc.config.NUM_EPOCHS):
- gm.train()
+ model.train()
- if args.synthetic:
- # if we use synthetic data
- # we assume it only has 30 steps per epoch
- num_steps = range(30)
-
- else:
- # we use the actual number of steps for training
- num_steps = range(len(train_dataloader))
- data_iter = iter(train_dataloader)
+ # if we use synthetic data
+ # we assume it only has 10 steps per epoch
+ num_steps = range(10)
progress = tqdm(num_steps)
for _ in progress:
- if args.synthetic:
- # generate fake data
- img, label = synthesize_data()
- else:
- # get the real data
- img, label = next(data_iter)
+ # generate fake data
+ img, label = synthesize_data()
img = img.cuda()
label = label.cuda()
optimizer.zero_grad()
- output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
+ output = model(img)
train_loss = criterion(output, label)
train_loss.backward(train_loss)
+ torch.cuda.synchronize()
optimizer.step()
lr_scheduler.step()
# run evaluation
- gm.eval()
+ model.eval()
correct = 0
total = 0
- if args.synthetic:
- # if we use synthetic data
- # we assume it only has 10 steps for evaluation
- num_steps = range(30)
-
- else:
- # we use the actual number of steps for training
- num_steps = range(len(test_dataloader))
- data_iter = iter(test_dataloader)
+ # if we use synthetic data
+ # we assume it only has 10 steps for evaluation
+ num_steps = range(10)
progress = tqdm(num_steps)
for _ in progress:
- if args.synthetic:
- # generate fake data
- img, label = synthesize_data()
- else:
- # get the real data
- img, label = next(data_iter)
+ # generate fake data
+ img, label = synthesize_data()
img = img.cuda()
label = label.cuda()
with torch.no_grad():
- output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
+ output = model(img)
test_loss = criterion(output, label)
pred = torch.argmax(output, dim=-1)
correct += torch.sum(pred == label)
diff --git a/examples/tutorial/auto_parallel/config.py b/examples/tutorial/auto_parallel/config.py
index fa14eda740f7..52e0abcef698 100644
--- a/examples/tutorial/auto_parallel/config.py
+++ b/examples/tutorial/auto_parallel/config.py
@@ -1,2 +1,2 @@
-BATCH_SIZE = 128
-NUM_EPOCHS = 10
+BATCH_SIZE = 32
+NUM_EPOCHS = 2
diff --git a/examples/tutorial/auto_parallel/requirements.txt b/examples/tutorial/auto_parallel/requirements.txt
index 137a69e80498..ce89e7c80070 100644
--- a/examples/tutorial/auto_parallel/requirements.txt
+++ b/examples/tutorial/auto_parallel/requirements.txt
@@ -1,2 +1,7 @@
-colossalai >= 0.1.12
-torch >= 1.8.1
+torch
+colossalai
+titans
+pulp
+datasets
+matplotlib
+transformers
diff --git a/examples/tutorial/stable_diffusion/setup.py b/examples/tutorial/auto_parallel/setup.py
similarity index 68%
rename from examples/tutorial/stable_diffusion/setup.py
rename to examples/tutorial/auto_parallel/setup.py
index a24d54167640..6e6cff32ed23 100644
--- a/examples/tutorial/stable_diffusion/setup.py
+++ b/examples/tutorial/auto_parallel/setup.py
@@ -1,7 +1,7 @@
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
setup(
- name='latent-diffusion',
+ name='auto_parallel',
version='0.0.1',
description='',
packages=find_packages(),
@@ -10,4 +10,4 @@
'numpy',
'tqdm',
],
-)
\ No newline at end of file
+)
diff --git a/examples/tutorial/auto_parallel/test_ci.sh b/examples/tutorial/auto_parallel/test_ci.sh
new file mode 100644
index 000000000000..bf6275b673ff
--- /dev/null
+++ b/examples/tutorial/auto_parallel/test_ci.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+set -euxo pipefail
+
+pip install -r requirements.txt
+conda install -c conda-forge coin-or-cbc
+colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py
diff --git a/examples/tutorial/fastfold/FastFold b/examples/tutorial/fastfold/FastFold
new file mode 160000
index 000000000000..867587b3aa4e
--- /dev/null
+++ b/examples/tutorial/fastfold/FastFold
@@ -0,0 +1 @@
+Subproject commit 867587b3aa4e43bdaf64f9910127842f1dfbfebd
diff --git a/examples/tutorial/fastfold/README.md b/examples/tutorial/fastfold/README.md
new file mode 100644
index 000000000000..434d033b9792
--- /dev/null
+++ b/examples/tutorial/fastfold/README.md
@@ -0,0 +1,49 @@
+# FastFold Inference
+
+## Table of contents
+
+- [FastFold Inference](#fastfold-inference)
+ - [Table of contents](#table-of-contents)
+ - [📚 Overview](#-overview)
+ - [🚀 Quick Start](#-quick-start)
+ - [🔍 Dive into FastFold](#-dive-into-fastfold)
+
+## 📚 Overview
+
+This example lets you to try out the inference of [FastFold](https://github.com/hpcaitech/FastFold).
+
+## 🚀 Quick Start
+
+1. Install FastFold
+
+We highly recommend you to install FastFold with conda.
+```
+git clone https://github.com/hpcaitech/FastFold
+cd FastFold
+conda env create --name=fastfold -f environment.yml
+conda activate fastfold
+python setup.py install
+```
+
+2. Download datasets.
+
+It may take ~900GB space to keep datasets.
+```
+./scripts/download_all_data.sh data/
+```
+
+3. Run the inference scripts.
+
+```
+bash inference.sh
+```
+You can find predictions under the `outputs` dir.
+
+## 🔍 Dive into FastFold
+
+There are another features of [FastFold](https://github.com/hpcaitech/FastFold), such as:
++ more excellent kernel based on triton
++ much faster data processing based on ray
++ training supported
+
+More detailed information can be seen [here](https://github.com/hpcaitech/FastFold/).
diff --git a/examples/tutorial/fp8/mnist/README.md b/examples/tutorial/fp8/mnist/README.md
new file mode 100644
index 000000000000..46711f9ebdd8
--- /dev/null
+++ b/examples/tutorial/fp8/mnist/README.md
@@ -0,0 +1,13 @@
+# Basic MNIST Example with optional FP8 of TransformerEngine
+
+[TransformerEngine](https://github.com/NVIDIA/TransformerEngine) is a library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower memory utilization in both training and inference.
+
+Thanks for the contribution to this tutorial from NVIDIA.
+
+```bash
+python main.py
+python main.py --use-te # Linear layers from TransformerEngine
+python main.py --use-fp8 # FP8 + TransformerEngine for Linear layers
+```
+
+> We are working to integrate it with Colossal-AI and will finish it soon.
diff --git a/examples/tutorial/fp8/mnist/main.py b/examples/tutorial/fp8/mnist/main.py
new file mode 100644
index 000000000000..000ded2f111f
--- /dev/null
+++ b/examples/tutorial/fp8/mnist/main.py
@@ -0,0 +1,237 @@
+# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import argparse
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import datasets, transforms
+from torch.optim.lr_scheduler import StepLR
+
+try:
+ from transformer_engine import pytorch as te
+ HAVE_TE = True
+except (ImportError, ModuleNotFoundError):
+ HAVE_TE = False
+
+
+class Net(nn.Module):
+ def __init__(self, use_te=False):
+ super(Net, self).__init__()
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
+ self.dropout1 = nn.Dropout(0.25)
+ self.dropout2 = nn.Dropout(0.5)
+ if use_te:
+ self.fc1 = te.Linear(9216, 128)
+ self.fc2 = te.Linear(128, 16)
+ else:
+ self.fc1 = nn.Linear(9216, 128)
+ self.fc2 = nn.Linear(128, 16)
+ self.fc3 = nn.Linear(16, 10)
+
+ def forward(self, x):
+ """FWD"""
+ x = self.conv1(x)
+ x = F.relu(x)
+ x = self.conv2(x)
+ x = F.relu(x)
+ x = F.max_pool2d(x, 2)
+ x = self.dropout1(x)
+ x = torch.flatten(x, 1)
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = self.dropout2(x)
+ x = self.fc2(x)
+ x = self.fc3(x)
+ output = F.log_softmax(x, dim=1)
+ return output
+
+
+def train(args, model, device, train_loader, optimizer, epoch, use_fp8):
+ """Training function."""
+ model.train()
+ for batch_idx, (data, target) in enumerate(train_loader):
+ data, target = data.to(device), target.to(device)
+ optimizer.zero_grad()
+ with te.fp8_autocast(enabled=use_fp8):
+ output = model(data)
+ loss = F.nll_loss(output, target)
+ loss.backward()
+ optimizer.step()
+ if batch_idx % args.log_interval == 0:
+ print(
+ f"Train Epoch: {epoch} "
+ f"[{batch_idx * len(data)}/{len(train_loader.dataset)} "
+ f"({100. * batch_idx / len(train_loader):.0f}%)]\t"
+ f"Loss: {loss.item():.6f}"
+ )
+ if args.dry_run:
+ break
+
+
+def calibrate(model, device, test_loader):
+ """Calibration function."""
+ model.eval()
+ test_loss = 0
+ correct = 0
+ with torch.no_grad():
+ for data, target in test_loader:
+ data, target = data.to(device), target.to(device)
+ with te.fp8_autocast(enabled=False, calibrating=True):
+ output = model(data)
+
+def test(model, device, test_loader, use_fp8):
+ """Testing function."""
+ model.eval()
+ test_loss = 0
+ correct = 0
+ with torch.no_grad():
+ for data, target in test_loader:
+ data, target = data.to(device), target.to(device)
+ with te.fp8_autocast(enabled=use_fp8):
+ output = model(data)
+ test_loss += F.nll_loss(
+ output, target, reduction="sum"
+ ).item() # sum up batch loss
+ pred = output.argmax(
+ dim=1, keepdim=True
+ ) # get the index of the max log-probability
+ correct += pred.eq(target.view_as(pred)).sum().item()
+
+ test_loss /= len(test_loader.dataset)
+
+ print(
+ f"\nTest set: Average loss: {test_loss:.4f}, "
+ f"Accuracy: {correct}/{len(test_loader.dataset)} "
+ f"({100. * correct / len(test_loader.dataset):.0f}%)\n"
+ )
+
+
+def main():
+ # Training settings
+ parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=64,
+ metavar="N",
+ help="input batch size for training (default: 64)",
+ )
+ parser.add_argument(
+ "--test-batch-size",
+ type=int,
+ default=1000,
+ metavar="N",
+ help="input batch size for testing (default: 1000)",
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=14,
+ metavar="N",
+ help="number of epochs to train (default: 14)",
+ )
+ parser.add_argument(
+ "--lr",
+ type=float,
+ default=1.0,
+ metavar="LR",
+ help="learning rate (default: 1.0)",
+ )
+ parser.add_argument(
+ "--gamma",
+ type=float,
+ default=0.7,
+ metavar="M",
+ help="Learning rate step gamma (default: 0.7)",
+ )
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ default=False,
+ help="quickly check a single pass",
+ )
+ parser.add_argument(
+ "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=10,
+ metavar="N",
+ help="how many batches to wait before logging training status",
+ )
+ parser.add_argument(
+ "--save-model",
+ action="store_true",
+ default=False,
+ help="For Saving the current Model",
+ )
+ parser.add_argument(
+ "--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
+ )
+ parser.add_argument(
+ "--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
+ )
+ parser.add_argument(
+ "--use-te", action="store_true", default=False, help="Use Transformer Engine"
+ )
+ args = parser.parse_args()
+ use_cuda = torch.cuda.is_available()
+
+ if args.use_te or args.use_fp8 or args.use_fp8_infer:
+ assert HAVE_TE, "TransformerEngine not installed."
+
+ if args.use_fp8 or args.use_fp8_infer:
+ args.use_te = True
+
+ if args.use_te:
+ assert use_cuda, "CUDA needed for FP8 execution."
+
+ if args.use_fp8_infer:
+ assert not args.use_fp8, "fp8-infer path currently only supports calibration from a bfloat checkpoint"
+
+ torch.manual_seed(args.seed)
+
+ device = torch.device("cuda" if use_cuda else "cpu")
+
+ train_kwargs = {"batch_size": args.batch_size}
+ test_kwargs = {"batch_size": args.test_batch_size}
+ if use_cuda:
+ cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
+ train_kwargs.update(cuda_kwargs)
+ test_kwargs.update(cuda_kwargs)
+
+ transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
+ )
+ dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform)
+ dataset2 = datasets.MNIST("../data", train=False, transform=transform)
+ train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
+ test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
+
+ model = Net(use_te=args.use_te).to(device)
+ optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
+
+ scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+ for epoch in range(1, args.epochs + 1):
+ train(args, model, device, train_loader, optimizer, epoch, args.use_fp8)
+ test(model, device, test_loader, args.use_fp8)
+ scheduler.step()
+
+ if args.use_fp8_infer:
+ calibrate(model, device, test_loader)
+
+ if args.save_model or args.use_fp8_infer:
+ torch.save(model.state_dict(), "mnist_cnn.pt")
+ print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer))
+ weights = torch.load("mnist_cnn.pt")
+ model.load_state_dict(weights)
+ test(model, device, test_loader, args.use_fp8_infer)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/tutorial/hybrid_parallel/README.md b/examples/tutorial/hybrid_parallel/README.md
index 6f975e86330a..1b5e54f928d4 100644
--- a/examples/tutorial/hybrid_parallel/README.md
+++ b/examples/tutorial/hybrid_parallel/README.md
@@ -1,45 +1,40 @@
# Multi-dimensional Parallelism with Colossal-AI
+## Table of contents
-## 🚀Quick Start
-1. Install our model zoo.
-```bash
-pip install titans
-```
-2. Run with synthetic data which is of similar shape to CIFAR10 with the `-s` flag.
-```bash
-colossalai run --nproc_per_node 4 train.py --config config.py -s
-```
+- [Overview](#-overview)
+- [Quick Start](#-quick-start)
-3. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs.
+## 📚 Overview
+This example lets you to quickly try out the hybrid parallelism provided by Colossal-AI.
+You can change the parameters below to try out different settings in the `config.py`.
-## Install Titans Model Zoo
+```python
+# parallel setting
+TENSOR_PARALLEL_SIZE = 2
+TENSOR_PARALLEL_MODE = '1d'
-```bash
-pip install titans
+parallel = dict(
+ pipeline=2,
+ tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
+)
```
+## 🚀 Quick Start
-## Prepare Dataset
+1. Install PyTorch
-We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
-The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
-If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
+2. Install the dependencies.
```bash
-export DATA=/path/to/data
+pip install -r requirements.txt
```
-
-## Run on 2*2 device mesh
-
-Current configuration setting on `config.py` is TP=2, PP=2.
+3. Run the training scripts with synthetic data.
```bash
-# train with cifar10
colossalai run --nproc_per_node 4 train.py --config config.py
-
-# train with synthetic data
-colossalai run --nproc_per_node 4 train.py --config config.py -s
```
+
+4. Modify the config file to play with different types of tensor parallelism, for example, change tensor parallel size to be 4 and mode to be 2d and run on 8 GPUs.
diff --git a/examples/tutorial/hybrid_parallel/config.py b/examples/tutorial/hybrid_parallel/config.py
index 2450ab1c7a72..fe9abf2f1955 100644
--- a/examples/tutorial/hybrid_parallel/config.py
+++ b/examples/tutorial/hybrid_parallel/config.py
@@ -3,20 +3,20 @@
# hyperparameters
# BATCH_SIZE is as per GPU
# global batch size = BATCH_SIZE x data parallel size
-BATCH_SIZE = 256
+BATCH_SIZE = 4
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
-NUM_EPOCHS = 10
-WARMUP_EPOCHS = 3
+NUM_EPOCHS = 2
+WARMUP_EPOCHS = 1
# model config
IMG_SIZE = 224
PATCH_SIZE = 16
-HIDDEN_SIZE = 512
+HIDDEN_SIZE = 128
DEPTH = 4
NUM_HEADS = 4
MLP_RATIO = 2
-NUM_CLASSES = 1000
+NUM_CLASSES = 10
CHECKPOINT = False
SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
diff --git a/examples/tutorial/hybrid_parallel/requirements.txt b/examples/tutorial/hybrid_parallel/requirements.txt
index 137a69e80498..99b7ecfe162e 100644
--- a/examples/tutorial/hybrid_parallel/requirements.txt
+++ b/examples/tutorial/hybrid_parallel/requirements.txt
@@ -1,2 +1,3 @@
-colossalai >= 0.1.12
-torch >= 1.8.1
+torch
+colossalai
+titans
diff --git a/examples/tutorial/hybrid_parallel/test_ci.sh b/examples/tutorial/hybrid_parallel/test_ci.sh
new file mode 100644
index 000000000000..e0dbef354e2d
--- /dev/null
+++ b/examples/tutorial/hybrid_parallel/test_ci.sh
@@ -0,0 +1,5 @@
+#!/bin/bash
+set -euxo pipefail
+
+pip install -r requirements.txt
+colossalai run --nproc_per_node 4 train.py --config config.py
diff --git a/examples/tutorial/hybrid_parallel/train.py b/examples/tutorial/hybrid_parallel/train.py
index 0f2a207cb172..4953d5350f31 100644
--- a/examples/tutorial/hybrid_parallel/train.py
+++ b/examples/tutorial/hybrid_parallel/train.py
@@ -1,7 +1,6 @@
import os
import torch
-from titans.dataloader.cifar10 import build_cifar
from titans.model.vit.vit import _create_vit_model
from tqdm import tqdm
@@ -12,7 +11,7 @@
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.pipeline.pipelinable import PipelinableContext
-from colossalai.utils import get_dataloader, is_using_pp
+from colossalai.utils import is_using_pp
class DummyDataloader():
@@ -42,12 +41,9 @@ def __len__(self):
def main():
- # initialize distributed setting
+ # launch from torch
parser = colossalai.get_default_parser()
- parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
args = parser.parse_args()
-
- # launch from torch
colossalai.launch_from_torch(config=args.config)
# get logger
@@ -94,15 +90,10 @@ def main():
pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
- # create dataloaders
- root = os.environ.get('DATA', '../data')
- if args.synthetic:
- # if we use synthetic dataset
- # we train for 30 steps and eval for 10 steps per epoch
- train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE)
- test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
- else:
- train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
+ # use synthetic dataset
+ # we train for 10 steps and eval for 5 steps per epoch
+ train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
+ test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
@@ -139,6 +130,7 @@ def main():
engine.execute_schedule(data_iter, return_output_label=False)
engine.step()
lr_scheduler.step()
+ gpc.destroy()
if __name__ == '__main__':
diff --git a/examples/tutorial/large_batch_optimizer/README.md b/examples/tutorial/large_batch_optimizer/README.md
index 20bddb383434..1a17c2d8740f 100644
--- a/examples/tutorial/large_batch_optimizer/README.md
+++ b/examples/tutorial/large_batch_optimizer/README.md
@@ -1,31 +1,37 @@
-# Comparison of Large Batch Training Optimization
+# Large Batch Training Optimization
-## 🚀Quick Start
-Run with synthetic data
-```bash
-colossalai run --nproc_per_node 4 train.py --config config.py -s
-```
+## Table of contents
+- [Large Batch Training Optimization](#large-batch-training-optimization)
+ - [Table of contents](#table-of-contents)
+ - [📚 Overview](#-overview)
+ - [🚀 Quick Start](#-quick-start)
-## Prepare Dataset
+## 📚 Overview
-We use CIFAR10 dataset in this example. You should invoke the `donwload_cifar10.py` in the tutorial root directory or directly run the `auto_parallel_with_resnet.py`.
-The dataset will be downloaded to `colossalai/examples/tutorials/data` by default.
-If you wish to use customized directory for the dataset. You can set the environment variable `DATA` via the following command.
+This example lets you to quickly try out the large batch training optimization provided by Colossal-AI. We use synthetic dataset to go through the process, thus, you don't need to prepare any dataset. You can try out the `Lamb` and `Lars` optimizers from Colossal-AI with the following code.
-```bash
-export DATA=/path/to/data
+```python
+from colossalai.nn.optimizer import Lamb, Lars
```
-You can also use synthetic data for this tutorial if you don't wish to download the `CIFAR10` dataset by adding the `-s` or `--synthetic` flag to the command.
+## 🚀 Quick Start
+
+1. Install PyTorch
+2. Install the dependencies.
+
+```bash
+pip install -r requirements.txt
+```
-## Run on 2*2 device mesh
+3. Run the training scripts with synthetic data.
```bash
-# run with cifar10
-colossalai run --nproc_per_node 4 train.py --config config.py
+# run on 4 GPUs
+# run with lars
+colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lars
-# run with synthetic dataset
-colossalai run --nproc_per_node 4 train.py --config config.py -s
+# run with lamb
+colossalai run --nproc_per_node 4 train.py --config config.py --optimizer lamb
```
diff --git a/examples/tutorial/large_batch_optimizer/config.py b/examples/tutorial/large_batch_optimizer/config.py
index e019154e4b12..2efa0ffd0556 100644
--- a/examples/tutorial/large_batch_optimizer/config.py
+++ b/examples/tutorial/large_batch_optimizer/config.py
@@ -6,31 +6,11 @@
BATCH_SIZE = 512
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
-NUM_EPOCHS = 10
-WARMUP_EPOCHS = 3
+NUM_EPOCHS = 2
+WARMUP_EPOCHS = 1
# model config
-IMG_SIZE = 224
-PATCH_SIZE = 16
-HIDDEN_SIZE = 512
-DEPTH = 4
-NUM_HEADS = 4
-MLP_RATIO = 2
-NUM_CLASSES = 1000
-CHECKPOINT = False
-SEQ_LENGTH = (IMG_SIZE // PATCH_SIZE)**2 + 1 # add 1 for cls token
-
-# parallel setting
-TENSOR_PARALLEL_SIZE = 2
-TENSOR_PARALLEL_MODE = '1d'
-
-parallel = dict(
- pipeline=2,
- tensor=dict(mode=TENSOR_PARALLEL_MODE, size=TENSOR_PARALLEL_SIZE),
-)
+NUM_CLASSES = 10
fp16 = dict(mode=AMP_TYPE.NAIVE)
clip_grad_norm = 1.0
-
-# pipeline config
-NUM_MICRO_BATCHES = parallel['pipeline']
diff --git a/examples/tutorial/large_batch_optimizer/requirements.txt b/examples/tutorial/large_batch_optimizer/requirements.txt
index 137a69e80498..c013287751bf 100644
--- a/examples/tutorial/large_batch_optimizer/requirements.txt
+++ b/examples/tutorial/large_batch_optimizer/requirements.txt
@@ -1,2 +1,3 @@
-colossalai >= 0.1.12
-torch >= 1.8.1
+colossalai
+torch
+titans
diff --git a/examples/tutorial/large_batch_optimizer/test_ci.sh b/examples/tutorial/large_batch_optimizer/test_ci.sh
new file mode 100644
index 000000000000..89f426c542b1
--- /dev/null
+++ b/examples/tutorial/large_batch_optimizer/test_ci.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+set -euxo pipefail
+
+pip install -r requirements.txt
+
+# run test
+colossalai run --nproc_per_node 4 --master_port 29500 train.py --config config.py --optimizer lars
+colossalai run --nproc_per_node 4 --master_port 29501 train.py --config config.py --optimizer lamb
diff --git a/examples/tutorial/large_batch_optimizer/train.py b/examples/tutorial/large_batch_optimizer/train.py
index d403c275d1af..35e54582f494 100644
--- a/examples/tutorial/large_batch_optimizer/train.py
+++ b/examples/tutorial/large_batch_optimizer/train.py
@@ -1,19 +1,13 @@
-import os
-
import torch
-from titans.dataloader.cifar10 import build_cifar
-from titans.model.vit.vit import _create_vit_model
+import torch.nn as nn
+from torchvision.models import resnet18
from tqdm import tqdm
import colossalai
-from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
-from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import Lamb, Lars
-from colossalai.pipeline.pipelinable import PipelinableContext
-from colossalai.utils import get_dataloader, is_using_pp
class DummyDataloader():
@@ -45,7 +39,10 @@ def __len__(self):
def main():
# initialize distributed setting
parser = colossalai.get_default_parser()
- parser.add_argument('-s', '--synthetic', action="store_true", help="whether use synthetic data")
+ parser.add_argument('--optimizer',
+ choices=['lars', 'lamb'],
+ help="Choose your large-batch optimizer",
+ required=True)
args = parser.parse_args()
# launch from torch
@@ -55,59 +52,22 @@ def main():
logger = get_dist_logger()
logger.info("initialized distributed environment", ranks=[0])
- if hasattr(gpc.config, 'LOG_PATH'):
- if gpc.get_global_rank() == 0:
- log_path = gpc.config.LOG_PATH
- if not os.path.exists(log_path):
- os.mkdir(log_path)
- logger.log_to_file(log_path)
-
- use_pipeline = is_using_pp()
-
- # create model
- model_kwargs = dict(img_size=gpc.config.IMG_SIZE,
- patch_size=gpc.config.PATCH_SIZE,
- hidden_size=gpc.config.HIDDEN_SIZE,
- depth=gpc.config.DEPTH,
- num_heads=gpc.config.NUM_HEADS,
- mlp_ratio=gpc.config.MLP_RATIO,
- num_classes=10,
- init_method='jax',
- checkpoint=gpc.config.CHECKPOINT)
-
- if use_pipeline:
- pipelinable = PipelinableContext()
- with pipelinable:
- model = _create_vit_model(**model_kwargs)
- pipelinable.to_layer_list()
- pipelinable.policy = "uniform"
- model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
- else:
- model = _create_vit_model(**model_kwargs)
-
- # count number of parameters
- total_numel = 0
- for p in model.parameters():
- total_numel += p.numel()
- if not gpc.is_initialized(ParallelMode.PIPELINE):
- pipeline_stage = 0
- else:
- pipeline_stage = gpc.get_local_rank(ParallelMode.PIPELINE)
- logger.info(f"number of parameters: {total_numel} on pipeline stage {pipeline_stage}")
-
- # create dataloaders
- root = os.environ.get('DATA', '../data/')
- if args.synthetic:
- train_dataloader = DummyDataloader(length=30, batch_size=gpc.config.BATCH_SIZE)
- test_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
- else:
- train_dataloader, test_dataloader = build_cifar(gpc.config.BATCH_SIZE, root, pad_if_needed=True)
+ # create synthetic dataloaders
+ train_dataloader = DummyDataloader(length=10, batch_size=gpc.config.BATCH_SIZE)
+ test_dataloader = DummyDataloader(length=5, batch_size=gpc.config.BATCH_SIZE)
+
+ # build model
+ model = resnet18(num_classes=gpc.config.NUM_CLASSES)
# create loss function
- criterion = CrossEntropyLoss(label_smoothing=0.1)
+ criterion = nn.CrossEntropyLoss()
# create optimizer
- optimizer = Lars(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
+ if args.optimizer == "lars":
+ optim_cls = Lars
+ elif args.optimizer == "lamb":
+ optim_cls = Lamb
+ optimizer = optim_cls(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer,
diff --git a/examples/tutorial/opt/inference/README.md b/examples/tutorial/opt/inference/README.md
index 5bacac0d74ad..20ad4a23fdeb 100644
--- a/examples/tutorial/opt/inference/README.md
+++ b/examples/tutorial/opt/inference/README.md
@@ -50,7 +50,7 @@ python opt_fastapi.py --queue_size
```
The `` can be an integer in `[0, MAXINT]`. If it's `0`, the request queue size is infinite. If it's a positive integer, when the request queue is full, incoming requests will be dropped (the HTTP status code of response will be 406).
-### Configure bathcing
+### Configure batching
```shell
python opt_fastapi.py --max_batch_size
```
@@ -85,4 +85,4 @@ Then open the web interface link which is on your console.
See [script/processing_ckpt_66b.py](./script/processing_ckpt_66b.py).
## OPT-175B
-See [script/process-opt-175b](./script/process-opt-175b/).
\ No newline at end of file
+See [script/process-opt-175b](./script/process-opt-175b/).
diff --git a/examples/tutorial/opt/inference/requirements.txt b/examples/tutorial/opt/inference/requirements.txt
index e6e8511e3178..966dff4746f2 100644
--- a/examples/tutorial/opt/inference/requirements.txt
+++ b/examples/tutorial/opt/inference/requirements.txt
@@ -7,3 +7,4 @@ torch>=1.10.0
transformers==4.23.1
uvicorn==0.19.0
colossalai
+git+https://github.com/hpcaitech/EnergonAI@main
diff --git a/examples/tutorial/sequence_parallel/README.md b/examples/tutorial/sequence_parallel/README.md
index 7058f53db8b6..1b7c60e22861 100644
--- a/examples/tutorial/sequence_parallel/README.md
+++ b/examples/tutorial/sequence_parallel/README.md
@@ -1,139 +1,56 @@
-# Sequence Parallelism with BERT
+# Sequence Parallelism
-In this example, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate
-activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length.
+## Table of contents
-Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
+- [Sequence Parallelism](#sequence-parallelism)
+ - [Table of contents](#table-of-contents)
+ - [📚 Overview](#-overview)
+ - [🚀 Quick Start](#-quick-start)
+ - [🏎 How to Train with Sequence Parallelism](#-how-to-train-with-sequence-parallelism)
+ - [Step 1. Configure your parameters](#step-1-configure-your-parameters)
+ - [Step 2. Invoke parallel training](#step-2-invoke-parallel-training)
-## 🚀Quick Start
-1. Run with the following command
-```bash
-export PYTHONPATH=$PWD
-colossalai run --nproc_per_node 4 train.py -s
-```
-2. The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again.
-
-
-## How to Prepare WikiPedia Dataset
-
-First, let's prepare the WikiPedia dataset from scratch. To generate a preprocessed dataset, we need four items:
-1. raw WikiPedia dataset
-2. wikipedia extractor (extract data from the raw dataset)
-3. vocabulary file
-4. preprocessing scripts (generate final data from extracted data)
-
-For the preprocessing script, we thank Megatron-LM for providing a preprocessing script to generate the corpus file.
-
-```python
-# download raw data
-mkdir data && cd ./data
-wget https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2
-
-# install wiki extractor
-git clone https://github.com/FrankLeeeee/wikiextractor.git
-pip install ./wikiextractor
-
-# extractmodule
-wikiextractor --json enwiki-latest-pages-articles.xml.bz2
-cat text/*/* > ./corpus.json
-cd ..
-
-# download vocab file
-mkdir vocab && cd ./vocab
-wget https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt
-cd ..
-
-# preprocess some data
-git clone https://github.com/NVIDIA/Megatron-LM.git
-cd ./Megatron-LM
-python tools/preprocess_data.py \
- --input ../data/corpus.json \
- --output-prefix my-bert \
- --vocab ../vocab/bert-large-uncased-vocab.txt \
- --dataset-impl mmap \
- --tokenizer-type BertWordPieceLowerCase \
- --split-sentences \
- --workers 24
-```
+## 📚 Overview
-After running the preprocessing scripts, you will obtain two files:
-1. my-bert_text_sentence.bin
-2. my-bert_text_sentence.idx
-
-If you happen to encouter `index out of range` problem when running Megatron's script,
-this is probably because that a sentence starts with a punctuation and cannot be tokenized. A work-around is to update `Encoder.encode` method with the code below:
-
-```python
-class Encoder(object):
- def __init__(self, args):
- ...
-
- def initializer(self):
- ...
-
- def encode(self, json_line):
- data = json.loads(json_line)
- ids = {}
- for key in self.args.json_keys:
- text = data[key]
- doc_ids = []
-
- # lsg: avoid sentences which start with a punctuation
- # as it cannot be tokenized by splitter
- if len(text) > 0 and text[0] in string.punctuation:
- text = text[1:]
-
- for sentence in Encoder.splitter.tokenize(text):
- sentence_ids = Encoder.tokenizer.tokenize(sentence)
- if len(sentence_ids) > 0:
- doc_ids.append(sentence_ids)
- if len(doc_ids) > 0 and self.args.append_eod:
- doc_ids[-1].append(Encoder.tokenizer.eod)
- ids[key] = doc_ids
- return ids, len(json_line)
-```
+In this tutorial, we implemented BERT with sequence parallelism. Sequence parallelism splits the input tensor and intermediate
+activation along the sequence dimension. This method can achieve better memory efficiency and allows us to train with larger batch size and longer sequence length.
-## How to Train with Sequence Parallelism
+Paper: [Sequence Parallelism: Long Sequence Training from System Perspective](https://arxiv.org/abs/2105.13120)
-We provided `train.py` for you to execute training. Before invoking the script, there are several
-steps to perform.
+## 🚀 Quick Start
-### Step 1. Set data path and vocab path
+1. Install PyTorch
-At the top of `config.py`, you can see two global variables `DATA_PATH` and `VOCAB_FILE_PATH`.
+2. Install the dependencies.
-```python
-DATA_PATH =
-VOCAB_FILE_PATH =
+```bash
+pip install -r requirements.txt
```
-`DATA_PATH` refers to the path to the data file generated by Megatron's script. For example, in the section above, you should get two data files (my-bert_text_sentence.bin and my-bert_text_sentence.idx). You just need to `DATA_PATH` to the path to the bin file without the file extension.
+3. Run with the following command
-For example, if your my-bert_text_sentence.bin is /home/Megatron-LM/my-bert_text_sentence.bin, then you should set
+```bash
+export PYTHONPATH=$PWD
-```python
-DATA_PATH = '/home/Megatron-LM/my-bert_text_sentence'
+# run with synthetic dataset
+colossalai run --nproc_per_node 4 train.py
```
-The `VOCAB_FILE_PATH` refers to the path to the vocabulary downloaded when you prepare the dataset
-(e.g. bert-large-uncased-vocab.txt).
+> The default config is sequence parallel size = 2, pipeline size = 1, let’s change pipeline size to be 2 and try it again.
-### Step 3. Make Dataset Helper
-Build BERT dataset helper. Requirements are `CUDA`, `g++`, `pybind11` and `make`.
+## 🏎 How to Train with Sequence Parallelism
-```python
-cd ./data/datasets
-make
-```
+We provided `train.py` for you to execute training. Before invoking the script, there are several
+steps to perform.
-### Step 3. Configure your parameters
+### Step 1. Configure your parameters
In the `config.py` provided, a set of parameters are defined including training scheme, model, etc.
You can also modify the ColossalAI setting. For example, if you wish to parallelize over the
sequence dimension on 8 GPUs. You can change `size=4` to `size=8`. If you wish to use pipeline parallelism, you can set `pipeline=`.
-### Step 4. Invoke parallel training
+### Step 2. Invoke parallel training
Lastly, you can start training with sequence parallelism. How you invoke `train.py` depends on your
machine setting.
diff --git a/examples/tutorial/sequence_parallel/config.py b/examples/tutorial/sequence_parallel/config.py
index df0c5282f032..6edf9cc2c7e5 100644
--- a/examples/tutorial/sequence_parallel/config.py
+++ b/examples/tutorial/sequence_parallel/config.py
@@ -1,11 +1,8 @@
from colossalai.amp import AMP_TYPE
-DATA_PATH = ''
-VOCAB_FILE_PATH = ''
-
# hyper-parameters
-TRAIN_ITERS = 1000000
-DECAY_ITERS = 990000
+TRAIN_ITERS = 10
+DECAY_ITERS = 4
WARMUP_FRACTION = 0.01
GLOBAL_BATCH_SIZE = 32 # dp world size * sentences per GPU
EVAL_ITERS = 10
@@ -13,12 +10,12 @@
LR = 0.0001
MIN_LR = 1e-05
WEIGHT_DECAY = 0.01
-SEQ_LENGTH = 512
+SEQ_LENGTH = 128
# BERT config
-DEPTH = 12
-NUM_ATTENTION_HEADS = 12
-HIDDEN_SIZE = 768
+DEPTH = 4
+NUM_ATTENTION_HEADS = 4
+HIDDEN_SIZE = 128
# model config
ADD_BINARY_HEAD = False
diff --git a/examples/tutorial/sequence_parallel/requirements.txt b/examples/tutorial/sequence_parallel/requirements.txt
index 137a69e80498..b49a94554afb 100644
--- a/examples/tutorial/sequence_parallel/requirements.txt
+++ b/examples/tutorial/sequence_parallel/requirements.txt
@@ -1,2 +1,2 @@
-colossalai >= 0.1.12
-torch >= 1.8.1
+colossalai
+torch
diff --git a/examples/tutorial/sequence_parallel/test_ci.sh b/examples/tutorial/sequence_parallel/test_ci.sh
new file mode 100644
index 000000000000..7bc20de3b6e4
--- /dev/null
+++ b/examples/tutorial/sequence_parallel/test_ci.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+set -euxo pipefail
+
+pip install -r requirements.txt
+
+# run test
+colossalai run --nproc_per_node 4 train.py
diff --git a/examples/tutorial/sequence_parallel/train.py b/examples/tutorial/sequence_parallel/train.py
index b92061000d10..a89747b5845e 100644
--- a/examples/tutorial/sequence_parallel/train.py
+++ b/examples/tutorial/sequence_parallel/train.py
@@ -1,9 +1,8 @@
import argparse
import torch
-from data import build_train_valid_test_data_iterators
from data.bert_helper import SequenceParallelDataIterator, get_batch_for_sequence_parallel
-from data.tokenizer import get_padded_vocab_size, initialize_tokenizer
+from data.dummy_dataloader import DummyDataloader
from loss_func.bert_loss import BertLoss
from lr_scheduler import AnnealingLR
from model.bert import BertForPretrain, build_pipeline_bert
@@ -36,7 +35,7 @@ def parse_args():
def pipeline_data_process_func(stage_output, micro_batch_data):
- tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
+ tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
if gpc.is_first_rank(ParallelMode.PIPELINE):
data = (tokens, padding_mask, types, lm_labels)
label = (loss_mask, sentence_order)
@@ -53,36 +52,15 @@ def main():
logger = get_dist_logger()
- # build dataloader
- if not args.synthetic:
- initialize_tokenizer(gpc.config.VOCAB_FILE_PATH, tokenizer_type='BertWordPieceLowerCase')
- VOCAB_SIZE = get_padded_vocab_size()
- trainloader, validloader, testloader = build_train_valid_test_data_iterators(
- train_iters=gpc.config.TRAIN_ITERS,
- global_batch_size=gpc.config.GLOBAL_BATCH_SIZE,
- eval_interval=gpc.config.EVAL_INTERVAL,
- eval_iters=gpc.config.EVAL_ITERS,
- data_prefix=[gpc.config.DATA_PATH],
- data_impl='mmap',
- splits_string='949,50,1',
- max_seq_length=gpc.config.SEQ_LENGTH,
- masked_lm_prob=0.15,
- short_seq_prob=0.1,
- seed=1234,
- skip_warmup=True,
- binary_head=False,
- )
- else:
- from data.dummy_dataloader import DummyDataloader
-
- BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
- VOCAB_SIZE = 30528
- trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
- vocab_size=VOCAB_SIZE,
- seq_length=gpc.config.SEQ_LENGTH)
- validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
- vocab_size=VOCAB_SIZE,
- seq_length=gpc.config.SEQ_LENGTH)
+ # build synthetic dataloader
+ BATCH_SIZE_PER_GPUS = gpc.config.GLOBAL_BATCH_SIZE // gpc.get_world_size(ParallelMode.DATA)
+ VOCAB_SIZE = 30528
+ trainloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
+ vocab_size=VOCAB_SIZE,
+ seq_length=gpc.config.SEQ_LENGTH)
+ validloader = DummyDataloader(batch_size=BATCH_SIZE_PER_GPUS,
+ vocab_size=VOCAB_SIZE,
+ seq_length=gpc.config.SEQ_LENGTH)
logger.info("Dataloaders are built", ranks=[0])
diff --git a/examples/tutorial/stable_diffusion/LICENSE b/examples/tutorial/stable_diffusion/LICENSE
deleted file mode 100644
index 0e609df0d8cd..000000000000
--- a/examples/tutorial/stable_diffusion/LICENSE
+++ /dev/null
@@ -1,82 +0,0 @@
-Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
-
-CreativeML Open RAIL-M
-dated August 22, 2022
-
-Section I: PREAMBLE
-
-Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
-
-Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
-
-In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
-
-Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
-
-This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
-
-NOW THEREFORE, You and Licensor agree as follows:
-
-1. Definitions
-
-- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
-- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
-- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
-- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
-- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
-- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
-- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
-- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
-- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
-- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
-- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
-- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
-
-Section II: INTELLECTUAL PROPERTY RIGHTS
-
-Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
-
-2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
-3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
-
-Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
-
-4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
-Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
-You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
-You must cause any modified files to carry prominent notices stating that You changed the files;
-You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
-You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
-5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
-6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
-
-Section IV: OTHER PROVISIONS
-
-7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model.
-8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
-9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
-10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
-11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
-12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
-
-END OF TERMS AND CONDITIONS
-
-
-
-
-Attachment A
-
-Use Restrictions
-
-You agree not to use the Model or Derivatives of the Model:
-- In any way that violates any applicable national, federal, state, local or international law or regulation;
-- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
-- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
-- To generate or disseminate personal identifiable information that can be used to harm an individual;
-- To defame, disparage or otherwise harass others;
-- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
-- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
-- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
-- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
-- To provide medical advice and medical results interpretation;
-- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
diff --git a/examples/tutorial/stable_diffusion/README.md b/examples/tutorial/stable_diffusion/README.md
deleted file mode 100644
index a0ece4485d27..000000000000
--- a/examples/tutorial/stable_diffusion/README.md
+++ /dev/null
@@ -1,149 +0,0 @@
-# Stable Diffusion with Colossal-AI
-*[Colosssal-AI](https://github.com/hpcaitech/ColossalAI) provides a faster and lower cost solution for pretraining and
-fine-tuning for AIGC (AI-Generated Content) applications such as the model [stable-diffusion](https://github.com/CompVis/stable-diffusion) from [Stability AI](https://stability.ai/).*
-
-We take advantage of [Colosssal-AI](https://github.com/hpcaitech/ColossalAI) to exploit multiple optimization strategies
-, e.g. data parallelism, tensor parallelism, mixed precision & ZeRO, to scale the training to multiple GPUs.
-
-## 🚀Quick Start
-1. Create a new environment for diffusion
-```bash
-conda env create -f environment.yaml
-conda activate ldm
-```
-2. Install Colossal-AI from our official page
-```bash
-pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
-```
-3. Install PyTorch Lightning compatible commit
-```bash
-git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
-pip install -r requirements.txt && pip install .
-cd ..
-```
-
-4. Comment out the `from_pretrained` field in the `train_colossalai_cifar10.yaml`.
-5. Run training with CIFAR10.
-```bash
-python main.py -logdir /tmp -t true -postfix test -b configs/train_colossalai_cifar10.yaml
-```
-
-## Stable Diffusion
-[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) is a latent text-to-image diffusion
-model.
-Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
-Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
-this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
-
-
-
-
-
-[Stable Diffusion with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion) provides **6.5x faster training and pretraining cost saving, the hardware cost of fine-tuning can be almost 7X cheaper** (from RTX3090/4090 24GB to RTX3050/2070 8GB).
-
-
-
-
-
-## Requirements
-A suitable [conda](https://conda.io/) environment named `ldm` can be created
-and activated with:
-
-```
-conda env create -f environment.yaml
-conda activate ldm
-```
-
-You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
-
-```
-conda install pytorch torchvision -c pytorch
-pip install transformers==4.19.2 diffusers invisible-watermark
-pip install -e .
-```
-
-### Install [Colossal-AI v0.1.10](https://colossalai.org/download/) From Our Official Website
-```
-pip install colossalai==0.1.10+torch1.11cu11.3 -f https://release.colossalai.org
-```
-
-### Install [Lightning](https://github.com/Lightning-AI/lightning)
-We use the Sep. 2022 version with commit id as `b04a7aa`.
-```
-git clone https://github.com/Lightning-AI/lightning && cd lightning && git reset --hard b04a7aa
-pip install -r requirements.txt && pip install .
-```
-
-> The specified version is due to the interface incompatibility caused by the latest update of [Lightning](https://github.com/Lightning-AI/lightning), which will be fixed in the near future.
-
-## Dataset
-The dataSet is from [LAION-5B](https://laion.ai/blog/laion-5b/), the subset of [LAION](https://laion.ai/),
-you should the change the `data.file_path` in the `config/train_colossalai.yaml`
-
-## Training
-
-We provide the script `train.sh` to run the training task , and two Stategy in `configs`:`train_colossalai.yaml`
-
-For example, you can run the training from colossalai by
-```
-python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml
-```
-
-- you can change the `--logdir` the save the log information and the last checkpoint
-
-### Training config
-You can change the trainging config in the yaml file
-
-- accelerator: acceleratortype, default 'gpu'
-- devices: device number used for training, default 4
-- max_epochs: max training epochs
-- precision: usefp16 for training or not, default 16, you must use fp16 if you want to apply colossalai
-
-## Example
-
-### Training on cifar10
-
-We provide the finetuning example on CIFAR10 dataset
-
-You can run by config `train_colossalai_cifar10.yaml`
-```
-python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai_cifar10.yaml
-```
-
-
-
-## Comments
-
-- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
-, [lucidrains](https://github.com/lucidrains/denoising-diffusion-pytorch),
-[Stable Diffusion](https://github.com/CompVis/stable-diffusion), [Lightning](https://github.com/Lightning-AI/lightning) and [Hugging Face](https://huggingface.co/CompVis/stable-diffusion).
-Thanks for open-sourcing!
-
-- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
-
-- The implementation of [flash attention](https://github.com/HazyResearch/flash-attention) is from [HazyResearch](https://github.com/HazyResearch).
-
-## BibTeX
-
-```
-@article{bian2021colossal,
- title={Colossal-AI: A Unified Deep Learning System For Large-Scale Parallel Training},
- author={Bian, Zhengda and Liu, Hongxin and Wang, Boxiang and Huang, Haichen and Li, Yongbin and Wang, Chuanrui and Cui, Fan and You, Yang},
- journal={arXiv preprint arXiv:2110.14883},
- year={2021}
-}
-@misc{rombach2021highresolution,
- title={High-Resolution Image Synthesis with Latent Diffusion Models},
- author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
- year={2021},
- eprint={2112.10752},
- archivePrefix={arXiv},
- primaryClass={cs.CV}
-}
-@article{dao2022flashattention,
- title={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
- author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
- journal={arXiv preprint arXiv:2205.14135},
- year={2022}
-}
-```
diff --git a/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml b/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml
deleted file mode 100644
index c457787dd881..000000000000
--- a/examples/tutorial/stable_diffusion/configs/train_colossalai.yaml
+++ /dev/null
@@ -1,116 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: image
- cond_stage_key: caption
- image_size: 64
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1.e-4 ]
- f_min: [ 1.e-10 ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: False
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
- params:
- use_fp16: True
-
-data:
- target: main.DataModuleFromConfig
- params:
- batch_size: 64
- wrap: False
- train:
- target: ldm.data.base.Txt2ImgIterableBaseDataset
- params:
- file_path: "/data/scratch/diffuser/laion_part0/"
- world_size: 1
- rank: 0
-
-lightning:
- trainer:
- accelerator: 'gpu'
- devices: 4
- log_gpu_memory: all
- max_epochs: 2
- precision: 16
- auto_select_gpus: False
- strategy:
- target: pytorch_lightning.strategies.ColossalAIStrategy
- params:
- use_chunk: False
- enable_distributed_storage: True,
- placement_policy: cuda
- force_outputs_fp32: False
-
- log_every_n_steps: 2
- logger: True
- default_root_dir: "/tmp/diff_log/"
- profiler: pytorch
-
- logger_config:
- wandb:
- target: pytorch_lightning.loggers.WandbLogger
- params:
- name: nowname
- save_dir: "/tmp/diff_log/"
- offline: opt.debug
- id: nowname
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml b/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml
deleted file mode 100644
index 63b9d1c0179c..000000000000
--- a/examples/tutorial/stable_diffusion/configs/train_colossalai_cifar10.yaml
+++ /dev/null
@@ -1,123 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: image
- cond_stage_key: txt
- image_size: 64
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1.e-4 ]
- f_min: [ 1.e-10 ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: False
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
- params:
- use_fp16: True
-
-data:
- target: main.DataModuleFromConfig
- params:
- batch_size: 4
- num_workers: 4
- train:
- target: ldm.data.cifar10.hf_dataset
- params:
- name: cifar10
- image_transforms:
- - target: torchvision.transforms.Resize
- params:
- size: 512
- interpolation: 3
- - target: torchvision.transforms.RandomCrop
- params:
- size: 512
- - target: torchvision.transforms.RandomHorizontalFlip
-
-lightning:
- trainer:
- accelerator: 'gpu'
- devices: 2
- log_gpu_memory: all
- max_epochs: 2
- precision: 16
- auto_select_gpus: False
- strategy:
- target: pytorch_lightning.strategies.ColossalAIStrategy
- params:
- use_chunk: False
- enable_distributed_storage: True,
- placement_policy: cuda
- force_outputs_fp32: False
-
- log_every_n_steps: 2
- logger: True
- default_root_dir: "/tmp/diff_log/"
- profiler: pytorch
-
- logger_config:
- wandb:
- target: pytorch_lightning.loggers.WandbLogger
- params:
- name: nowname
- save_dir: "/tmp/diff_log/"
- offline: opt.debug
- id: nowname
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/configs/train_ddp.yaml b/examples/tutorial/stable_diffusion/configs/train_ddp.yaml
deleted file mode 100644
index 90d41258fada..000000000000
--- a/examples/tutorial/stable_diffusion/configs/train_ddp.yaml
+++ /dev/null
@@ -1,113 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: image
- cond_stage_key: caption
- image_size: 32
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 100 ]
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1.e-4 ]
- f_min: [ 1.e-10 ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: False
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
- params:
- use_fp16: True
-
-data:
- target: main.DataModuleFromConfig
- params:
- batch_size: 64
- wrap: False
- train:
- target: ldm.data.base.Txt2ImgIterableBaseDataset
- params:
- file_path: "/data/scratch/diffuser/laion_part0/"
- world_size: 1
- rank: 0
-
-lightning:
- trainer:
- accelerator: 'gpu'
- devices: 4
- log_gpu_memory: all
- max_epochs: 2
- precision: 16
- auto_select_gpus: False
- strategy:
- target: pytorch_lightning.strategies.DDPStrategy
- params:
- find_unused_parameters: False
- log_every_n_steps: 2
-# max_steps: 6o
- logger: True
- default_root_dir: "/tmp/diff_log/"
- # profiler: pytorch
-
- logger_config:
- wandb:
- target: pytorch_lightning.loggers.WandbLogger
- params:
- name: nowname
- save_dir: "/tmp/diff_log/"
- offline: opt.debug
- id: nowname
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml b/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml
deleted file mode 100644
index 8b5d2adfaf17..000000000000
--- a/examples/tutorial/stable_diffusion/configs/train_pokemon.yaml
+++ /dev/null
@@ -1,121 +0,0 @@
-model:
- base_learning_rate: 1.0e-04
- target: ldm.models.diffusion.ddpm.LatentDiffusion
- params:
- linear_start: 0.00085
- linear_end: 0.0120
- num_timesteps_cond: 1
- log_every_t: 200
- timesteps: 1000
- first_stage_key: image
- cond_stage_key: caption
- image_size: 32
- channels: 4
- cond_stage_trainable: false # Note: different from the one we trained before
- conditioning_key: crossattn
- monitor: val/loss_simple_ema
- scale_factor: 0.18215
- use_ema: False
- check_nan_inf: False
-
- scheduler_config: # 10000 warmup steps
- target: ldm.lr_scheduler.LambdaLinearScheduler
- params:
- warm_up_steps: [ 10000 ]
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
- f_start: [ 1.e-6 ]
- f_max: [ 1.e-4 ]
- f_min: [ 1.e-10 ]
-
- unet_config:
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
- params:
- image_size: 32 # unused
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin'
- in_channels: 4
- out_channels: 4
- model_channels: 320
- attention_resolutions: [ 4, 2, 1 ]
- num_res_blocks: 2
- channel_mult: [ 1, 2, 4, 4 ]
- num_heads: 8
- use_spatial_transformer: True
- transformer_depth: 1
- context_dim: 768
- use_checkpoint: False
- legacy: False
-
- first_stage_config:
- target: ldm.models.autoencoder.AutoencoderKL
- params:
- embed_dim: 4
- from_pretrained: '/data/scratch/diffuser/stable-diffusion-v1-4/vae/diffusion_pytorch_model.bin'
- monitor: val/rec_loss
- ddconfig:
- double_z: true
- z_channels: 4
- resolution: 256
- in_channels: 3
- out_ch: 3
- ch: 128
- ch_mult:
- - 1
- - 2
- - 4
- - 4
- num_res_blocks: 2
- attn_resolutions: []
- dropout: 0.0
- lossconfig:
- target: torch.nn.Identity
-
- cond_stage_config:
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
- params:
- use_fp16: True
-
-data:
- target: main.DataModuleFromConfig
- params:
- batch_size: 32
- wrap: False
- train:
- target: ldm.data.pokemon.PokemonDataset
- # params:
- # file_path: "/data/scratch/diffuser/laion_part0/"
- # world_size: 1
- # rank: 0
-
-lightning:
- trainer:
- accelerator: 'gpu'
- devices: 4
- log_gpu_memory: all
- max_epochs: 2
- precision: 16
- auto_select_gpus: False
- strategy:
- target: pytorch_lightning.strategies.ColossalAIStrategy
- params:
- use_chunk: False
- enable_distributed_storage: True,
- placement_policy: cuda
- force_outputs_fp32: False
- initial_scale: 65536
- min_scale: 1
- max_scale: 65536
- # max_scale: 4294967296
-
- log_every_n_steps: 2
- logger: True
- default_root_dir: "/tmp/diff_log/"
- profiler: pytorch
-
- logger_config:
- wandb:
- target: pytorch_lightning.loggers.WandbLogger
- params:
- name: nowname
- save_dir: "/tmp/diff_log/"
- offline: opt.debug
- id: nowname
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/environment.yaml b/examples/tutorial/stable_diffusion/environment.yaml
deleted file mode 100644
index 7d8aec86f288..000000000000
--- a/examples/tutorial/stable_diffusion/environment.yaml
+++ /dev/null
@@ -1,34 +0,0 @@
-name: ldm
-channels:
- - pytorch
- - defaults
-dependencies:
- - python=3.9.12
- - pip=20.3
- - cudatoolkit=11.3
- - pytorch=1.11.0
- - torchvision=0.12.0
- - numpy=1.19.2
- - pip:
- - albumentations==0.4.3
- - datasets
- - diffusers
- - opencv-python==4.6.0.66
- - pudb==2019.2
- - invisible-watermark
- - imageio==2.9.0
- - imageio-ffmpeg==0.4.2
- - pytorch-lightning==1.8.0
- - omegaconf==2.1.1
- - test-tube>=0.7.5
- - streamlit>=0.73.1
- - einops==0.3.0
- - torch-fidelity==0.3.0
- - transformers==4.19.2
- - torchmetrics==0.7.0
- - kornia==0.6
- - prefetch_generator
- - colossalai
- - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- - -e git+https://github.com/openai/CLIP.git@main#egg=clip
- - -e .
diff --git a/examples/tutorial/stable_diffusion/ldm/data/base.py b/examples/tutorial/stable_diffusion/ldm/data/base.py
deleted file mode 100644
index 4f3cd35714a0..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/data/base.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import math
-from abc import abstractmethod
-
-import torch
-from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
-import os
-import numpy as np
-import cv2
-
-class Txt2ImgIterableBaseDataset(IterableDataset):
- '''
- Define an interface to make the IterableDatasets for text2img data chainable
- '''
- def __init__(self, file_path: str, rank, world_size):
- super().__init__()
- self.file_path = file_path
- self.folder_list = []
- self.file_list = []
- self.txt_list = []
- self.info = self._get_file_info(file_path)
- self.start = self.info['start']
- self.end = self.info['end']
- self.rank = rank
-
- self.world_size = world_size
- # self.per_worker = int(math.floor((self.end - self.start) / float(self.world_size)))
- # self.iter_start = self.start + self.rank * self.per_worker
- # self.iter_end = min(self.iter_start + self.per_worker, self.end)
- # self.num_records = self.iter_end - self.iter_start
- # self.valid_ids = [i for i in range(self.iter_end)]
- self.num_records = self.end - self.start
- self.valid_ids = [i for i in range(self.end)]
-
- print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
-
- def __len__(self):
- # return self.iter_end - self.iter_start
- return self.end - self.start
-
- def __iter__(self):
- sample_iterator = self._sample_generator(self.start, self.end)
- # sample_iterator = self._sample_generator(self.iter_start, self.iter_end)
- return sample_iterator
-
- def _sample_generator(self, start, end):
- for idx in range(start, end):
- file_name = self.file_list[idx]
- txt_name = self.txt_list[idx]
- f_ = open(txt_name, 'r')
- txt_ = f_.read()
- f_.close()
- image = cv2.imdecode(np.fromfile(file_name, dtype=np.uint8), 1)
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- image = torch.from_numpy(image) / 255
- yield {"caption": txt_, "image":image}
-
-
- def _get_file_info(self, file_path):
- info = \
- {
- "start": 1,
- "end": 0,
- }
- self.folder_list = [file_path + i for i in os.listdir(file_path) if '.' not in i]
- for folder in self.folder_list:
- files = [folder + '/' + i for i in os.listdir(folder) if 'jpg' in i]
- txts = [k.replace('jpg', 'txt') for k in files]
- self.file_list.extend(files)
- self.txt_list.extend(txts)
- info['end'] = len(self.file_list)
- # with open(file_path, 'r') as fin:
- # for _ in enumerate(fin):
- # info['end'] += 1
- # self.txt_list = [k.replace('jpg', 'txt') for k in self.file_list]
- return info
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/data/cifar10.py b/examples/tutorial/stable_diffusion/ldm/data/cifar10.py
deleted file mode 100644
index 53cd61263b47..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/data/cifar10.py
+++ /dev/null
@@ -1,184 +0,0 @@
-from typing import Dict
-import numpy as np
-from omegaconf import DictConfig, ListConfig
-import torch
-from torch.utils.data import Dataset
-from pathlib import Path
-import json
-from PIL import Image
-from torchvision import transforms
-from einops import rearrange
-from ldm.util import instantiate_from_config
-from datasets import load_dataset
-
-def make_multi_folder_data(paths, caption_files=None, **kwargs):
- """Make a concat dataset from multiple folders
- Don't suport captions yet
- If paths is a list, that's ok, if it's a Dict interpret it as:
- k=folder v=n_times to repeat that
- """
- list_of_paths = []
- if isinstance(paths, (Dict, DictConfig)):
- assert caption_files is None, \
- "Caption files not yet supported for repeats"
- for folder_path, repeats in paths.items():
- list_of_paths.extend([folder_path]*repeats)
- paths = list_of_paths
-
- if caption_files is not None:
- datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)]
- else:
- datasets = [FolderData(p, **kwargs) for p in paths]
- return torch.utils.data.ConcatDataset(datasets)
-
-class FolderData(Dataset):
- def __init__(self,
- root_dir,
- caption_file=None,
- image_transforms=[],
- ext="jpg",
- default_caption="",
- postprocess=None,
- return_paths=False,
- ) -> None:
- """Create a dataset from a folder of images.
- If you pass in a root directory it will be searched for images
- ending in ext (ext can be a list)
- """
- self.root_dir = Path(root_dir)
- self.default_caption = default_caption
- self.return_paths = return_paths
- if isinstance(postprocess, DictConfig):
- postprocess = instantiate_from_config(postprocess)
- self.postprocess = postprocess
- if caption_file is not None:
- with open(caption_file, "rt") as f:
- ext = Path(caption_file).suffix.lower()
- if ext == ".json":
- captions = json.load(f)
- elif ext == ".jsonl":
- lines = f.readlines()
- lines = [json.loads(x) for x in lines]
- captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
- else:
- raise ValueError(f"Unrecognised format: {ext}")
- self.captions = captions
- else:
- self.captions = None
-
- if not isinstance(ext, (tuple, list, ListConfig)):
- ext = [ext]
-
- # Only used if there is no caption file
- self.paths = []
- for e in ext:
- self.paths.extend(list(self.root_dir.rglob(f"*.{e}")))
- if isinstance(image_transforms, ListConfig):
- image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
- image_transforms.extend([transforms.ToTensor(),
- transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
- image_transforms = transforms.Compose(image_transforms)
- self.tform = image_transforms
-
-
- def __len__(self):
- if self.captions is not None:
- return len(self.captions.keys())
- else:
- return len(self.paths)
-
- def __getitem__(self, index):
- data = {}
- if self.captions is not None:
- chosen = list(self.captions.keys())[index]
- caption = self.captions.get(chosen, None)
- if caption is None:
- caption = self.default_caption
- filename = self.root_dir/chosen
- else:
- filename = self.paths[index]
-
- if self.return_paths:
- data["path"] = str(filename)
-
- im = Image.open(filename)
- im = self.process_im(im)
- data["image"] = im
-
- if self.captions is not None:
- data["txt"] = caption
- else:
- data["txt"] = self.default_caption
-
- if self.postprocess is not None:
- data = self.postprocess(data)
-
- return data
-
- def process_im(self, im):
- im = im.convert("RGB")
- return self.tform(im)
-
-def hf_dataset(
- name,
- image_transforms=[],
- image_column="img",
- label_column="label",
- text_column="txt",
- split='train',
- image_key='image',
- caption_key='txt',
- ):
- """Make huggingface dataset with appropriate list of transforms applied
- """
- ds = load_dataset(name, split=split)
- image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
- image_transforms.extend([transforms.ToTensor(),
- transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
- tform = transforms.Compose(image_transforms)
-
- assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
- assert label_column in ds.column_names, f"Didn't find column {label_column} in {ds.column_names}"
-
- def pre_process(examples):
- processed = {}
- processed[image_key] = [tform(im) for im in examples[image_column]]
-
- label_to_text_dict = {0: "airplane", 1: "automobile", 2: "bird", 3: "cat", 4: "deer", 5: "dog", 6: "frog", 7: "horse", 8: "ship", 9: "truck"}
-
- processed[caption_key] = [label_to_text_dict[label] for label in examples[label_column]]
-
- return processed
-
- ds.set_transform(pre_process)
- return ds
-
-class TextOnly(Dataset):
- def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
- """Returns only captions with dummy images"""
- self.output_size = output_size
- self.image_key = image_key
- self.caption_key = caption_key
- if isinstance(captions, Path):
- self.captions = self._load_caption_file(captions)
- else:
- self.captions = captions
-
- if n_gpus > 1:
- # hack to make sure that all the captions appear on each gpu
- repeated = [n_gpus*[x] for x in self.captions]
- self.captions = []
- [self.captions.extend(x) for x in repeated]
-
- def __len__(self):
- return len(self.captions)
-
- def __getitem__(self, index):
- dummy_im = torch.zeros(3, self.output_size, self.output_size)
- dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
- return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
-
- def _load_caption_file(self, filename):
- with open(filename, 'rt') as f:
- captions = f.readlines()
- return [x.strip('\n') for x in captions]
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/data/imagenet.py b/examples/tutorial/stable_diffusion/ldm/data/imagenet.py
deleted file mode 100644
index 1c473f9c6965..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/data/imagenet.py
+++ /dev/null
@@ -1,394 +0,0 @@
-import os, yaml, pickle, shutil, tarfile, glob
-import cv2
-import albumentations
-import PIL
-import numpy as np
-import torchvision.transforms.functional as TF
-from omegaconf import OmegaConf
-from functools import partial
-from PIL import Image
-from tqdm import tqdm
-from torch.utils.data import Dataset, Subset
-
-import taming.data.utils as tdu
-from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
-from taming.data.imagenet import ImagePaths
-
-from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
-
-
-def synset2idx(path_to_yaml="data/index_synset.yaml"):
- with open(path_to_yaml) as f:
- di2s = yaml.load(f)
- return dict((v,k) for k,v in di2s.items())
-
-
-class ImageNetBase(Dataset):
- def __init__(self, config=None):
- self.config = config or OmegaConf.create()
- if not type(self.config)==dict:
- self.config = OmegaConf.to_container(self.config)
- self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
- self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
- self._prepare()
- self._prepare_synset_to_human()
- self._prepare_idx_to_synset()
- self._prepare_human_to_integer_label()
- self._load()
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, i):
- return self.data[i]
-
- def _prepare(self):
- raise NotImplementedError()
-
- def _filter_relpaths(self, relpaths):
- ignore = set([
- "n06596364_9591.JPEG",
- ])
- relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
- if "sub_indices" in self.config:
- indices = str_to_indices(self.config["sub_indices"])
- synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
- self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
- files = []
- for rpath in relpaths:
- syn = rpath.split("/")[0]
- if syn in synsets:
- files.append(rpath)
- return files
- else:
- return relpaths
-
- def _prepare_synset_to_human(self):
- SIZE = 2655750
- URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
- self.human_dict = os.path.join(self.root, "synset_human.txt")
- if (not os.path.exists(self.human_dict) or
- not os.path.getsize(self.human_dict)==SIZE):
- download(URL, self.human_dict)
-
- def _prepare_idx_to_synset(self):
- URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
- self.idx2syn = os.path.join(self.root, "index_synset.yaml")
- if (not os.path.exists(self.idx2syn)):
- download(URL, self.idx2syn)
-
- def _prepare_human_to_integer_label(self):
- URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
- self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
- if (not os.path.exists(self.human2integer)):
- download(URL, self.human2integer)
- with open(self.human2integer, "r") as f:
- lines = f.read().splitlines()
- assert len(lines) == 1000
- self.human2integer_dict = dict()
- for line in lines:
- value, key = line.split(":")
- self.human2integer_dict[key] = int(value)
-
- def _load(self):
- with open(self.txt_filelist, "r") as f:
- self.relpaths = f.read().splitlines()
- l1 = len(self.relpaths)
- self.relpaths = self._filter_relpaths(self.relpaths)
- print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
-
- self.synsets = [p.split("/")[0] for p in self.relpaths]
- self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
-
- unique_synsets = np.unique(self.synsets)
- class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
- if not self.keep_orig_class_label:
- self.class_labels = [class_dict[s] for s in self.synsets]
- else:
- self.class_labels = [self.synset2idx[s] for s in self.synsets]
-
- with open(self.human_dict, "r") as f:
- human_dict = f.read().splitlines()
- human_dict = dict(line.split(maxsplit=1) for line in human_dict)
-
- self.human_labels = [human_dict[s] for s in self.synsets]
-
- labels = {
- "relpath": np.array(self.relpaths),
- "synsets": np.array(self.synsets),
- "class_label": np.array(self.class_labels),
- "human_label": np.array(self.human_labels),
- }
-
- if self.process_images:
- self.size = retrieve(self.config, "size", default=256)
- self.data = ImagePaths(self.abspaths,
- labels=labels,
- size=self.size,
- random_crop=self.random_crop,
- )
- else:
- self.data = self.abspaths
-
-
-class ImageNetTrain(ImageNetBase):
- NAME = "ILSVRC2012_train"
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
- AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
- FILES = [
- "ILSVRC2012_img_train.tar",
- ]
- SIZES = [
- 147897477120,
- ]
-
- def __init__(self, process_images=True, data_root=None, **kwargs):
- self.process_images = process_images
- self.data_root = data_root
- super().__init__(**kwargs)
-
- def _prepare(self):
- if self.data_root:
- self.root = os.path.join(self.data_root, self.NAME)
- else:
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
-
- self.datadir = os.path.join(self.root, "data")
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
- self.expected_length = 1281167
- self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
- default=True)
- if not tdu.is_prepared(self.root):
- # prep
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
-
- datadir = self.datadir
- if not os.path.exists(datadir):
- path = os.path.join(self.root, self.FILES[0])
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
- import academictorrents as at
- atpath = at.get(self.AT_HASH, datastore=self.root)
- assert atpath == path
-
- print("Extracting {} to {}".format(path, datadir))
- os.makedirs(datadir, exist_ok=True)
- with tarfile.open(path, "r:") as tar:
- tar.extractall(path=datadir)
-
- print("Extracting sub-tars.")
- subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
- for subpath in tqdm(subpaths):
- subdir = subpath[:-len(".tar")]
- os.makedirs(subdir, exist_ok=True)
- with tarfile.open(subpath, "r:") as tar:
- tar.extractall(path=subdir)
-
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
- filelist = sorted(filelist)
- filelist = "\n".join(filelist)+"\n"
- with open(self.txt_filelist, "w") as f:
- f.write(filelist)
-
- tdu.mark_prepared(self.root)
-
-
-class ImageNetValidation(ImageNetBase):
- NAME = "ILSVRC2012_validation"
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
- AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
- VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
- FILES = [
- "ILSVRC2012_img_val.tar",
- "validation_synset.txt",
- ]
- SIZES = [
- 6744924160,
- 1950000,
- ]
-
- def __init__(self, process_images=True, data_root=None, **kwargs):
- self.data_root = data_root
- self.process_images = process_images
- super().__init__(**kwargs)
-
- def _prepare(self):
- if self.data_root:
- self.root = os.path.join(self.data_root, self.NAME)
- else:
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
- self.datadir = os.path.join(self.root, "data")
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
- self.expected_length = 50000
- self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
- default=False)
- if not tdu.is_prepared(self.root):
- # prep
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
-
- datadir = self.datadir
- if not os.path.exists(datadir):
- path = os.path.join(self.root, self.FILES[0])
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
- import academictorrents as at
- atpath = at.get(self.AT_HASH, datastore=self.root)
- assert atpath == path
-
- print("Extracting {} to {}".format(path, datadir))
- os.makedirs(datadir, exist_ok=True)
- with tarfile.open(path, "r:") as tar:
- tar.extractall(path=datadir)
-
- vspath = os.path.join(self.root, self.FILES[1])
- if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
- download(self.VS_URL, vspath)
-
- with open(vspath, "r") as f:
- synset_dict = f.read().splitlines()
- synset_dict = dict(line.split() for line in synset_dict)
-
- print("Reorganizing into synset folders")
- synsets = np.unique(list(synset_dict.values()))
- for s in synsets:
- os.makedirs(os.path.join(datadir, s), exist_ok=True)
- for k, v in synset_dict.items():
- src = os.path.join(datadir, k)
- dst = os.path.join(datadir, v)
- shutil.move(src, dst)
-
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
- filelist = sorted(filelist)
- filelist = "\n".join(filelist)+"\n"
- with open(self.txt_filelist, "w") as f:
- f.write(filelist)
-
- tdu.mark_prepared(self.root)
-
-
-
-class ImageNetSR(Dataset):
- def __init__(self, size=None,
- degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
- random_crop=True):
- """
- Imagenet Superresolution Dataloader
- Performs following ops in order:
- 1. crops a crop of size s from image either as random or center crop
- 2. resizes crop to size with cv2.area_interpolation
- 3. degrades resized crop with degradation_fn
-
- :param size: resizing to size after cropping
- :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
- :param downscale_f: Low Resolution Downsample factor
- :param min_crop_f: determines crop size s,
- where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
- :param max_crop_f: ""
- :param data_root:
- :param random_crop:
- """
- self.base = self.get_base()
- assert size
- assert (size / downscale_f).is_integer()
- self.size = size
- self.LR_size = int(size / downscale_f)
- self.min_crop_f = min_crop_f
- self.max_crop_f = max_crop_f
- assert(max_crop_f <= 1.)
- self.center_crop = not random_crop
-
- self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
-
- self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
-
- if degradation == "bsrgan":
- self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
-
- elif degradation == "bsrgan_light":
- self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
-
- else:
- interpolation_fn = {
- "cv_nearest": cv2.INTER_NEAREST,
- "cv_bilinear": cv2.INTER_LINEAR,
- "cv_bicubic": cv2.INTER_CUBIC,
- "cv_area": cv2.INTER_AREA,
- "cv_lanczos": cv2.INTER_LANCZOS4,
- "pil_nearest": PIL.Image.NEAREST,
- "pil_bilinear": PIL.Image.BILINEAR,
- "pil_bicubic": PIL.Image.BICUBIC,
- "pil_box": PIL.Image.BOX,
- "pil_hamming": PIL.Image.HAMMING,
- "pil_lanczos": PIL.Image.LANCZOS,
- }[degradation]
-
- self.pil_interpolation = degradation.startswith("pil_")
-
- if self.pil_interpolation:
- self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
-
- else:
- self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
- interpolation=interpolation_fn)
-
- def __len__(self):
- return len(self.base)
-
- def __getitem__(self, i):
- example = self.base[i]
- image = Image.open(example["file_path_"])
-
- if not image.mode == "RGB":
- image = image.convert("RGB")
-
- image = np.array(image).astype(np.uint8)
-
- min_side_len = min(image.shape[:2])
- crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
- crop_side_len = int(crop_side_len)
-
- if self.center_crop:
- self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
-
- else:
- self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
-
- image = self.cropper(image=image)["image"]
- image = self.image_rescaler(image=image)["image"]
-
- if self.pil_interpolation:
- image_pil = PIL.Image.fromarray(image)
- LR_image = self.degradation_process(image_pil)
- LR_image = np.array(LR_image).astype(np.uint8)
-
- else:
- LR_image = self.degradation_process(image=image)["image"]
-
- example["image"] = (image/127.5 - 1.0).astype(np.float32)
- example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
-
- return example
-
-
-class ImageNetSRTrain(ImageNetSR):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- def get_base(self):
- with open("data/imagenet_train_hr_indices.p", "rb") as f:
- indices = pickle.load(f)
- dset = ImageNetTrain(process_images=False,)
- return Subset(dset, indices)
-
-
-class ImageNetSRValidation(ImageNetSR):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- def get_base(self):
- with open("data/imagenet_val_hr_indices.p", "rb") as f:
- indices = pickle.load(f)
- dset = ImageNetValidation(process_images=False,)
- return Subset(dset, indices)
diff --git a/examples/tutorial/stable_diffusion/ldm/data/lsun.py b/examples/tutorial/stable_diffusion/ldm/data/lsun.py
deleted file mode 100644
index 6256e45715ff..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/data/lsun.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import os
-import numpy as np
-import PIL
-from PIL import Image
-from torch.utils.data import Dataset
-from torchvision import transforms
-
-
-class LSUNBase(Dataset):
- def __init__(self,
- txt_file,
- data_root,
- size=None,
- interpolation="bicubic",
- flip_p=0.5
- ):
- self.data_paths = txt_file
- self.data_root = data_root
- with open(self.data_paths, "r") as f:
- self.image_paths = f.read().splitlines()
- self._length = len(self.image_paths)
- self.labels = {
- "relative_file_path_": [l for l in self.image_paths],
- "file_path_": [os.path.join(self.data_root, l)
- for l in self.image_paths],
- }
-
- self.size = size
- self.interpolation = {"linear": PIL.Image.LINEAR,
- "bilinear": PIL.Image.BILINEAR,
- "bicubic": PIL.Image.BICUBIC,
- "lanczos": PIL.Image.LANCZOS,
- }[interpolation]
- self.flip = transforms.RandomHorizontalFlip(p=flip_p)
-
- def __len__(self):
- return self._length
-
- def __getitem__(self, i):
- example = dict((k, self.labels[k][i]) for k in self.labels)
- image = Image.open(example["file_path_"])
- if not image.mode == "RGB":
- image = image.convert("RGB")
-
- # default to score-sde preprocessing
- img = np.array(image).astype(np.uint8)
- crop = min(img.shape[0], img.shape[1])
- h, w, = img.shape[0], img.shape[1]
- img = img[(h - crop) // 2:(h + crop) // 2,
- (w - crop) // 2:(w + crop) // 2]
-
- image = Image.fromarray(img)
- if self.size is not None:
- image = image.resize((self.size, self.size), resample=self.interpolation)
-
- image = self.flip(image)
- image = np.array(image).astype(np.uint8)
- example["image"] = (image / 127.5 - 1.0).astype(np.float32)
- return example
-
-
-class LSUNChurchesTrain(LSUNBase):
- def __init__(self, **kwargs):
- super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
-
-
-class LSUNChurchesValidation(LSUNBase):
- def __init__(self, flip_p=0., **kwargs):
- super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
- flip_p=flip_p, **kwargs)
-
-
-class LSUNBedroomsTrain(LSUNBase):
- def __init__(self, **kwargs):
- super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
-
-
-class LSUNBedroomsValidation(LSUNBase):
- def __init__(self, flip_p=0.0, **kwargs):
- super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
- flip_p=flip_p, **kwargs)
-
-
-class LSUNCatsTrain(LSUNBase):
- def __init__(self, **kwargs):
- super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
-
-
-class LSUNCatsValidation(LSUNBase):
- def __init__(self, flip_p=0., **kwargs):
- super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
- flip_p=flip_p, **kwargs)
diff --git a/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py b/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py
deleted file mode 100644
index be39da9ca6da..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/lr_scheduler.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import numpy as np
-
-
-class LambdaWarmUpCosineScheduler:
- """
- note: use with a base_lr of 1.0
- """
- def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
- self.lr_warm_up_steps = warm_up_steps
- self.lr_start = lr_start
- self.lr_min = lr_min
- self.lr_max = lr_max
- self.lr_max_decay_steps = max_decay_steps
- self.last_lr = 0.
- self.verbosity_interval = verbosity_interval
-
- def schedule(self, n, **kwargs):
- if self.verbosity_interval > 0:
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
- if n < self.lr_warm_up_steps:
- lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
- self.last_lr = lr
- return lr
- else:
- t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
- t = min(t, 1.0)
- lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
- 1 + np.cos(t * np.pi))
- self.last_lr = lr
- return lr
-
- def __call__(self, n, **kwargs):
- return self.schedule(n,**kwargs)
-
-
-class LambdaWarmUpCosineScheduler2:
- """
- supports repeated iterations, configurable via lists
- note: use with a base_lr of 1.0.
- """
- def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
- assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
- self.lr_warm_up_steps = warm_up_steps
- self.f_start = f_start
- self.f_min = f_min
- self.f_max = f_max
- self.cycle_lengths = cycle_lengths
- self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
- self.last_f = 0.
- self.verbosity_interval = verbosity_interval
-
- def find_in_interval(self, n):
- interval = 0
- for cl in self.cum_cycles[1:]:
- if n <= cl:
- return interval
- interval += 1
-
- def schedule(self, n, **kwargs):
- cycle = self.find_in_interval(n)
- n = n - self.cum_cycles[cycle]
- if self.verbosity_interval > 0:
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
- f"current cycle {cycle}")
- if n < self.lr_warm_up_steps[cycle]:
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
- self.last_f = f
- return f
- else:
- t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
- t = min(t, 1.0)
- f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
- 1 + np.cos(t * np.pi))
- self.last_f = f
- return f
-
- def __call__(self, n, **kwargs):
- return self.schedule(n, **kwargs)
-
-
-class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
-
- def schedule(self, n, **kwargs):
- cycle = self.find_in_interval(n)
- n = n - self.cum_cycles[cycle]
- if self.verbosity_interval > 0:
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
- f"current cycle {cycle}")
-
- if n < self.lr_warm_up_steps[cycle]:
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
- self.last_f = f
- return f
- else:
- f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
- self.last_f = f
- return f
-
diff --git a/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py b/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py
deleted file mode 100644
index 873d8b69bd22..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/autoencoder.py
+++ /dev/null
@@ -1,544 +0,0 @@
-import torch
-import pytorch_lightning as pl
-import torch.nn.functional as F
-from contextlib import contextmanager
-
-from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
-
-from ldm.modules.diffusionmodules.model import Encoder, Decoder
-from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
-
-from ldm.util import instantiate_from_config
-
-
-class VQModel(pl.LightningModule):
- def __init__(self,
- ddconfig,
- lossconfig,
- n_embed,
- embed_dim,
- ckpt_path=None,
- ignore_keys=[],
- image_key="image",
- colorize_nlabels=None,
- monitor=None,
- batch_resize_range=None,
- scheduler_config=None,
- lr_g_factor=1.0,
- remap=None,
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
- use_ema=False
- ):
- super().__init__()
- self.embed_dim = embed_dim
- self.n_embed = n_embed
- self.image_key = image_key
- self.encoder = Encoder(**ddconfig)
- self.decoder = Decoder(**ddconfig)
- self.loss = instantiate_from_config(lossconfig)
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
- remap=remap,
- sane_index_shape=sane_index_shape)
- self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
- if colorize_nlabels is not None:
- assert type(colorize_nlabels)==int
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
- if monitor is not None:
- self.monitor = monitor
- self.batch_resize_range = batch_resize_range
- if self.batch_resize_range is not None:
- print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
-
- self.use_ema = use_ema
- if self.use_ema:
- self.model_ema = LitEma(self)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
- self.scheduler_config = scheduler_config
- self.lr_g_factor = lr_g_factor
-
- @contextmanager
- def ema_scope(self, context=None):
- if self.use_ema:
- self.model_ema.store(self.parameters())
- self.model_ema.copy_to(self)
- if context is not None:
- print(f"{context}: Switched to EMA weights")
- try:
- yield None
- finally:
- if self.use_ema:
- self.model_ema.restore(self.parameters())
- if context is not None:
- print(f"{context}: Restored training weights")
-
- def init_from_ckpt(self, path, ignore_keys=list()):
- sd = torch.load(path, map_location="cpu")["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- missing, unexpected = self.load_state_dict(sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
- print(f"Missing Keys: {missing}")
- print(f"Unexpected Keys: {unexpected}")
-
- def on_train_batch_end(self, *args, **kwargs):
- if self.use_ema:
- self.model_ema(self)
-
- def encode(self, x):
- h = self.encoder(x)
- h = self.quant_conv(h)
- quant, emb_loss, info = self.quantize(h)
- return quant, emb_loss, info
-
- def encode_to_prequant(self, x):
- h = self.encoder(x)
- h = self.quant_conv(h)
- return h
-
- def decode(self, quant):
- quant = self.post_quant_conv(quant)
- dec = self.decoder(quant)
- return dec
-
- def decode_code(self, code_b):
- quant_b = self.quantize.embed_code(code_b)
- dec = self.decode(quant_b)
- return dec
-
- def forward(self, input, return_pred_indices=False):
- quant, diff, (_,_,ind) = self.encode(input)
- dec = self.decode(quant)
- if return_pred_indices:
- return dec, diff, ind
- return dec, diff
-
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
- if self.batch_resize_range is not None:
- lower_size = self.batch_resize_range[0]
- upper_size = self.batch_resize_range[1]
- if self.global_step <= 4:
- # do the first few batches with max size to avoid later oom
- new_resize = upper_size
- else:
- new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
- if new_resize != x.shape[2]:
- x = F.interpolate(x, size=new_resize, mode="bicubic")
- x = x.detach()
- return x
-
- def training_step(self, batch, batch_idx, optimizer_idx):
- # https://github.com/pytorch/pytorch/issues/37142
- # try not to fool the heuristics
- x = self.get_input(batch, self.image_key)
- xrec, qloss, ind = self(x, return_pred_indices=True)
-
- if optimizer_idx == 0:
- # autoencode
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train",
- predicted_indices=ind)
-
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return aeloss
-
- if optimizer_idx == 1:
- # discriminator
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
- return discloss
-
- def validation_step(self, batch, batch_idx):
- log_dict = self._validation_step(batch, batch_idx)
- with self.ema_scope():
- log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
- return log_dict
-
- def _validation_step(self, batch, batch_idx, suffix=""):
- x = self.get_input(batch, self.image_key)
- xrec, qloss, ind = self(x, return_pred_indices=True)
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
- self.global_step,
- last_layer=self.get_last_layer(),
- split="val"+suffix,
- predicted_indices=ind
- )
-
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
- self.global_step,
- last_layer=self.get_last_layer(),
- split="val"+suffix,
- predicted_indices=ind
- )
- rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
- self.log(f"val{suffix}/rec_loss", rec_loss,
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
- self.log(f"val{suffix}/aeloss", aeloss,
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
- if version.parse(pl.__version__) >= version.parse('1.4.0'):
- del log_dict_ae[f"val{suffix}/rec_loss"]
- self.log_dict(log_dict_ae)
- self.log_dict(log_dict_disc)
- return self.log_dict
-
- def configure_optimizers(self):
- lr_d = self.learning_rate
- lr_g = self.lr_g_factor*self.learning_rate
- print("lr_d", lr_d)
- print("lr_g", lr_g)
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quantize.parameters())+
- list(self.quant_conv.parameters())+
- list(self.post_quant_conv.parameters()),
- lr=lr_g, betas=(0.5, 0.9))
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
- lr=lr_d, betas=(0.5, 0.9))
-
- if self.scheduler_config is not None:
- scheduler = instantiate_from_config(self.scheduler_config)
-
- print("Setting up LambdaLR scheduler...")
- scheduler = [
- {
- 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- },
- {
- 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- },
- ]
- return [opt_ae, opt_disc], scheduler
- return [opt_ae, opt_disc], []
-
- def get_last_layer(self):
- return self.decoder.conv_out.weight
-
- def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
- log = dict()
- x = self.get_input(batch, self.image_key)
- x = x.to(self.device)
- if only_inputs:
- log["inputs"] = x
- return log
- xrec, _ = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec.shape[1] > 3
- x = self.to_rgb(x)
- xrec = self.to_rgb(xrec)
- log["inputs"] = x
- log["reconstructions"] = xrec
- if plot_ema:
- with self.ema_scope():
- xrec_ema, _ = self(x)
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
- log["reconstructions_ema"] = xrec_ema
- return log
-
- def to_rgb(self, x):
- assert self.image_key == "segmentation"
- if not hasattr(self, "colorize"):
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
- x = F.conv2d(x, weight=self.colorize)
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
- return x
-
-
-class VQModelInterface(VQModel):
- def __init__(self, embed_dim, *args, **kwargs):
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
- self.embed_dim = embed_dim
-
- def encode(self, x):
- h = self.encoder(x)
- h = self.quant_conv(h)
- return h
-
- def decode(self, h, force_not_quantize=False):
- # also go through quantization layer
- if not force_not_quantize:
- quant, emb_loss, info = self.quantize(h)
- else:
- quant = h
- quant = self.post_quant_conv(quant)
- dec = self.decoder(quant)
- return dec
-
-
-class AutoencoderKL(pl.LightningModule):
- def __init__(self,
- ddconfig,
- lossconfig,
- embed_dim,
- ckpt_path=None,
- ignore_keys=[],
- image_key="image",
- colorize_nlabels=None,
- monitor=None,
- from_pretrained: str=None
- ):
- super().__init__()
- self.image_key = image_key
- self.encoder = Encoder(**ddconfig)
- self.decoder = Decoder(**ddconfig)
- self.loss = instantiate_from_config(lossconfig)
- assert ddconfig["double_z"]
- self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
- self.embed_dim = embed_dim
- if colorize_nlabels is not None:
- assert type(colorize_nlabels)==int
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
- if monitor is not None:
- self.monitor = monitor
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
- from diffusers.modeling_utils import load_state_dict
- if from_pretrained is not None:
- state_dict = load_state_dict(from_pretrained)
- self._load_pretrained_model(state_dict)
-
- def _state_key_mapping(self, state_dict: dict):
- import re
- res_dict = {}
- key_list = state_dict.keys()
- key_str = " ".join(key_list)
- up_block_pattern = re.compile('upsamplers')
- p1 = re.compile('mid.block_[0-9]')
- p2 = re.compile('decoder.up.[0-9]')
- up_blocks_count = int(len(re.findall(up_block_pattern, key_str)) / 2 + 1)
- for key_, val_ in state_dict.items():
- key_ = key_.replace("up_blocks", "up").replace("down_blocks", "down").replace('resnets', 'block')\
- .replace('mid_block', 'mid').replace("mid.block.", "mid.block_")\
- .replace('mid.attentions.0.key', 'mid.attn_1.k')\
- .replace('mid.attentions.0.query', 'mid.attn_1.q') \
- .replace('mid.attentions.0.value', 'mid.attn_1.v') \
- .replace('mid.attentions.0.group_norm', 'mid.attn_1.norm') \
- .replace('mid.attentions.0.proj_attn', 'mid.attn_1.proj_out')\
- .replace('upsamplers.0', 'upsample')\
- .replace('downsamplers.0', 'downsample')\
- .replace('conv_shortcut', 'nin_shortcut')\
- .replace('conv_norm_out', 'norm_out')
-
- mid_list = re.findall(p1, key_)
- if len(mid_list) != 0:
- mid_str = mid_list[0]
- mid_id = int(mid_str[-1]) + 1
- key_ = key_.replace(mid_str, mid_str[:-1] + str(mid_id))
-
- up_list = re.findall(p2, key_)
- if len(up_list) != 0:
- up_str = up_list[0]
- up_id = up_blocks_count - 1 -int(up_str[-1])
- key_ = key_.replace(up_str, up_str[:-1] + str(up_id))
- res_dict[key_] = val_
- return res_dict
-
- def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
- state_dict = self._state_key_mapping(state_dict)
- model_state_dict = self.state_dict()
- loaded_keys = [k for k in state_dict.keys()]
- expected_keys = list(model_state_dict.keys())
- original_loaded_keys = loaded_keys
- missing_keys = list(set(expected_keys) - set(loaded_keys))
- unexpected_keys = list(set(loaded_keys) - set(expected_keys))
-
- def _find_mismatched_keys(
- state_dict,
- model_state_dict,
- loaded_keys,
- ignore_mismatched_sizes,
- ):
- mismatched_keys = []
- if ignore_mismatched_sizes:
- for checkpoint_key in loaded_keys:
- model_key = checkpoint_key
-
- if (
- model_key in model_state_dict
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
- ):
- mismatched_keys.append(
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
- )
- del state_dict[checkpoint_key]
- return mismatched_keys
- if state_dict is not None:
- # Whole checkpoint
- mismatched_keys = _find_mismatched_keys(
- state_dict,
- model_state_dict,
- original_loaded_keys,
- ignore_mismatched_sizes,
- )
- error_msgs = self._load_state_dict_into_model(state_dict)
- return missing_keys, unexpected_keys, mismatched_keys, error_msgs
-
- def _load_state_dict_into_model(self, state_dict):
- # Convert old format to new format if needed from a PyTorch state_dict
- # copy state_dict so _load_from_state_dict can modify it
- state_dict = state_dict.copy()
- error_msgs = []
-
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
- # so we need to apply the function recursively.
- def load(module: torch.nn.Module, prefix=""):
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
- module._load_from_state_dict(*args)
-
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + ".")
-
- load(self)
-
- return error_msgs
-
- def init_from_ckpt(self, path, ignore_keys=list()):
- sd = torch.load(path, map_location="cpu")["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- self.load_state_dict(sd, strict=False)
- print(f"Restored from {path}")
-
- def encode(self, x):
- h = self.encoder(x)
- moments = self.quant_conv(h)
- posterior = DiagonalGaussianDistribution(moments)
- return posterior
-
- def decode(self, z):
- z = self.post_quant_conv(z)
- dec = self.decoder(z)
- return dec
-
- def forward(self, input, sample_posterior=True):
- posterior = self.encode(input)
- if sample_posterior:
- z = posterior.sample()
- else:
- z = posterior.mode()
- dec = self.decode(z)
- return dec, posterior
-
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
- return x
-
- def training_step(self, batch, batch_idx, optimizer_idx):
- inputs = self.get_input(batch, self.image_key)
- reconstructions, posterior = self(inputs)
-
- if optimizer_idx == 0:
- # train encoder+decoder+logvar
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
- self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
- return aeloss
-
- if optimizer_idx == 1:
- # train the discriminator
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
- last_layer=self.get_last_layer(), split="train")
-
- self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
- return discloss
-
- def validation_step(self, batch, batch_idx):
- inputs = self.get_input(batch, self.image_key)
- reconstructions, posterior = self(inputs)
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
- last_layer=self.get_last_layer(), split="val")
-
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
- last_layer=self.get_last_layer(), split="val")
-
- self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
- self.log_dict(log_dict_ae)
- self.log_dict(log_dict_disc)
- return self.log_dict
-
- def configure_optimizers(self):
- lr = self.learning_rate
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
- list(self.decoder.parameters())+
- list(self.quant_conv.parameters())+
- list(self.post_quant_conv.parameters()),
- lr=lr, betas=(0.5, 0.9))
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
- lr=lr, betas=(0.5, 0.9))
- return [opt_ae, opt_disc], []
-
- def get_last_layer(self):
- return self.decoder.conv_out.weight
-
- @torch.no_grad()
- def log_images(self, batch, only_inputs=False, **kwargs):
- log = dict()
- x = self.get_input(batch, self.image_key)
- x = x.to(self.device)
- if not only_inputs:
- xrec, posterior = self(x)
- if x.shape[1] > 3:
- # colorize with random projection
- assert xrec.shape[1] > 3
- x = self.to_rgb(x)
- xrec = self.to_rgb(xrec)
- log["samples"] = self.decode(torch.randn_like(posterior.sample()))
- log["reconstructions"] = xrec
- log["inputs"] = x
- return log
-
- def to_rgb(self, x):
- assert self.image_key == "segmentation"
- if not hasattr(self, "colorize"):
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
- x = F.conv2d(x, weight=self.colorize)
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
- return x
-
-
-class IdentityFirstStage(torch.nn.Module):
- def __init__(self, *args, vq_interface=False, **kwargs):
- self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
- super().__init__()
-
- def encode(self, x, *args, **kwargs):
- return x
-
- def decode(self, x, *args, **kwargs):
- return x
-
- def quantize(self, x, *args, **kwargs):
- if self.vq_interface:
- return x, None, [None, None, None]
- return x
-
- def forward(self, x, *args, **kwargs):
- return x
diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py
deleted file mode 100644
index 67e98b9d8ffb..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/diffusion/classifier.py
+++ /dev/null
@@ -1,267 +0,0 @@
-import os
-import torch
-import pytorch_lightning as pl
-from omegaconf import OmegaConf
-from torch.nn import functional as F
-from torch.optim import AdamW
-from torch.optim.lr_scheduler import LambdaLR
-from copy import deepcopy
-from einops import rearrange
-from glob import glob
-from natsort import natsorted
-
-from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
-from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
-
-__models__ = {
- 'class_label': EncoderUNetModel,
- 'segmentation': UNetModel
-}
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-class NoisyLatentImageClassifier(pl.LightningModule):
-
- def __init__(self,
- diffusion_path,
- num_classes,
- ckpt_path=None,
- pool='attention',
- label_key=None,
- diffusion_ckpt_path=None,
- scheduler_config=None,
- weight_decay=1.e-2,
- log_steps=10,
- monitor='val/loss',
- *args,
- **kwargs):
- super().__init__(*args, **kwargs)
- self.num_classes = num_classes
- # get latest config of diffusion model
- diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
- self.diffusion_config = OmegaConf.load(diffusion_config).model
- self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
- self.load_diffusion()
-
- self.monitor = monitor
- self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
- self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
- self.log_steps = log_steps
-
- self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
- else self.diffusion_model.cond_stage_key
-
- assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
-
- if self.label_key not in __models__:
- raise NotImplementedError()
-
- self.load_classifier(ckpt_path, pool)
-
- self.scheduler_config = scheduler_config
- self.use_scheduler = self.scheduler_config is not None
- self.weight_decay = weight_decay
-
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
- sd = torch.load(path, map_location="cpu")
- if "state_dict" in list(sd.keys()):
- sd = sd["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
- sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
- print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
- print(f"Unexpected Keys: {unexpected}")
-
- def load_diffusion(self):
- model = instantiate_from_config(self.diffusion_config)
- self.diffusion_model = model.eval()
- self.diffusion_model.train = disabled_train
- for param in self.diffusion_model.parameters():
- param.requires_grad = False
-
- def load_classifier(self, ckpt_path, pool):
- model_config = deepcopy(self.diffusion_config.params.unet_config.params)
- model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
- model_config.out_channels = self.num_classes
- if self.label_key == 'class_label':
- model_config.pool = pool
-
- self.model = __models__[self.label_key](**model_config)
- if ckpt_path is not None:
- print('#####################################################################')
- print(f'load from ckpt "{ckpt_path}"')
- print('#####################################################################')
- self.init_from_ckpt(ckpt_path)
-
- @torch.no_grad()
- def get_x_noisy(self, x, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x))
- continuous_sqrt_alpha_cumprod = None
- if self.diffusion_model.use_continuous_noise:
- continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
- # todo: make sure t+1 is correct here
-
- return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
- continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
-
- def forward(self, x_noisy, t, *args, **kwargs):
- return self.model(x_noisy, t)
-
- @torch.no_grad()
- def get_input(self, batch, k):
- x = batch[k]
- if len(x.shape) == 3:
- x = x[..., None]
- x = rearrange(x, 'b h w c -> b c h w')
- x = x.to(memory_format=torch.contiguous_format).float()
- return x
-
- @torch.no_grad()
- def get_conditioning(self, batch, k=None):
- if k is None:
- k = self.label_key
- assert k is not None, 'Needs to provide label key'
-
- targets = batch[k].to(self.device)
-
- if self.label_key == 'segmentation':
- targets = rearrange(targets, 'b h w c -> b c h w')
- for down in range(self.numd):
- h, w = targets.shape[-2:]
- targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
-
- # targets = rearrange(targets,'b c h w -> b h w c')
-
- return targets
-
- def compute_top_k(self, logits, labels, k, reduction="mean"):
- _, top_ks = torch.topk(logits, k, dim=1)
- if reduction == "mean":
- return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
- elif reduction == "none":
- return (top_ks == labels[:, None]).float().sum(dim=-1)
-
- def on_train_epoch_start(self):
- # save some memory
- self.diffusion_model.model.to('cpu')
-
- @torch.no_grad()
- def write_logs(self, loss, logits, targets):
- log_prefix = 'train' if self.training else 'val'
- log = {}
- log[f"{log_prefix}/loss"] = loss.mean()
- log[f"{log_prefix}/acc@1"] = self.compute_top_k(
- logits, targets, k=1, reduction="mean"
- )
- log[f"{log_prefix}/acc@5"] = self.compute_top_k(
- logits, targets, k=5, reduction="mean"
- )
-
- self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
- self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
- self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
- lr = self.optimizers().param_groups[0]['lr']
- self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
-
- def shared_step(self, batch, t=None):
- x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
- targets = self.get_conditioning(batch)
- if targets.dim() == 4:
- targets = targets.argmax(dim=1)
- if t is None:
- t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
- else:
- t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
- x_noisy = self.get_x_noisy(x, t)
- logits = self(x_noisy, t)
-
- loss = F.cross_entropy(logits, targets, reduction='none')
-
- self.write_logs(loss.detach(), logits.detach(), targets.detach())
-
- loss = loss.mean()
- return loss, logits, x_noisy, targets
-
- def training_step(self, batch, batch_idx):
- loss, *_ = self.shared_step(batch)
- return loss
-
- def reset_noise_accs(self):
- self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
- range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
-
- def on_validation_start(self):
- self.reset_noise_accs()
-
- @torch.no_grad()
- def validation_step(self, batch, batch_idx):
- loss, *_ = self.shared_step(batch)
-
- for t in self.noisy_acc:
- _, logits, _, targets = self.shared_step(batch, t)
- self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
- self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
-
- return loss
-
- def configure_optimizers(self):
- optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
-
- if self.use_scheduler:
- scheduler = instantiate_from_config(self.scheduler_config)
-
- print("Setting up LambdaLR scheduler...")
- scheduler = [
- {
- 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- }]
- return [optimizer], scheduler
-
- return optimizer
-
- @torch.no_grad()
- def log_images(self, batch, N=8, *args, **kwargs):
- log = dict()
- x = self.get_input(batch, self.diffusion_model.first_stage_key)
- log['inputs'] = x
-
- y = self.get_conditioning(batch)
-
- if self.label_key == 'class_label':
- y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
- log['labels'] = y
-
- if ismap(y):
- log['labels'] = self.diffusion_model.to_rgb(y)
-
- for step in range(self.log_steps):
- current_time = step * self.log_time_interval
-
- _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
-
- log[f'inputs@t{current_time}'] = x_noisy
-
- pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
- pred = rearrange(pred, 'b h w c -> b c h w')
-
- log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
-
- for key in log:
- log[key] = log[key][:N]
-
- return log
diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py
deleted file mode 100644
index 91335d6372df..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddim.py
+++ /dev/null
@@ -1,240 +0,0 @@
-"""SAMPLING ONLY."""
-
-import torch
-import numpy as np
-from tqdm import tqdm
-from functools import partial
-
-from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
- extract_into_tensor
-
-
-class DDIMSampler(object):
- def __init__(self, model, schedule="linear", **kwargs):
- super().__init__()
- self.model = model
- self.ddpm_num_timesteps = model.num_timesteps
- self.schedule = schedule
-
- def register_buffer(self, name, attr):
- if type(attr) == torch.Tensor:
- if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
- setattr(self, name, attr)
-
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
- alphas_cumprod = self.model.alphas_cumprod
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
-
- self.register_buffer('betas', to_torch(self.model.betas))
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
-
- # ddim sampling parameters
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
- ddim_timesteps=self.ddim_timesteps,
- eta=ddim_eta,verbose=verbose)
- self.register_buffer('ddim_sigmas', ddim_sigmas)
- self.register_buffer('ddim_alphas', ddim_alphas)
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
-
- @torch.no_grad()
- def sample(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
-
- samples, intermediates = self.ddim_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
- @torch.no_grad()
- def ddim_sampling(self, cond, shape,
- x_T=None, ddim_use_original_steps=False,
- callback=None, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, img_callback=None, log_every_t=100,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
- device = self.model.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- if timesteps is None:
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
- elif timesteps is not None and not ddim_use_original_steps:
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
- timesteps = self.ddim_timesteps[:subset_end]
-
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
- print(f"Running DDIM Sampling with {total_steps} timesteps")
-
- iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
-
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full((b,), step, device=device, dtype=torch.long)
-
- if mask is not None:
- assert x0 is not None
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
- img = img_orig * mask + (1. - mask) * img
- outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
- quantize_denoised=quantize_denoised, temperature=temperature,
- noise_dropout=noise_dropout, score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning)
- img, pred_x0 = outs
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
-
- if index % log_every_t == 0 or index == total_steps - 1:
- intermediates['x_inter'].append(img)
- intermediates['pred_x0'].append(pred_x0)
-
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None):
- b, *_, device = *x.shape, x.device
-
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- c_in = torch.cat([unconditional_conditioning, c])
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
- @torch.no_grad()
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
- # fast, but does not allow for exact reconstruction
- # t serves as an index to gather the correct alphas
- if use_original_steps:
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
- else:
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
-
- if noise is None:
- noise = torch.randn_like(x0)
- return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
- extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
-
- @torch.no_grad()
- def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- use_original_steps=False):
-
- timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
- timesteps = timesteps[:t_start]
-
- time_range = np.flip(timesteps)
- total_steps = timesteps.shape[0]
- print(f"Running DDIM Sampling with {total_steps} timesteps")
-
- iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
- x_dec = x_latent
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
- x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning)
- return x_dec
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py
deleted file mode 100644
index 9633ec3d843a..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/diffusion/ddpm.py
+++ /dev/null
@@ -1,1554 +0,0 @@
-import torch
-import torch.nn as nn
-import numpy as np
-import pytorch_lightning as pl
-from torch.optim.lr_scheduler import LambdaLR
-from einops import rearrange, repeat
-from contextlib import contextmanager
-from functools import partial
-from tqdm import tqdm
-from torchvision.utils import make_grid
-
-from pytorch_lightning.utilities.rank_zero import rank_zero_only
-from pytorch_lightning.utilities import rank_zero_info
-
-from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
-from ldm.modules.ema import LitEma
-from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
-from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.modules.diffusionmodules.openaimodel import AttentionPool2d
-from ldm.modules.x_transformer import *
-from ldm.modules.encoders.modules import *
-
-from ldm.modules.ema import LitEma
-from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ldm.models.autoencoder import *
-from ldm.models.diffusion.ddim import *
-from ldm.modules.diffusionmodules.openaimodel import *
-from ldm.modules.diffusionmodules.model import *
-
-
-from ldm.modules.diffusionmodules.model import Model, Encoder, Decoder
-
-from ldm.util import instantiate_from_config
-
-from einops import rearrange, repeat
-
-
-
-
-__conditioning_keys__ = {'concat': 'c_concat',
- 'crossattn': 'c_crossattn',
- 'adm': 'y'}
-
-
-def disabled_train(self, mode=True):
- """Overwrite model.train with this function to make sure train/eval mode
- does not change anymore."""
- return self
-
-
-def uniform_on_device(r1, r2, shape, device):
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
-
-
-class DDPM(pl.LightningModule):
- # classic DDPM with Gaussian diffusion, in image space
- def __init__(self,
- unet_config,
- timesteps=1000,
- beta_schedule="linear",
- loss_type="l2",
- ckpt_path=None,
- ignore_keys=[],
- load_only_unet=False,
- monitor="val/loss",
- use_ema=True,
- first_stage_key="image",
- image_size=256,
- channels=3,
- log_every_t=100,
- clip_denoised=True,
- linear_start=1e-4,
- linear_end=2e-2,
- cosine_s=8e-3,
- given_betas=None,
- original_elbo_weight=0.,
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
- l_simple_weight=1.,
- conditioning_key=None,
- parameterization="eps", # all assuming fixed variance schedules
- scheduler_config=None,
- use_positional_encodings=False,
- learn_logvar=False,
- logvar_init=0.,
- use_fp16 = True,
- ):
- super().__init__()
- assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
- self.parameterization = parameterization
- rank_zero_info(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
- self.cond_stage_model = None
- self.clip_denoised = clip_denoised
- self.log_every_t = log_every_t
- self.first_stage_key = first_stage_key
- self.image_size = image_size # try conv?
- self.channels = channels
- self.use_positional_encodings = use_positional_encodings
- self.unet_config = unet_config
- self.conditioning_key = conditioning_key
- # self.model = DiffusionWrapper(unet_config, conditioning_key)
- # count_params(self.model, verbose=True)
- self.use_ema = use_ema
- # if self.use_ema:
- # self.model_ema = LitEma(self.model)
- # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
- self.use_scheduler = scheduler_config is not None
- if self.use_scheduler:
- self.scheduler_config = scheduler_config
-
- self.v_posterior = v_posterior
- self.original_elbo_weight = original_elbo_weight
- self.l_simple_weight = l_simple_weight
-
- if monitor is not None:
- self.monitor = monitor
- self.ckpt_path = ckpt_path
- self.ignore_keys = ignore_keys
- self.load_only_unet = load_only_unet
- self.given_betas = given_betas
- self.beta_schedule = beta_schedule
- self.timesteps = timesteps
- self.linear_start = linear_start
- self.linear_end = linear_end
- self.cosine_s = cosine_s
- # if ckpt_path is not None:
- # self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
- #
- # self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
- # linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
-
- self.loss_type = loss_type
-
- self.learn_logvar = learn_logvar
- self.logvar_init = logvar_init
- # self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
- # if self.learn_logvar:
- # self.logvar = nn.Parameter(self.logvar, requires_grad=True)
- # self.logvar = nn.Parameter(self.logvar, requires_grad=True)
-
- self.use_fp16 = use_fp16
- if use_fp16:
- self.unet_config["params"].update({"use_fp16": True})
- rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"]))
- else:
- self.unet_config["params"].update({"use_fp16": False})
- rank_zero_info("Using FP16 for UNet = {}".format(self.unet_config["params"]["use_fp16"]))
-
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- if exists(given_betas):
- betas = given_betas
- else:
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
- cosine_s=cosine_s)
- alphas = 1. - betas
- alphas_cumprod = np.cumprod(alphas, axis=0)
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
-
- timesteps, = betas.shape
- self.num_timesteps = int(timesteps)
- self.linear_start = linear_start
- self.linear_end = linear_end
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
-
- to_torch = partial(torch.tensor, dtype=torch.float32)
-
- self.register_buffer('betas', to_torch(betas))
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
-
- # calculations for posterior q(x_{t-1} | x_t, x_0)
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
- 1. - alphas_cumprod) + self.v_posterior * betas
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
- self.register_buffer('posterior_mean_coef1', to_torch(
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
- self.register_buffer('posterior_mean_coef2', to_torch(
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
-
- if self.parameterization == "eps":
- lvlb_weights = self.betas ** 2 / (
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
- elif self.parameterization == "x0":
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
- else:
- raise NotImplementedError("mu not supported")
- # TODO how to choose this term
- lvlb_weights[0] = lvlb_weights[1]
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
- assert not torch.isnan(self.lvlb_weights).all()
-
- @contextmanager
- def ema_scope(self, context=None):
- if self.use_ema:
- self.model_ema.store(self.model.parameters())
- self.model_ema.copy_to(self.model)
- if context is not None:
- print(f"{context}: Switched to EMA weights")
- try:
- yield None
- finally:
- if self.use_ema:
- self.model_ema.restore(self.model.parameters())
- if context is not None:
- print(f"{context}: Restored training weights")
-
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
- sd = torch.load(path, map_location="cpu")
- if "state_dict" in list(sd.keys()):
- sd = sd["state_dict"]
- keys = list(sd.keys())
- for k in keys:
- for ik in ignore_keys:
- if k.startswith(ik):
- print("Deleting key {} from state_dict.".format(k))
- del sd[k]
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
- sd, strict=False)
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
- if len(missing) > 0:
- print(f"Missing Keys: {missing}")
- if len(unexpected) > 0:
- print(f"Unexpected Keys: {unexpected}")
-
- def q_mean_variance(self, x_start, t):
- """
- Get the distribution q(x_t | x_0).
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
- """
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
- return mean, variance, log_variance
-
- def predict_start_from_noise(self, x_t, t, noise):
- return (
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
- )
-
- def q_posterior(self, x_start, x_t, t):
- posterior_mean = (
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
- )
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
- def p_mean_variance(self, x, t, clip_denoised: bool):
- model_out = self.model(x, t)
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- if clip_denoised:
- x_recon.clamp_(-1., 1.)
-
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
- b, *_, device = *x.shape, x.device
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
- noise = noise_like(x.shape, device, repeat_noise)
- # no noise when t == 0
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def p_sample_loop(self, shape, return_intermediates=False):
- device = self.betas.device
- b = shape[0]
- img = torch.randn(shape, device=device)
- intermediates = [img]
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
- clip_denoised=self.clip_denoised)
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
- intermediates.append(img)
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(self, batch_size=16, return_intermediates=False):
- image_size = self.image_size
- channels = self.channels
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
- return_intermediates=return_intermediates)
-
- def q_sample(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
-
- def get_loss(self, pred, target, mean=True):
-
- if pred.isnan().any():
- print("Warning: Prediction has nan values")
- lr = self.optimizers().param_groups[0]['lr']
- # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
- print(f"lr: {lr}")
- if pred.isinf().any():
- print("Warning: Prediction has inf values")
-
- if self.use_fp16:
- target = target.half()
-
- if self.loss_type == 'l1':
- loss = (target - pred).abs()
- if mean:
- loss = loss.mean()
- elif self.loss_type == 'l2':
- if mean:
- loss = torch.nn.functional.mse_loss(target, pred)
- else:
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
- else:
- raise NotImplementedError("unknown loss type '{loss_type}'")
-
- if loss.isnan().any():
- print("Warning: loss has nan values")
- print("loss: ", loss[0][0][0])
- raise ValueError("loss has nan values")
- if loss.isinf().any():
- print("Warning: loss has inf values")
- print("loss: ", loss)
- raise ValueError("loss has inf values")
-
- return loss
-
- def p_losses(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- model_out = self.model(x_noisy, t)
-
- loss_dict = {}
- if self.parameterization == "eps":
- target = noise
- elif self.parameterization == "x0":
- target = x_start
- else:
- raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
-
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
-
- log_prefix = 'train' if self.training else 'val'
-
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
- loss_simple = loss.mean() * self.l_simple_weight
-
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
-
- loss = loss_simple + self.original_elbo_weight * loss_vlb
-
- loss_dict.update({f'{log_prefix}/loss': loss})
-
- return loss, loss_dict
-
- def forward(self, x, *args, **kwargs):
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
- return self.p_losses(x, t, *args, **kwargs)
-
- def get_input(self, batch, k):
- # print("+" * 30)
- # print(batch['jpg'].shape)
- # print(len(batch['txt']))
- # print(k)
- # print("=" * 30)
- if not isinstance(batch, torch.Tensor):
- x = batch[k]
- else:
- x = batch
- if len(x.shape) == 3:
- x = x[..., None]
- x = rearrange(x, 'b h w c -> b c h w')
-
- if self.use_fp16:
- x = x.to(memory_format=torch.contiguous_format).float().half()
- else:
- x = x.to(memory_format=torch.contiguous_format).float()
-
- return x
-
- def shared_step(self, batch):
- x = self.get_input(batch, self.first_stage_key)
- loss, loss_dict = self(x)
- return loss, loss_dict
-
- def training_step(self, batch, batch_idx):
- loss, loss_dict = self.shared_step(batch)
-
- self.log_dict(loss_dict, prog_bar=True,
- logger=True, on_step=True, on_epoch=True)
-
- self.log("global_step", self.global_step,
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
-
- if self.use_scheduler:
- lr = self.optimizers().param_groups[0]['lr']
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
-
- return loss
-
- @torch.no_grad()
- def validation_step(self, batch, batch_idx):
- _, loss_dict_no_ema = self.shared_step(batch)
- with self.ema_scope():
- _, loss_dict_ema = self.shared_step(batch)
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
-
- def on_train_batch_end(self, *args, **kwargs):
- if self.use_ema:
- self.model_ema(self.model)
-
- def _get_rows_from_list(self, samples):
- n_imgs_per_row = len(samples)
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
- return denoise_grid
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
- log = dict()
- x = self.get_input(batch, self.first_stage_key)
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- x = x.to(self.device)[:N]
- log["inputs"] = x
-
- # get diffusion row
- diffusion_row = list()
- x_start = x[:n_row]
-
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(x_start)
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- diffusion_row.append(x_noisy)
-
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
-
- if sample:
- # get denoise row
- with self.ema_scope("Plotting"):
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
-
- log["samples"] = samples
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
-
- if return_keys:
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
- return log
- else:
- return {key: log[key] for key in return_keys}
- return log
-
- def configure_optimizers(self):
- lr = self.learning_rate
- params = list(self.model.parameters())
- if self.learn_logvar:
- params = params + [self.logvar]
- opt = torch.optim.AdamW(params, lr=lr)
- return opt
-
-
-class LatentDiffusion(DDPM):
- """main class"""
- def __init__(self,
- first_stage_config,
- cond_stage_config,
- num_timesteps_cond=None,
- cond_stage_key="image",
- cond_stage_trainable=False,
- concat_mode=True,
- cond_stage_forward=None,
- conditioning_key=None,
- scale_factor=1.0,
- scale_by_std=False,
- use_fp16=True,
- *args, **kwargs):
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
- self.scale_by_std = scale_by_std
- assert self.num_timesteps_cond <= kwargs['timesteps']
- # for backwards compatibility after implementation of DiffusionWrapper
- if conditioning_key is None:
- conditioning_key = 'concat' if concat_mode else 'crossattn'
- if cond_stage_config == '__is_unconditional__':
- conditioning_key = None
- ckpt_path = kwargs.pop("ckpt_path", None)
- ignore_keys = kwargs.pop("ignore_keys", [])
- super().__init__(conditioning_key=conditioning_key, use_fp16=use_fp16, *args, **kwargs)
- self.concat_mode = concat_mode
- self.cond_stage_trainable = cond_stage_trainable
- self.cond_stage_key = cond_stage_key
- try:
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
- except:
- self.num_downs = 0
- if not scale_by_std:
- self.scale_factor = scale_factor
- else:
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
- self.first_stage_config = first_stage_config
- self.cond_stage_config = cond_stage_config
- if self.use_fp16:
- self.cond_stage_config["params"].update({"use_fp16": True})
- rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"]))
- else:
- self.cond_stage_config["params"].update({"use_fp16": False})
- rank_zero_info("Using fp16 for conditioning stage = {}".format(self.cond_stage_config["params"]["use_fp16"]))
- # self.instantiate_first_stage(first_stage_config)
- # self.instantiate_cond_stage(cond_stage_config)
- self.cond_stage_forward = cond_stage_forward
- self.clip_denoised = False
- self.bbox_tokenizer = None
-
- self.restarted_from_ckpt = False
- if ckpt_path is not None:
- self.init_from_ckpt(ckpt_path, ignore_keys)
- self.restarted_from_ckpt = True
-
-
-
- def configure_sharded_model(self) -> None:
- self.model = DiffusionWrapper(self.unet_config, self.conditioning_key)
- count_params(self.model, verbose=True)
- if self.use_ema:
- self.model_ema = LitEma(self.model)
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
-
-
- self.register_schedule(given_betas=self.given_betas, beta_schedule=self.beta_schedule, timesteps=self.timesteps,
- linear_start=self.linear_start, linear_end=self.linear_end, cosine_s=self.cosine_s)
-
- self.logvar = torch.full(fill_value=self.logvar_init, size=(self.num_timesteps,))
- if self.learn_logvar:
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
- # self.logvar = nn.Parameter(self.logvar, requires_grad=True)
- if self.ckpt_path is not None:
- self.init_from_ckpt(self.ckpt_path, self.ignore_keys)
- self.restarted_from_ckpt = True
-
- # TODO()
- # for p in self.model.modules():
- # if not p.parameters().data.is_contiguous:
- # p.data = p.data.contiguous()
-
- self.instantiate_first_stage(self.first_stage_config)
- self.instantiate_cond_stage(self.cond_stage_config)
-
- def make_cond_schedule(self, ):
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
- self.cond_ids[:self.num_timesteps_cond] = ids
-
-
-
- @rank_zero_only
- @torch.no_grad()
- # def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
- def on_train_batch_start(self, batch, batch_idx):
- # only for very first batch
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
- # set rescale weight to 1./std of encodings
- print("### USING STD-RESCALING ###")
- x = super().get_input(batch, self.first_stage_key)
- x = x.to(self.device)
- encoder_posterior = self.encode_first_stage(x)
- z = self.get_first_stage_encoding(encoder_posterior).detach()
- del self.scale_factor
- self.register_buffer('scale_factor', 1. / z.flatten().std())
- print(f"setting self.scale_factor to {self.scale_factor}")
- print("### USING STD-RESCALING ###")
-
- def register_schedule(self,
- given_betas=None, beta_schedule="linear", timesteps=1000,
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
-
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
- if self.shorten_cond_schedule:
- self.make_cond_schedule()
-
- def instantiate_first_stage(self, config):
- model = instantiate_from_config(config)
- self.first_stage_model = model.eval()
- self.first_stage_model.train = disabled_train
- for param in self.first_stage_model.parameters():
- param.requires_grad = False
-
- def instantiate_cond_stage(self, config):
- if not self.cond_stage_trainable:
- if config == "__is_first_stage__":
- print("Using first stage also as cond stage.")
- self.cond_stage_model = self.first_stage_model
- elif config == "__is_unconditional__":
- print(f"Training {self.__class__.__name__} as an unconditional model.")
- self.cond_stage_model = None
- # self.be_unconditional = True
- else:
- model = instantiate_from_config(config)
- self.cond_stage_model = model.eval()
- self.cond_stage_model.train = disabled_train
- for param in self.cond_stage_model.parameters():
- param.requires_grad = False
- else:
- assert config != '__is_first_stage__'
- assert config != '__is_unconditional__'
- model = instantiate_from_config(config)
- self.cond_stage_model = model
-
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
- denoise_row = []
- for zd in tqdm(samples, desc=desc):
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
- force_not_quantize=force_no_decoder_quantization))
- n_imgs_per_row = len(denoise_row)
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
- return denoise_grid
-
- def get_first_stage_encoding(self, encoder_posterior):
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
- z = encoder_posterior.sample()
- elif isinstance(encoder_posterior, torch.Tensor):
- z = encoder_posterior
- else:
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
- return self.scale_factor * z
-
- def get_learned_conditioning(self, c):
- if self.cond_stage_forward is None:
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
- c = self.cond_stage_model.encode(c)
- if isinstance(c, DiagonalGaussianDistribution):
- c = c.mode()
- else:
- c = self.cond_stage_model(c)
- else:
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
- return c
-
- def meshgrid(self, h, w):
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
-
- arr = torch.cat([y, x], dim=-1)
- return arr
-
- def delta_border(self, h, w):
- """
- :param h: height
- :param w: width
- :return: normalized distance to image border,
- wtith min distance = 0 at border and max dist = 0.5 at image center
- """
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
- arr = self.meshgrid(h, w) / lower_right_corner
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
- return edge_dist
-
- def get_weighting(self, h, w, Ly, Lx, device):
- weighting = self.delta_border(h, w)
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
- self.split_input_params["clip_max_weight"], )
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
-
- if self.split_input_params["tie_braker"]:
- L_weighting = self.delta_border(Ly, Lx)
- L_weighting = torch.clip(L_weighting,
- self.split_input_params["clip_min_tie_weight"],
- self.split_input_params["clip_max_tie_weight"])
-
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
- weighting = weighting * L_weighting
- return weighting
-
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
- """
- :param x: img of size (bs, c, h, w)
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
- """
- bs, nc, h, w = x.shape
-
- # number of crops in image
- Ly = (h - kernel_size[0]) // stride[0] + 1
- Lx = (w - kernel_size[1]) // stride[1] + 1
-
- if uf == 1 and df == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
-
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
-
- elif uf > 1 and df == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
- dilation=1, padding=0,
- stride=(stride[0] * uf, stride[1] * uf))
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
-
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
-
- elif df > 1 and uf == 1:
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
- unfold = torch.nn.Unfold(**fold_params)
-
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
- dilation=1, padding=0,
- stride=(stride[0] // df, stride[1] // df))
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
-
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
-
- else:
- raise NotImplementedError
-
- return fold, unfold, normalization, weighting
-
- @torch.no_grad()
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
- cond_key=None, return_original_cond=False, bs=None):
- x = super().get_input(batch, k)
- if bs is not None:
- x = x[:bs]
- x = x.to(self.device)
- encoder_posterior = self.encode_first_stage(x)
- z = self.get_first_stage_encoding(encoder_posterior).detach()
-
- if self.model.conditioning_key is not None:
- if cond_key is None:
- cond_key = self.cond_stage_key
- if cond_key != self.first_stage_key:
- if cond_key in ['caption', 'coordinates_bbox', 'txt']:
- xc = batch[cond_key]
- elif cond_key == 'class_label':
- xc = batch
- else:
- xc = super().get_input(batch, cond_key).to(self.device)
- else:
- xc = x
- if not self.cond_stage_trainable or force_c_encode:
- if isinstance(xc, dict) or isinstance(xc, list):
- # import pudb; pudb.set_trace()
- c = self.get_learned_conditioning(xc)
- else:
- c = self.get_learned_conditioning(xc.to(self.device))
- else:
- c = xc
- if bs is not None:
- c = c[:bs]
-
- if self.use_positional_encodings:
- pos_x, pos_y = self.compute_latent_shifts(batch)
- ckey = __conditioning_keys__[self.model.conditioning_key]
- c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
-
- else:
- c = None
- xc = None
- if self.use_positional_encodings:
- pos_x, pos_y = self.compute_latent_shifts(batch)
- c = {'pos_x': pos_x, 'pos_y': pos_y}
- out = [z, c]
- if return_first_stage_outputs:
- xrec = self.decode_first_stage(z)
- out.extend([x, xrec])
- if return_original_cond:
- out.append(xc)
- return out
-
- @torch.no_grad()
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
- if predict_cids:
- if z.dim() == 4:
- z = torch.argmax(z.exp(), dim=1).long()
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
-
- z = 1. / self.scale_factor * z
-
- if hasattr(self, "split_input_params"):
- if self.split_input_params["patch_distributed_vq"]:
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
- uf = self.split_input_params["vqf"]
- bs, nc, h, w = z.shape
- if ks[0] > h or ks[1] > w:
- ks = (min(ks[0], h), min(ks[1], w))
- print("reducing Kernel")
-
- if stride[0] > h or stride[1] > w:
- stride = (min(stride[0], h), min(stride[1], w))
- print("reducing stride")
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
-
- z = unfold(z) # (bn, nc * prod(**ks), L)
- # 1. Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- # 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
- force_not_quantize=predict_cids or force_not_quantize)
- for i in range(z.shape[-1])]
- else:
-
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
- for i in range(z.shape[-1])]
-
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
- o = o * weighting
- # Reverse 1. reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- decoded = fold(o)
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
- return decoded
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- # same as above but without decorator
- def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
- if predict_cids:
- if z.dim() == 4:
- z = torch.argmax(z.exp(), dim=1).long()
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
-
- z = 1. / self.scale_factor * z
-
- if hasattr(self, "split_input_params"):
- if self.split_input_params["patch_distributed_vq"]:
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
- uf = self.split_input_params["vqf"]
- bs, nc, h, w = z.shape
- if ks[0] > h or ks[1] > w:
- ks = (min(ks[0], h), min(ks[1], w))
- print("reducing Kernel")
-
- if stride[0] > h or stride[1] > w:
- stride = (min(stride[0], h), min(stride[1], w))
- print("reducing stride")
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
-
- z = unfold(z) # (bn, nc * prod(**ks), L)
- # 1. Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- # 2. apply model loop over last dim
- if isinstance(self.first_stage_model, VQModelInterface):
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
- force_not_quantize=predict_cids or force_not_quantize)
- for i in range(z.shape[-1])]
- else:
-
- output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
- for i in range(z.shape[-1])]
-
- o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
- o = o * weighting
- # Reverse 1. reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- decoded = fold(o)
- decoded = decoded / normalization # norm is shape (1, 1, h, w)
- return decoded
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- else:
- if isinstance(self.first_stage_model, VQModelInterface):
- return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
- else:
- return self.first_stage_model.decode(z)
-
- @torch.no_grad()
- def encode_first_stage(self, x):
- if hasattr(self, "split_input_params"):
- if self.split_input_params["patch_distributed_vq"]:
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
- df = self.split_input_params["vqf"]
- self.split_input_params['original_image_size'] = x.shape[-2:]
- bs, nc, h, w = x.shape
- if ks[0] > h or ks[1] > w:
- ks = (min(ks[0], h), min(ks[1], w))
- print("reducing Kernel")
-
- if stride[0] > h or stride[1] > w:
- stride = (min(stride[0], h), min(stride[1], w))
- print("reducing stride")
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
- z = unfold(x) # (bn, nc * prod(**ks), L)
- # Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
- for i in range(z.shape[-1])]
-
- o = torch.stack(output_list, axis=-1)
- o = o * weighting
-
- # Reverse reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- decoded = fold(o)
- decoded = decoded / normalization
- return decoded
-
- else:
- return self.first_stage_model.encode(x)
- else:
- return self.first_stage_model.encode(x)
-
- def shared_step(self, batch, **kwargs):
- x, c = self.get_input(batch, self.first_stage_key)
- loss = self(x, c)
- return loss
-
- def forward(self, x, c, *args, **kwargs):
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
- if self.model.conditioning_key is not None:
- assert c is not None
- if self.cond_stage_trainable:
- c = self.get_learned_conditioning(c)
- if self.shorten_cond_schedule: # TODO: drop this option
- tc = self.cond_ids[t].to(self.device)
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
- return self.p_losses(x, c, t, *args, **kwargs)
-
- def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
- def rescale_bbox(bbox):
- x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
- y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
- w = min(bbox[2] / crop_coordinates[2], 1 - x0)
- h = min(bbox[3] / crop_coordinates[3], 1 - y0)
- return x0, y0, w, h
-
- return [rescale_bbox(b) for b in bboxes]
-
- def apply_model(self, x_noisy, t, cond, return_ids=False):
- if isinstance(cond, dict):
- # hybrid case, cond is exptected to be a dict
- pass
- else:
- if not isinstance(cond, list):
- cond = [cond]
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
- cond = {key: cond}
-
- if hasattr(self, "split_input_params"):
- assert len(cond) == 1 # todo can only deal with one conditioning atm
- assert not return_ids
- ks = self.split_input_params["ks"] # eg. (128, 128)
- stride = self.split_input_params["stride"] # eg. (64, 64)
-
- h, w = x_noisy.shape[-2:]
-
- fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
-
- z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
- # Reshape to img shape
- z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
- z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
- if self.cond_stage_key in ["image", "LR_image", "segmentation",
- 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
- c_key = next(iter(cond.keys())) # get key
- c = next(iter(cond.values())) # get value
- assert (len(c) == 1) # todo extend to list with more than one elem
- c = c[0] # get element
-
- c = unfold(c)
- c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
-
- cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
-
- elif self.cond_stage_key == 'coordinates_bbox':
- assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
-
- # assuming padding of unfold is always 0 and its dilation is always 1
- n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
- full_img_h, full_img_w = self.split_input_params['original_image_size']
- # as we are operating on latents, we need the factor from the original image size to the
- # spatial latent size to properly rescale the crops for regenerating the bbox annotations
- num_downs = self.first_stage_model.encoder.num_resolutions - 1
- rescale_latent = 2 ** (num_downs)
-
- # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
- # need to rescale the tl patch coordinates to be in between (0,1)
- tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
- rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
- for patch_nr in range(z.shape[-1])]
-
- # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
- patch_limits = [(x_tl, y_tl,
- rescale_latent * ks[0] / full_img_w,
- rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
- # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
-
- # tokenize crop coordinates for the bounding boxes of the respective patches
- patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
- for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
- print(patch_limits_tknzd[0].shape)
- # cut tknzd crop position from conditioning
- assert isinstance(cond, dict), 'cond must be dict to be fed into model'
- cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
- print(cut_cond.shape)
-
- adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
- adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
- print(adapted_cond.shape)
- adapted_cond = self.get_learned_conditioning(adapted_cond)
- print(adapted_cond.shape)
- adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
- print(adapted_cond.shape)
-
- cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
-
- else:
- cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
-
- # apply model by loop over crops
- output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
- assert not isinstance(output_list[0],
- tuple) # todo cant deal with multiple model outputs check this never happens
-
- o = torch.stack(output_list, axis=-1)
- o = o * weighting
- # Reverse reshape to img shape
- o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
- # stitch crops together
- x_recon = fold(o) / normalization
-
- else:
- x_recon = self.model(x_noisy, t, **cond)
-
- if isinstance(x_recon, tuple) and not return_ids:
- return x_recon[0]
- else:
- return x_recon
-
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
-
- def _prior_bpd(self, x_start):
- """
- Get the prior KL term for the variational lower-bound, measured in
- bits-per-dim.
- This term can't be optimized, as it only depends on the encoder.
- :param x_start: the [N x C x ...] tensor of inputs.
- :return: a batch of [N] KL values (in bits), one per batch element.
- """
- batch_size = x_start.shape[0]
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
- return mean_flat(kl_prior) / np.log(2.0)
-
- def p_losses(self, x_start, cond, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
- model_output = self.apply_model(x_noisy, t, cond)
-
- loss_dict = {}
- prefix = 'train' if self.training else 'val'
-
- if self.parameterization == "x0":
- target = x_start
- elif self.parameterization == "eps":
- target = noise
- else:
- raise NotImplementedError()
-
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
-
- logvar_t = self.logvar[t].to(self.device)
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
- # loss = loss_simple / torch.exp(self.logvar) + self.logvar
- if self.learn_logvar:
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
- loss_dict.update({'logvar': self.logvar.data.mean()})
-
- loss = self.l_simple_weight * loss.mean()
-
- loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
- loss += (self.original_elbo_weight * loss_vlb)
- loss_dict.update({f'{prefix}/loss': loss})
-
- return loss, loss_dict
-
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
- return_x0=False, score_corrector=None, corrector_kwargs=None):
- t_in = t
- model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
-
- if score_corrector is not None:
- assert self.parameterization == "eps"
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
-
- if return_codebook_ids:
- model_out, logits = model_out
-
- if self.parameterization == "eps":
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
- elif self.parameterization == "x0":
- x_recon = model_out
- else:
- raise NotImplementedError()
-
- if clip_denoised:
- x_recon.clamp_(-1., 1.)
- if quantize_denoised:
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
- if return_codebook_ids:
- return model_mean, posterior_variance, posterior_log_variance, logits
- elif return_x0:
- return model_mean, posterior_variance, posterior_log_variance, x_recon
- else:
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
- b, *_, device = *x.shape, x.device
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
- return_codebook_ids=return_codebook_ids,
- quantize_denoised=quantize_denoised,
- return_x0=return_x0,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
- if return_codebook_ids:
- raise DeprecationWarning("Support dropped.")
- model_mean, _, model_log_variance, logits = outputs
- elif return_x0:
- model_mean, _, model_log_variance, x0 = outputs
- else:
- model_mean, _, model_log_variance = outputs
-
- noise = noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- # no noise when t == 0
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
-
- if return_codebook_ids:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
- if return_x0:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
- else:
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
-
- @torch.no_grad()
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
- log_every_t=None):
- if not log_every_t:
- log_every_t = self.log_every_t
- timesteps = self.num_timesteps
- if batch_size is not None:
- b = batch_size if batch_size is not None else shape[0]
- shape = [batch_size] + list(shape)
- else:
- b = batch_size = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=self.device)
- else:
- img = x_T
- intermediates = []
- if cond is not None:
- if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
- else:
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
- total=timesteps) if verbose else reversed(
- range(0, timesteps))
- if type(temperature) == float:
- temperature = [temperature] * timesteps
-
- for i in iterator:
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != 'hybrid'
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img, x0_partial = self.p_sample(img, cond, ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised, return_x0=True,
- temperature=temperature[i], noise_dropout=noise_dropout,
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
- if mask is not None:
- assert x0 is not None
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1. - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(x0_partial)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_loop(self, cond, shape, return_intermediates=False,
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, img_callback=None, start_T=None,
- log_every_t=None):
-
- if not log_every_t:
- log_every_t = self.log_every_t
- device = self.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- intermediates = [img]
- if timesteps is None:
- timesteps = self.num_timesteps
-
- if start_T is not None:
- timesteps = min(timesteps, start_T)
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
- range(0, timesteps))
-
- if mask is not None:
- assert x0 is not None
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
-
- for i in iterator:
- ts = torch.full((b,), i, device=device, dtype=torch.long)
- if self.shorten_cond_schedule:
- assert self.model.conditioning_key != 'hybrid'
- tc = self.cond_ids[ts].to(cond.device)
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
-
- img = self.p_sample(img, cond, ts,
- clip_denoised=self.clip_denoised,
- quantize_denoised=quantize_denoised)
- if mask is not None:
- img_orig = self.q_sample(x0, ts)
- img = img_orig * mask + (1. - mask) * img
-
- if i % log_every_t == 0 or i == timesteps - 1:
- intermediates.append(img)
- if callback: callback(i)
- if img_callback: img_callback(img, i)
-
- if return_intermediates:
- return img, intermediates
- return img
-
- @torch.no_grad()
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
- verbose=True, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, shape=None,**kwargs):
- if shape is None:
- shape = (batch_size, self.channels, self.image_size, self.image_size)
- if cond is not None:
- if isinstance(cond, dict):
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
- else:
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
- return self.p_sample_loop(cond,
- shape,
- return_intermediates=return_intermediates, x_T=x_T,
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
- mask=mask, x0=x0)
-
- @torch.no_grad()
- def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
-
- if ddim:
- ddim_sampler = DDIMSampler(self)
- shape = (self.channels, self.image_size, self.image_size)
- samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
- shape,cond,verbose=False,**kwargs)
-
- else:
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
- return_intermediates=True,**kwargs)
-
- return samples, intermediates
-
-
- @torch.no_grad()
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
- plot_diffusion_rows=True, **kwargs):
-
- use_ddim = ddim_steps is not None
-
- log = dict()
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
- return_first_stage_outputs=True,
- force_c_encode=True,
- return_original_cond=True,
- bs=N)
- N = min(x.shape[0], N)
- n_row = min(x.shape[0], n_row)
- log["inputs"] = x
- log["reconstruction"] = xrec
- if self.model.conditioning_key is not None:
- if hasattr(self.cond_stage_model, "decode"):
- xc = self.cond_stage_model.decode(c)
- log["conditioning"] = xc
- elif self.cond_stage_key in ["caption"]:
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
- log["conditioning"] = xc
- elif self.cond_stage_key == 'class_label':
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
- log['conditioning'] = xc
- elif isimage(xc):
- log["conditioning"] = xc
- if ismap(xc):
- log["original_conditioning"] = self.to_rgb(xc)
-
- if plot_diffusion_rows:
- # get diffusion row
- diffusion_row = list()
- z_start = z[:n_row]
- for t in range(self.num_timesteps):
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
- t = t.to(self.device).long()
- noise = torch.randn_like(z_start)
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
- diffusion_row.append(self.decode_first_stage(z_noisy))
-
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
- log["diffusion_row"] = diffusion_grid
-
- if sample:
- # get denoise row
- with self.ema_scope("Plotting"):
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
- ddim_steps=ddim_steps,eta=ddim_eta)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
- x_samples = self.decode_first_stage(samples)
- log["samples"] = x_samples
- if plot_denoise_rows:
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
- log["denoise_row"] = denoise_grid
-
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
- self.first_stage_model, IdentityFirstStage):
- # also display when quantizing x0 while sampling
- with self.ema_scope("Plotting Quantized Denoised"):
- samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
- ddim_steps=ddim_steps,eta=ddim_eta,
- quantize_denoised=True)
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
- # quantize_denoised=True)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_x0_quantized"] = x_samples
-
- if inpaint:
- # make a simple center square
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
- mask = torch.ones(N, h, w).to(self.device)
- # zeros will be filled in
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
- mask = mask[:, None, ...]
- with self.ema_scope("Plotting Inpaint"):
-
- samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_inpainting"] = x_samples
- log["mask"] = mask
-
- # outpaint
- with self.ema_scope("Plotting Outpaint"):
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
- x_samples = self.decode_first_stage(samples.to(self.device))
- log["samples_outpainting"] = x_samples
-
- if plot_progressive_rows:
- with self.ema_scope("Plotting Progressives"):
- img, progressives = self.progressive_denoising(c,
- shape=(self.channels, self.image_size, self.image_size),
- batch_size=N)
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
- log["progressive_row"] = prog_row
-
- if return_keys:
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
- return log
- else:
- return {key: log[key] for key in return_keys}
- return log
-
- def configure_optimizers(self):
- lr = self.learning_rate
- params = list(self.model.parameters())
- if self.cond_stage_trainable:
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
- params = params + list(self.cond_stage_model.parameters())
- if self.learn_logvar:
- print('Diffusion model optimizing logvar')
- params.append(self.logvar)
- from colossalai.nn.optimizer import HybridAdam
- opt = HybridAdam(params, lr=lr)
- # opt = torch.optim.AdamW(params, lr=lr)
- if self.use_scheduler:
- assert 'target' in self.scheduler_config
- scheduler = instantiate_from_config(self.scheduler_config)
-
- rank_zero_info("Setting up LambdaLR scheduler...")
- scheduler = [
- {
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
- 'interval': 'step',
- 'frequency': 1
- }]
- return [opt], scheduler
- return opt
-
- @torch.no_grad()
- def to_rgb(self, x):
- x = x.float()
- if not hasattr(self, "colorize"):
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
- x = nn.functional.conv2d(x, weight=self.colorize)
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
- return x
-
-
-class DiffusionWrapper(pl.LightningModule):
- def __init__(self, diff_model_config, conditioning_key):
- super().__init__()
- self.diffusion_model = instantiate_from_config(diff_model_config)
- self.conditioning_key = conditioning_key
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
-
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
- if self.conditioning_key is None:
- out = self.diffusion_model(x, t)
- elif self.conditioning_key == 'concat':
- xc = torch.cat([x] + c_concat, dim=1)
- out = self.diffusion_model(xc, t)
- elif self.conditioning_key == 'crossattn':
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(x, t, context=cc)
- elif self.conditioning_key == 'hybrid':
- xc = torch.cat([x] + c_concat, dim=1)
- cc = torch.cat(c_crossattn, 1)
- out = self.diffusion_model(xc, t, context=cc)
- elif self.conditioning_key == 'adm':
- cc = c_crossattn[0]
- out = self.diffusion_model(x, t, y=cc)
- else:
- raise NotImplementedError()
-
- return out
-
-
-class Layout2ImgDiffusion(LatentDiffusion):
- # TODO: move all layout-specific hacks to this class
- def __init__(self, cond_stage_key, *args, **kwargs):
- assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
- super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
-
- def log_images(self, batch, N=8, *args, **kwargs):
- logs = super().log_images(batch=batch, N=N, *args, **kwargs)
-
- key = 'train' if self.training else 'validation'
- dset = self.trainer.datamodule.datasets[key]
- mapper = dset.conditional_builders[self.cond_stage_key]
-
- bbox_imgs = []
- map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
- for tknzd_bbox in batch[self.cond_stage_key][:N]:
- bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
- bbox_imgs.append(bboximg)
-
- cond_img = torch.stack(bbox_imgs, dim=0)
- logs['bbox_image'] = cond_img
- return logs
diff --git a/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py b/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py
deleted file mode 100644
index 78eeb1003aa4..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/models/diffusion/plms.py
+++ /dev/null
@@ -1,236 +0,0 @@
-"""SAMPLING ONLY."""
-
-import torch
-import numpy as np
-from tqdm import tqdm
-from functools import partial
-
-from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
-
-
-class PLMSSampler(object):
- def __init__(self, model, schedule="linear", **kwargs):
- super().__init__()
- self.model = model
- self.ddpm_num_timesteps = model.num_timesteps
- self.schedule = schedule
-
- def register_buffer(self, name, attr):
- if type(attr) == torch.Tensor:
- if attr.device != torch.device("cuda"):
- attr = attr.to(torch.device("cuda"))
- setattr(self, name, attr)
-
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
- if ddim_eta != 0:
- raise ValueError('ddim_eta must be 0 for PLMS')
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
- alphas_cumprod = self.model.alphas_cumprod
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
-
- self.register_buffer('betas', to_torch(self.model.betas))
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
-
- # ddim sampling parameters
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
- ddim_timesteps=self.ddim_timesteps,
- eta=ddim_eta,verbose=verbose)
- self.register_buffer('ddim_sigmas', ddim_sigmas)
- self.register_buffer('ddim_alphas', ddim_alphas)
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
-
- @torch.no_grad()
- def sample(self,
- S,
- batch_size,
- shape,
- conditioning=None,
- callback=None,
- normals_sequence=None,
- img_callback=None,
- quantize_x0=False,
- eta=0.,
- mask=None,
- x0=None,
- temperature=1.,
- noise_dropout=0.,
- score_corrector=None,
- corrector_kwargs=None,
- verbose=True,
- x_T=None,
- log_every_t=100,
- unconditional_guidance_scale=1.,
- unconditional_conditioning=None,
- # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
- **kwargs
- ):
- if conditioning is not None:
- if isinstance(conditioning, dict):
- cbs = conditioning[list(conditioning.keys())[0]].shape[0]
- if cbs != batch_size:
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
- else:
- if conditioning.shape[0] != batch_size:
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
-
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
- # sampling
- C, H, W = shape
- size = (batch_size, C, H, W)
- print(f'Data shape for PLMS sampling is {size}')
-
- samples, intermediates = self.plms_sampling(conditioning, size,
- callback=callback,
- img_callback=img_callback,
- quantize_denoised=quantize_x0,
- mask=mask, x0=x0,
- ddim_use_original_steps=False,
- noise_dropout=noise_dropout,
- temperature=temperature,
- score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- x_T=x_T,
- log_every_t=log_every_t,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- )
- return samples, intermediates
-
- @torch.no_grad()
- def plms_sampling(self, cond, shape,
- x_T=None, ddim_use_original_steps=False,
- callback=None, timesteps=None, quantize_denoised=False,
- mask=None, x0=None, img_callback=None, log_every_t=100,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None,):
- device = self.model.betas.device
- b = shape[0]
- if x_T is None:
- img = torch.randn(shape, device=device)
- else:
- img = x_T
-
- if timesteps is None:
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
- elif timesteps is not None and not ddim_use_original_steps:
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
- timesteps = self.ddim_timesteps[:subset_end]
-
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
- time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
- print(f"Running PLMS Sampling with {total_steps} timesteps")
-
- iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
- old_eps = []
-
- for i, step in enumerate(iterator):
- index = total_steps - i - 1
- ts = torch.full((b,), step, device=device, dtype=torch.long)
- ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
-
- if mask is not None:
- assert x0 is not None
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
- img = img_orig * mask + (1. - mask) * img
-
- outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
- quantize_denoised=quantize_denoised, temperature=temperature,
- noise_dropout=noise_dropout, score_corrector=score_corrector,
- corrector_kwargs=corrector_kwargs,
- unconditional_guidance_scale=unconditional_guidance_scale,
- unconditional_conditioning=unconditional_conditioning,
- old_eps=old_eps, t_next=ts_next)
- img, pred_x0, e_t = outs
- old_eps.append(e_t)
- if len(old_eps) >= 4:
- old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
-
- if index % log_every_t == 0 or index == total_steps - 1:
- intermediates['x_inter'].append(img)
- intermediates['pred_x0'].append(pred_x0)
-
- return img, intermediates
-
- @torch.no_grad()
- def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
- unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
- b, *_, device = *x.shape, x.device
-
- def get_model_output(x, t):
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
- e_t = self.model.apply_model(x, t, c)
- else:
- x_in = torch.cat([x] * 2)
- t_in = torch.cat([t] * 2)
- c_in = torch.cat([unconditional_conditioning, c])
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
-
- if score_corrector is not None:
- assert self.model.parameterization == "eps"
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
-
- return e_t
-
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
-
- def get_x_prev_and_pred_x0(e_t, index):
- # select parameters corresponding to the currently considered timestep
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
-
- # current prediction for x_0
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
- if quantize_denoised:
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
- # direction pointing to x_t
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
- if noise_dropout > 0.:
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
- return x_prev, pred_x0
-
- e_t = get_model_output(x, t)
- if len(old_eps) == 0:
- # Pseudo Improved Euler (2nd order)
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
- e_t_next = get_model_output(x_prev, t_next)
- e_t_prime = (e_t + e_t_next) / 2
- elif len(old_eps) == 1:
- # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (3 * e_t - old_eps[-1]) / 2
- elif len(old_eps) == 2:
- # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
- elif len(old_eps) >= 3:
- # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
- e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
-
- x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
-
- return x_prev, pred_x0, e_t
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/attention.py b/examples/tutorial/stable_diffusion/ldm/modules/attention.py
deleted file mode 100644
index 3401ceafddb4..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/attention.py
+++ /dev/null
@@ -1,314 +0,0 @@
-from inspect import isfunction
-import math
-import torch
-import torch.nn.functional as F
-from torch import nn, einsum
-from einops import rearrange, repeat
-
-from torch.utils import checkpoint
-
-try:
- from ldm.modules.flash_attention import flash_attention_qkv, flash_attention_q_kv
- FlASH_AVAILABLE = True
-except:
- FlASH_AVAILABLE = False
-
-USE_FLASH = False
-
-
-def enable_flash_attention():
- global USE_FLASH
- USE_FLASH = True
- if FlASH_AVAILABLE is False:
- print("Please install flash attention to activate new attention kernel.\n" +
- "Use \'pip install git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn\'")
-
-
-def exists(val):
- return val is not None
-
-
-def uniq(arr):
- return{el: True for el in arr}.keys()
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-def max_neg_value(t):
- return -torch.finfo(t.dtype).max
-
-
-def init_(tensor):
- dim = tensor.shape[-1]
- std = 1 / math.sqrt(dim)
- tensor.uniform_(-std, std)
- return tensor
-
-
-# feedforward
-class GEGLU(nn.Module):
- def __init__(self, dim_in, dim_out):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x):
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = default(dim_out, dim)
- project_in = nn.Sequential(
- nn.Linear(dim, inner_dim),
- nn.GELU()
- ) if not glu else GEGLU(dim, inner_dim)
-
- self.net = nn.Sequential(
- project_in,
- nn.Dropout(dropout),
- nn.Linear(inner_dim, dim_out)
- )
-
- def forward(self, x):
- return self.net(x)
-
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
-def Normalize(in_channels):
- return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
-
-
-class LinearAttention(nn.Module):
- def __init__(self, dim, heads=4, dim_head=32):
- super().__init__()
- self.heads = heads
- hidden_dim = dim_head * heads
- self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
- self.to_out = nn.Conv2d(hidden_dim, dim, 1)
-
- def forward(self, x):
- b, c, h, w = x.shape
- qkv = self.to_qkv(x)
- q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
- k = k.softmax(dim=-1)
- context = torch.einsum('bhdn,bhen->bhde', k, v)
- out = torch.einsum('bhde,bhdn->bhen', context, q)
- out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
- return self.to_out(out)
-
-
-class SpatialSelfAttention(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.k = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.v = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_out = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b,c,h,w = q.shape
- q = rearrange(q, 'b c h w -> b (h w) c')
- k = rearrange(k, 'b c h w -> b c (h w)')
- w_ = torch.einsum('bij,bjk->bik', q, k)
-
- w_ = w_ * (int(c)**(-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = rearrange(v, 'b c h w -> b c (h w)')
- w_ = rearrange(w_, 'b i j -> b j i')
- h_ = torch.einsum('bij,bjk->bik', v, w_)
- h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
- h_ = self.proj_out(h_)
-
- return x+h_
-
-
-class CrossAttention(nn.Module):
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
- super().__init__()
- inner_dim = dim_head * heads
- context_dim = default(context_dim, query_dim)
-
- self.scale = dim_head ** -0.5
- self.heads = heads
-
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
-
- self.to_out = nn.Sequential(
- nn.Linear(inner_dim, query_dim),
- nn.Dropout(dropout)
- )
-
- def forward(self, x, context=None, mask=None):
- q = self.to_q(x)
- context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
- dim_head = q.shape[-1] / self.heads
-
- if USE_FLASH and FlASH_AVAILABLE and q.dtype in (torch.float16, torch.bfloat16) and \
- dim_head <= 128 and (dim_head % 8) == 0:
- # print("in flash")
- if q.shape[1] == k.shape[1]:
- out = self._flash_attention_qkv(q, k, v)
- else:
- out = self._flash_attention_q_kv(q, k, v)
- else:
- out = self._native_attention(q, k, v, self.heads, mask)
-
- return self.to_out(out)
-
- def _native_attention(self, q, k, v, h, mask):
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
- sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
- if exists(mask):
- mask = rearrange(mask, 'b ... -> b (...)')
- max_neg_value = -torch.finfo(sim.dtype).max
- mask = repeat(mask, 'b j -> (b h) () j', h=h)
- sim.masked_fill_(~mask, max_neg_value)
- # attention, what we cannot get enough of
- out = sim.softmax(dim=-1)
- out = einsum('b i j, b j d -> b i d', out, v)
- out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
- return out
-
- def _flash_attention_qkv(self, q, k, v):
- qkv = torch.stack([q, k, v], dim=2)
- b = qkv.shape[0]
- n = qkv.shape[1]
- qkv = rearrange(qkv, 'b n t (h d) -> (b n) t h d', h=self.heads)
- out = flash_attention_qkv(qkv, self.scale, b, n)
- out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
- return out
-
- def _flash_attention_q_kv(self, q, k, v):
- kv = torch.stack([k, v], dim=2)
- b = q.shape[0]
- q_seqlen = q.shape[1]
- kv_seqlen = kv.shape[1]
- q = rearrange(q, 'b n (h d) -> (b n) h d', h=self.heads)
- kv = rearrange(kv, 'b n t (h d) -> (b n) t h d', h=self.heads)
- out = flash_attention_q_kv(q, kv, self.scale, b, q_seqlen, kv_seqlen)
- out = rearrange(out, '(b n) h d -> b n (h d)', b=b, h=self.heads)
- return out
-
-
-class BasicTransformerBlock(nn.Module):
- def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, use_checkpoint=False):
- super().__init__()
- self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
- self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
- heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
- self.norm1 = nn.LayerNorm(dim)
- self.norm2 = nn.LayerNorm(dim)
- self.norm3 = nn.LayerNorm(dim)
- self.use_checkpoint = use_checkpoint
-
- def forward(self, x, context=None):
-
-
- if self.use_checkpoint:
- return checkpoint(self._forward, x, context)
- else:
- return self._forward(x, context)
-
- def _forward(self, x, context=None):
- x = self.attn1(self.norm1(x)) + x
- x = self.attn2(self.norm2(x), context=context) + x
- x = self.ff(self.norm3(x)) + x
- return x
-
-
-
-class SpatialTransformer(nn.Module):
- """
- Transformer block for image-like data.
- First, project the input (aka embedding)
- and reshape to b, t, d.
- Then apply standard transformer action.
- Finally, reshape to image
- """
- def __init__(self, in_channels, n_heads, d_head,
- depth=1, dropout=0., context_dim=None, use_checkpoint=False):
- super().__init__()
- self.in_channels = in_channels
- inner_dim = n_heads * d_head
- self.norm = Normalize(in_channels)
-
- self.proj_in = nn.Conv2d(in_channels,
- inner_dim,
- kernel_size=1,
- stride=1,
- padding=0)
-
- self.transformer_blocks = nn.ModuleList(
- [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, use_checkpoint=use_checkpoint)
- for d in range(depth)]
- )
-
- self.proj_out = zero_module(nn.Conv2d(inner_dim,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0))
-
-
- def forward(self, x, context=None):
- # note: if no context is given, cross-attention defaults to self-attention
- b, c, h, w = x.shape
- x_in = x
- x = self.norm(x)
- x = self.proj_in(x)
- x = rearrange(x, 'b c h w -> b (h w) c')
- x = x.contiguous()
- for block in self.transformer_blocks:
- x = block(x, context=context)
- x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
- x = x.contiguous()
- x = self.proj_out(x)
- return x + x_in
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py
deleted file mode 100644
index 3c28492c5502..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/model.py
+++ /dev/null
@@ -1,862 +0,0 @@
-# pytorch_diffusion + derived encoder decoder
-import math
-import torch
-import torch.nn as nn
-import numpy as np
-from einops import rearrange
-
-from ldm.util import instantiate_from_config
-from ldm.modules.attention import LinearAttention
-
-
-def get_timestep_embedding(timesteps, embedding_dim):
- """
- This matches the implementation in Denoising Diffusion Probabilistic Models:
- From Fairseq.
- Build sinusoidal embeddings.
- This matches the implementation in tensor2tensor, but differs slightly
- from the description in Section 3.5 of "Attention Is All You Need".
- """
- assert len(timesteps.shape) == 1
-
- half_dim = embedding_dim // 2
- emb = math.log(10000) / (half_dim - 1)
- emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
- emb = emb.to(device=timesteps.device)
- emb = timesteps.float()[:, None] * emb[None, :]
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
- if embedding_dim % 2 == 1: # zero pad
- emb = torch.nn.functional.pad(emb, (0,1,0,0))
- return emb
-
-
-def nonlinearity(x):
- # swish
- return x*torch.sigmoid(x)
-
-
-def Normalize(in_channels, num_groups=32):
- return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
-
-
-class Upsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- self.conv = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x):
- x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
- if self.with_conv:
- x = self.conv(x)
- return x
-
-
-class Downsample(nn.Module):
- def __init__(self, in_channels, with_conv):
- super().__init__()
- self.with_conv = with_conv
- if self.with_conv:
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=3,
- stride=2,
- padding=0)
-
- def forward(self, x):
- if self.with_conv:
- pad = (0,1,0,1)
- x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
- x = self.conv(x)
- else:
- x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
- return x
-
-
-class ResnetBlock(nn.Module):
- def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
- dropout, temb_channels=512):
- super().__init__()
- self.in_channels = in_channels
- out_channels = in_channels if out_channels is None else out_channels
- self.out_channels = out_channels
- self.use_conv_shortcut = conv_shortcut
-
- self.norm1 = Normalize(in_channels)
- self.conv1 = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- if temb_channels > 0:
- self.temb_proj = torch.nn.Linear(temb_channels,
- out_channels)
- self.norm2 = Normalize(out_channels)
- self.dropout = torch.nn.Dropout(dropout)
- self.conv2 = torch.nn.Conv2d(out_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- self.conv_shortcut = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- else:
- self.nin_shortcut = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=1,
- stride=1,
- padding=0)
-
- def forward(self, x, temb):
- h = x
- h = self.norm1(h)
- h = nonlinearity(h)
- h = self.conv1(h)
-
- if temb is not None:
- h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
-
- h = self.norm2(h)
- h = nonlinearity(h)
- h = self.dropout(h)
- h = self.conv2(h)
-
- if self.in_channels != self.out_channels:
- if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
- else:
- x = self.nin_shortcut(x)
-
- return x+h
-
-
-class LinAttnBlock(LinearAttention):
- """to match AttnBlock usage"""
- def __init__(self, in_channels):
- super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
-
-
-class AttnBlock(nn.Module):
- def __init__(self, in_channels):
- super().__init__()
- self.in_channels = in_channels
-
- self.norm = Normalize(in_channels)
- self.q = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.k = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.v = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.proj_out = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=1,
- stride=1,
- padding=0)
-
-
- def forward(self, x):
- h_ = x
- h_ = self.norm(h_)
- q = self.q(h_)
- k = self.k(h_)
- v = self.v(h_)
-
- # compute attention
- b,c,h,w = q.shape
- q = q.reshape(b,c,h*w)
- q = q.permute(0,2,1) # b,hw,c
- k = k.reshape(b,c,h*w) # b,c,hw
- w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
- w_ = w_ * (int(c)**(-0.5))
- w_ = torch.nn.functional.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b,c,h*w)
- w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
- h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
- h_ = h_.reshape(b,c,h,w)
-
- h_ = self.proj_out(h_)
-
- return x+h_
-
-
-def make_attn(in_channels, attn_type="vanilla"):
- assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
- print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
- if attn_type == "vanilla":
- return AttnBlock(in_channels)
- elif attn_type == "none":
- return nn.Identity(in_channels)
- else:
- return LinAttnBlock(in_channels)
-
-class temb_module(nn.Module):
- def __init__(self):
- super().__init__()
- pass
-
-class Model(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
- super().__init__()
- if use_linear_attn: attn_type = "linear"
- self.ch = ch
- self.temb_ch = self.ch*4
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
-
- self.use_timestep = use_timestep
- if self.use_timestep:
- # timestep embedding
- # self.temb = nn.Module()
- self.temb = temb_module()
- self.temb.dense = nn.ModuleList([
- torch.nn.Linear(self.ch,
- self.temb_ch),
- torch.nn.Linear(self.temb_ch,
- self.temb_ch),
- ])
-
- # downsampling
- self.conv_in = torch.nn.Conv2d(in_channels,
- self.ch,
- kernel_size=3,
- stride=1,
- padding=1)
-
- curr_res = resolution
- in_ch_mult = (1,)+tuple(ch_mult)
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch*in_ch_mult[i_level]
- block_out = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- # down = nn.Module()
- down = Down_module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions-1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
-
- # middle
- # self.mid = nn.Module()
- self.mid = Mid_module()
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch*ch_mult[i_level]
- skip_in = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks+1):
- if i_block == self.num_res_blocks:
- skip_in = ch*in_ch_mult[i_level]
- block.append(ResnetBlock(in_channels=block_in+skip_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- # up = nn.Module()
- up = Up_module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- out_ch,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x, t=None, context=None):
- #assert x.shape[2] == x.shape[3] == self.resolution
- if context is not None:
- # assume aligned context, cat along channel axis
- x = torch.cat((x, context), dim=1)
- if self.use_timestep:
- # timestep embedding
- assert t is not None
- temb = get_timestep_embedding(t, self.ch)
- temb = self.temb.dense[0](temb)
- temb = nonlinearity(temb)
- temb = self.temb.dense[1](temb)
- else:
- temb = None
-
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions-1:
- hs.append(self.down[i_level].downsample(hs[-1]))
-
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
- h = self.up[i_level].block[i_block](
- torch.cat([h, hs.pop()], dim=1), temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
- def get_last_layer(self):
- return self.conv_out.weight
-
-class Down_module(nn.Module):
- def __init__(self):
- super().__init__()
- pass
-
-class Up_module(nn.Module):
- def __init__(self):
- super().__init__()
- pass
-
-class Mid_module(nn.Module):
- def __init__(self):
- super().__init__()
- pass
-
-
-class Encoder(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
- **ignore_kwargs):
- super().__init__()
- if use_linear_attn: attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
-
- # downsampling
- self.conv_in = torch.nn.Conv2d(in_channels,
- self.ch,
- kernel_size=3,
- stride=1,
- padding=1)
-
- curr_res = resolution
- in_ch_mult = (1,)+tuple(ch_mult)
- self.in_ch_mult = in_ch_mult
- self.down = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_in = ch*in_ch_mult[i_level]
- block_out = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks):
- block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- # down = nn.Module()
- down = Down_module()
- down.block = block
- down.attn = attn
- if i_level != self.num_resolutions-1:
- down.downsample = Downsample(block_in, resamp_with_conv)
- curr_res = curr_res // 2
- self.down.append(down)
-
- # middle
- # self.mid = nn.Module()
- self.mid = Mid_module()
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- 2*z_channels if double_z else z_channels,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x):
- # timestep embedding
- temb = None
-
- # downsampling
- hs = [self.conv_in(x)]
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1], temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- hs.append(h)
- if i_level != self.num_resolutions-1:
- hs.append(self.down[i_level].downsample(hs[-1]))
-
- # middle
- h = hs[-1]
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # end
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class Decoder(nn.Module):
- def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
- attn_type="vanilla", **ignorekwargs):
- super().__init__()
- if use_linear_attn: attn_type = "linear"
- self.ch = ch
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- self.resolution = resolution
- self.in_channels = in_channels
- self.give_pre_end = give_pre_end
- self.tanh_out = tanh_out
-
- # compute in_ch_mult, block_in and curr_res at lowest res
- in_ch_mult = (1,)+tuple(ch_mult)
- block_in = ch*ch_mult[self.num_resolutions-1]
- curr_res = resolution // 2**(self.num_resolutions-1)
- self.z_shape = (1,z_channels,curr_res,curr_res)
- print("Working with z of shape {} = {} dimensions.".format(
- self.z_shape, np.prod(self.z_shape)))
-
- # z to block_in
- self.conv_in = torch.nn.Conv2d(z_channels,
- block_in,
- kernel_size=3,
- stride=1,
- padding=1)
-
- # middle
- # self.mid = nn.Module()
- self.mid = Mid_module()
- self.mid.block_1 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
- self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
- self.mid.block_2 = ResnetBlock(in_channels=block_in,
- out_channels=block_in,
- temb_channels=self.temb_ch,
- dropout=dropout)
-
- # upsampling
- self.up = nn.ModuleList()
- for i_level in reversed(range(self.num_resolutions)):
- block = nn.ModuleList()
- attn = nn.ModuleList()
- block_out = ch*ch_mult[i_level]
- for i_block in range(self.num_res_blocks+1):
- block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- if curr_res in attn_resolutions:
- attn.append(make_attn(block_in, attn_type=attn_type))
- # up = nn.Module()
- up = Up_module()
- up.block = block
- up.attn = attn
- if i_level != 0:
- up.upsample = Upsample(block_in, resamp_with_conv)
- curr_res = curr_res * 2
- self.up.insert(0, up) # prepend to get consistent order
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- out_ch,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, z):
- #assert z.shape[1:] == self.z_shape[1:]
- self.last_z_shape = z.shape
-
- # timestep embedding
- temb = None
-
- # z to block_in
- h = self.conv_in(z)
-
- # middle
- h = self.mid.block_1(h, temb)
- h = self.mid.attn_1(h)
- h = self.mid.block_2(h, temb)
-
- # upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
- h = self.up[i_level].block[i_block](h, temb)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
-
- # end
- if self.give_pre_end:
- return h
-
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
-
-
-class SimpleDecoder(nn.Module):
- def __init__(self, in_channels, out_channels, *args, **kwargs):
- super().__init__()
- self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
- ResnetBlock(in_channels=in_channels,
- out_channels=2 * in_channels,
- temb_channels=0, dropout=0.0),
- ResnetBlock(in_channels=2 * in_channels,
- out_channels=4 * in_channels,
- temb_channels=0, dropout=0.0),
- ResnetBlock(in_channels=4 * in_channels,
- out_channels=2 * in_channels,
- temb_channels=0, dropout=0.0),
- nn.Conv2d(2*in_channels, in_channels, 1),
- Upsample(in_channels, with_conv=True)])
- # end
- self.norm_out = Normalize(in_channels)
- self.conv_out = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x):
- for i, layer in enumerate(self.model):
- if i in [1,2,3]:
- x = layer(x, None)
- else:
- x = layer(x)
-
- h = self.norm_out(x)
- h = nonlinearity(h)
- x = self.conv_out(h)
- return x
-
-
-class UpsampleDecoder(nn.Module):
- def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
- ch_mult=(2,2), dropout=0.0):
- super().__init__()
- # upsampling
- self.temb_ch = 0
- self.num_resolutions = len(ch_mult)
- self.num_res_blocks = num_res_blocks
- block_in = in_channels
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
- self.res_blocks = nn.ModuleList()
- self.upsample_blocks = nn.ModuleList()
- for i_level in range(self.num_resolutions):
- res_block = []
- block_out = ch * ch_mult[i_level]
- for i_block in range(self.num_res_blocks + 1):
- res_block.append(ResnetBlock(in_channels=block_in,
- out_channels=block_out,
- temb_channels=self.temb_ch,
- dropout=dropout))
- block_in = block_out
- self.res_blocks.append(nn.ModuleList(res_block))
- if i_level != self.num_resolutions - 1:
- self.upsample_blocks.append(Upsample(block_in, True))
- curr_res = curr_res * 2
-
- # end
- self.norm_out = Normalize(block_in)
- self.conv_out = torch.nn.Conv2d(block_in,
- out_channels,
- kernel_size=3,
- stride=1,
- padding=1)
-
- def forward(self, x):
- # upsampling
- h = x
- for k, i_level in enumerate(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks + 1):
- h = self.res_blocks[i_level][i_block](h, None)
- if i_level != self.num_resolutions - 1:
- h = self.upsample_blocks[k](h)
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
- return h
-
-
-class LatentRescaler(nn.Module):
- def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
- super().__init__()
- # residual block, interpolate, residual block
- self.factor = factor
- self.conv_in = nn.Conv2d(in_channels,
- mid_channels,
- kernel_size=3,
- stride=1,
- padding=1)
- self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0) for _ in range(depth)])
- self.attn = AttnBlock(mid_channels)
- self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
- out_channels=mid_channels,
- temb_channels=0,
- dropout=0.0) for _ in range(depth)])
-
- self.conv_out = nn.Conv2d(mid_channels,
- out_channels,
- kernel_size=1,
- )
-
- def forward(self, x):
- x = self.conv_in(x)
- for block in self.res_block1:
- x = block(x, None)
- x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
- x = self.attn(x)
- for block in self.res_block2:
- x = block(x, None)
- x = self.conv_out(x)
- return x
-
-
-class MergedRescaleEncoder(nn.Module):
- def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
- attn_resolutions, dropout=0.0, resamp_with_conv=True,
- ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
- super().__init__()
- intermediate_chn = ch * ch_mult[-1]
- self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
- z_channels=intermediate_chn, double_z=False, resolution=resolution,
- attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
- out_ch=None)
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
- mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
-
- def forward(self, x):
- x = self.encoder(x)
- x = self.rescaler(x)
- return x
-
-
-class MergedRescaleDecoder(nn.Module):
- def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
- dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
- super().__init__()
- tmp_chn = z_channels*ch_mult[-1]
- self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
- resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
- ch_mult=ch_mult, resolution=resolution, ch=ch)
- self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
- out_channels=tmp_chn, depth=rescale_module_depth)
-
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
-
-
-class Upsampler(nn.Module):
- def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
- super().__init__()
- assert out_size >= in_size
- num_blocks = int(np.log2(out_size//in_size))+1
- factor_up = 1.+ (out_size % in_size)
- print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
- self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
- out_channels=in_channels)
- self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
- attn_resolutions=[], in_channels=None, ch=in_channels,
- ch_mult=[ch_mult for _ in range(num_blocks)])
-
- def forward(self, x):
- x = self.rescaler(x)
- x = self.decoder(x)
- return x
-
-
-class Resize(nn.Module):
- def __init__(self, in_channels=None, learned=False, mode="bilinear"):
- super().__init__()
- self.with_conv = learned
- self.mode = mode
- if self.with_conv:
- print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
- raise NotImplementedError()
- assert in_channels is not None
- # no asymmetric padding in torch conv, must do it ourselves
- self.conv = torch.nn.Conv2d(in_channels,
- in_channels,
- kernel_size=4,
- stride=2,
- padding=1)
-
- def forward(self, x, scale_factor=1.0):
- if scale_factor==1.0:
- return x
- else:
- x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
- return x
-
-class FirstStagePostProcessor(nn.Module):
-
- def __init__(self, ch_mult:list, in_channels,
- pretrained_model:nn.Module=None,
- reshape=False,
- n_channels=None,
- dropout=0.,
- pretrained_config=None):
- super().__init__()
- if pretrained_config is None:
- assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
- self.pretrained_model = pretrained_model
- else:
- assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
- self.instantiate_pretrained(pretrained_config)
-
- self.do_reshape = reshape
-
- if n_channels is None:
- n_channels = self.pretrained_model.encoder.ch
-
- self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
- self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
- stride=1,padding=1)
-
- blocks = []
- downs = []
- ch_in = n_channels
- for m in ch_mult:
- blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
- ch_in = m * n_channels
- downs.append(Downsample(ch_in, with_conv=False))
-
- self.model = nn.ModuleList(blocks)
- self.downsampler = nn.ModuleList(downs)
-
-
- def instantiate_pretrained(self, config):
- model = instantiate_from_config(config)
- self.pretrained_model = model.eval()
- # self.pretrained_model.train = False
- for param in self.pretrained_model.parameters():
- param.requires_grad = False
-
-
- @torch.no_grad()
- def encode_with_pretrained(self,x):
- c = self.pretrained_model.encode(x)
- if isinstance(c, DiagonalGaussianDistribution):
- c = c.mode()
- return c
-
- def forward(self,x):
- z_fs = self.encode_with_pretrained(x)
- z = self.proj_norm(z_fs)
- z = self.proj(z)
- z = nonlinearity(z)
-
- for submodel, downmodel in zip(self.model,self.downsampler):
- z = submodel(z,temb=None)
- z = downmodel(z)
-
- if self.do_reshape:
- z = rearrange(z,'b c h w -> b (h w) c')
- return z
-
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py
deleted file mode 100644
index 3aedc2205e13..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/openaimodel.py
+++ /dev/null
@@ -1,1152 +0,0 @@
-from abc import abstractmethod
-from functools import partial
-import math
-from typing import Iterable
-
-import numpy as np
-import torch
-import torch as th
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.utils import checkpoint
-
-from ldm.modules.diffusionmodules.util import (
- conv_nd,
- linear,
- avg_pool_nd,
- zero_module,
- normalization,
- timestep_embedding,
-)
-from ldm.modules.attention import SpatialTransformer
-
-
-# dummy replace
-def convert_module_to_f16(x):
- # for n,p in x.named_parameter():
- # print(f"convert module {n} to_f16")
- # p.data = p.data.half()
- pass
-
-def convert_module_to_f32(x):
- pass
-
-
-## go
-class AttentionPool2d(nn.Module):
- """
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
- """
-
- def __init__(
- self,
- spacial_dim: int,
- embed_dim: int,
- num_heads_channels: int,
- output_dim: int = None,
- ):
- super().__init__()
- self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
- self.num_heads = embed_dim // num_heads_channels
- self.attention = QKVAttention(self.num_heads)
-
- def forward(self, x):
- b, c, *_spatial = x.shape
- x = x.reshape(b, c, -1) # NC(HW)
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
- x = self.qkv_proj(x)
- x = self.attention(x)
- x = self.c_proj(x)
- return x[:, :, 0]
-
-
-class TimestepBlock(nn.Module):
- """
- Any module where forward() takes timestep embeddings as a second argument.
- """
-
- @abstractmethod
- def forward(self, x, emb):
- """
- Apply the module to `x` given `emb` timestep embeddings.
- """
-
-
-class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
- """
- A sequential module that passes timestep embeddings to the children that
- support it as an extra input.
- """
-
- def forward(self, x, emb, context=None):
- for layer in self:
- if isinstance(layer, TimestepBlock):
- x = layer(x, emb)
- elif isinstance(layer, SpatialTransformer):
- x = layer(x, context)
- else:
- x = layer(x)
- return x
-
-
-class Upsample(nn.Module):
- """
- An upsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- upsampling occurs in the inner-two dimensions.
- """
-
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.dims = dims
- if use_conv:
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- if self.dims == 3:
- x = F.interpolate(
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
- )
- else:
- x = F.interpolate(x, scale_factor=2, mode="nearest")
- if self.use_conv:
- x = self.conv(x)
- return x
-
-class TransposedUpsample(nn.Module):
- 'Learned 2x upsampling without padding'
- def __init__(self, channels, out_channels=None, ks=5):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
-
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
-
- def forward(self,x):
- return self.up(x)
-
-
-class Downsample(nn.Module):
- """
- A downsampling layer with an optional convolution.
- :param channels: channels in the inputs and outputs.
- :param use_conv: a bool determining if a convolution is applied.
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
- downsampling occurs in the inner-two dimensions.
- """
-
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
- super().__init__()
- self.channels = channels
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.dims = dims
- stride = 2 if dims != 3 else (1, 2, 2)
- if use_conv:
- self.op = conv_nd(
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
- )
- else:
- assert self.channels == self.out_channels
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
-
- def forward(self, x):
- assert x.shape[1] == self.channels
- return self.op(x)
-
-
-class ResBlock(TimestepBlock):
- """
- A residual block that can optionally change the number of channels.
- :param channels: the number of input channels.
- :param emb_channels: the number of timestep embedding channels.
- :param dropout: the rate of dropout.
- :param out_channels: if specified, the number of out channels.
- :param use_conv: if True and out_channels is specified, use a spatial
- convolution instead of a smaller 1x1 convolution to change the
- channels in the skip connection.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param use_checkpoint: if True, use gradient checkpointing on this module.
- :param up: if True, use this block for upsampling.
- :param down: if True, use this block for downsampling.
- """
-
- def __init__(
- self,
- channels,
- emb_channels,
- dropout,
- out_channels=None,
- use_conv=False,
- use_scale_shift_norm=False,
- dims=2,
- use_checkpoint=False,
- up=False,
- down=False,
- ):
- super().__init__()
- self.channels = channels
- self.emb_channels = emb_channels
- self.dropout = dropout
- self.out_channels = out_channels or channels
- self.use_conv = use_conv
- self.use_checkpoint = use_checkpoint
- self.use_scale_shift_norm = use_scale_shift_norm
-
- self.in_layers = nn.Sequential(
- normalization(channels),
- nn.SiLU(),
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
- )
-
- self.updown = up or down
-
- if up:
- self.h_upd = Upsample(channels, False, dims)
- self.x_upd = Upsample(channels, False, dims)
- elif down:
- self.h_upd = Downsample(channels, False, dims)
- self.x_upd = Downsample(channels, False, dims)
- else:
- self.h_upd = self.x_upd = nn.Identity()
-
- self.emb_layers = nn.Sequential(
- nn.SiLU(),
- linear(
- emb_channels,
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
- ),
- )
- self.out_layers = nn.Sequential(
- normalization(self.out_channels),
- nn.SiLU(),
- nn.Dropout(p=dropout),
- zero_module(
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
- ),
- )
-
- if self.out_channels == channels:
- self.skip_connection = nn.Identity()
- elif use_conv:
- self.skip_connection = conv_nd(
- dims, channels, self.out_channels, 3, padding=1
- )
- else:
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
-
- def forward(self, x, emb):
- """
- Apply the block to a Tensor, conditioned on a timestep embedding.
- :param x: an [N x C x ...] Tensor of features.
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
- :return: an [N x C x ...] Tensor of outputs.
- """
- if self.use_checkpoint:
- return checkpoint(self._forward, x, emb)
- else:
- return self._forward(x, emb)
-
-
- def _forward(self, x, emb):
- if self.updown:
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
- h = in_rest(x)
- h = self.h_upd(h)
- x = self.x_upd(x)
- h = in_conv(h)
- else:
- h = self.in_layers(x)
- emb_out = self.emb_layers(emb).type(h.dtype)
- while len(emb_out.shape) < len(h.shape):
- emb_out = emb_out[..., None]
- if self.use_scale_shift_norm:
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
- scale, shift = th.chunk(emb_out, 2, dim=1)
- h = out_norm(h) * (1 + scale) + shift
- h = out_rest(h)
- else:
- h = h + emb_out
- h = self.out_layers(h)
- return self.skip_connection(x) + h
-
-
-class AttentionBlock(nn.Module):
- """
- An attention block that allows spatial positions to attend to each other.
- Originally ported from here, but adapted to the N-d case.
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
- """
-
- def __init__(
- self,
- channels,
- num_heads=1,
- num_head_channels=-1,
- use_checkpoint=False,
- use_new_attention_order=False,
- ):
- super().__init__()
- self.channels = channels
- if num_head_channels == -1:
- self.num_heads = num_heads
- else:
- assert (
- channels % num_head_channels == 0
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
- self.num_heads = channels // num_head_channels
- self.use_checkpoint = use_checkpoint
- self.norm = normalization(channels)
- self.qkv = conv_nd(1, channels, channels * 3, 1)
- if use_new_attention_order:
- # split qkv before split heads
- self.attention = QKVAttention(self.num_heads)
- else:
- # split heads before split qkv
- self.attention = QKVAttentionLegacy(self.num_heads)
-
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
-
- def forward(self, x):
- if self.use_checkpoint:
- return checkpoint(self._forward, x) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
- #return pt_checkpoint(self._forward, x) # pytorch
- else:
- return self._forward(x)
-
- def _forward(self, x):
- b, c, *spatial = x.shape
- x = x.reshape(b, c, -1)
- qkv = self.qkv(self.norm(x))
- h = self.attention(qkv)
- h = self.proj_out(h)
- return (x + h).reshape(b, c, *spatial)
-
-
-def count_flops_attn(model, _x, y):
- """
- A counter for the `thop` package to count the operations in an
- attention operation.
- Meant to be used like:
- macs, params = thop.profile(
- model,
- inputs=(inputs, timestamps),
- custom_ops={QKVAttention: QKVAttention.count_flops},
- )
- """
- b, c, *spatial = y[0].shape
- num_spatial = int(np.prod(spatial))
- # We perform two matmuls with the same number of ops.
- # The first computes the weight matrix, the second computes
- # the combination of the value vectors.
- matmul_ops = 2 * b * (num_spatial ** 2) * c
- model.total_ops += th.DoubleTensor([matmul_ops])
-
-
-class QKVAttentionLegacy(nn.Module):
- """
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
- """
-
- def __init__(self, n_heads):
- super().__init__()
- self.n_heads = n_heads
-
- def forward(self, qkv):
- """
- Apply QKV attention.
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
- :return: an [N x (H * C) x T] tensor after attention.
- """
- bs, width, length = qkv.shape
- assert width % (3 * self.n_heads) == 0
- ch = width // (3 * self.n_heads)
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
- scale = 1 / math.sqrt(math.sqrt(ch))
- weight = th.einsum(
- "bct,bcs->bts", q * scale, k * scale
- ) # More stable with f16 than dividing afterwards
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
- a = th.einsum("bts,bcs->bct", weight, v)
- return a.reshape(bs, -1, length)
-
- @staticmethod
- def count_flops(model, _x, y):
- return count_flops_attn(model, _x, y)
-
-
-class QKVAttention(nn.Module):
- """
- A module which performs QKV attention and splits in a different order.
- """
-
- def __init__(self, n_heads):
- super().__init__()
- self.n_heads = n_heads
-
- def forward(self, qkv):
- """
- Apply QKV attention.
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
- :return: an [N x (H * C) x T] tensor after attention.
- """
- bs, width, length = qkv.shape
- assert width % (3 * self.n_heads) == 0
- ch = width // (3 * self.n_heads)
- q, k, v = qkv.chunk(3, dim=1)
- scale = 1 / math.sqrt(math.sqrt(ch))
- weight = th.einsum(
- "bct,bcs->bts",
- (q * scale).view(bs * self.n_heads, ch, length),
- (k * scale).view(bs * self.n_heads, ch, length),
- ) # More stable with f16 than dividing afterwards
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
- return a.reshape(bs, -1, length)
-
- @staticmethod
- def count_flops(model, _x, y):
- return count_flops_attn(model, _x, y)
-
-
-class UNetModel(nn.Module):
- """
- The full UNet model with attention and timestep embedding.
- :param in_channels: channels in the input Tensor.
- :param model_channels: base channel count for the model.
- :param out_channels: channels in the output Tensor.
- :param num_res_blocks: number of residual blocks per downsample.
- :param attention_resolutions: a collection of downsample rates at which
- attention will take place. May be a set, list, or tuple.
- For example, if this contains 4, then at 4x downsampling, attention
- will be used.
- :param dropout: the dropout probability.
- :param channel_mult: channel multiplier for each level of the UNet.
- :param conv_resample: if True, use learned convolutions for upsampling and
- downsampling.
- :param dims: determines if the signal is 1D, 2D, or 3D.
- :param num_classes: if specified (as an int), then this model will be
- class-conditional with `num_classes` classes.
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
- :param num_heads: the number of attention heads in each attention layer.
- :param num_heads_channels: if specified, ignore num_heads and instead use
- a fixed channel width per attention head.
- :param num_heads_upsample: works with num_heads to set a different number
- of heads for upsampling. Deprecated.
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
- :param resblock_updown: use residual blocks for up/downsampling.
- :param use_new_attention_order: use a different attention pattern for potentially
- increased efficiency.
- """
-
- def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- num_classes=None,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=-1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- use_spatial_transformer=False, # custom transformer support
- transformer_depth=1, # custom transformer support
- context_dim=None, # custom transformer support
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
- legacy=True,
- from_pretrained: str=None
- ):
- super().__init__()
- if use_spatial_transformer:
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
-
- if context_dim is not None:
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
- from omegaconf.listconfig import ListConfig
- if type(context_dim) == ListConfig:
- context_dim = list(context_dim)
-
- if num_heads_upsample == -1:
- num_heads_upsample = num_heads
-
- if num_heads == -1:
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
-
- if num_head_channels == -1:
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
-
- self.image_size = image_size
- self.in_channels = in_channels
- self.model_channels = model_channels
- self.out_channels = out_channels
- self.num_res_blocks = num_res_blocks
- self.attention_resolutions = attention_resolutions
- self.dropout = dropout
- self.channel_mult = channel_mult
- self.conv_resample = conv_resample
- self.num_classes = num_classes
- self.use_checkpoint = use_checkpoint
- self.dtype = th.float16 if use_fp16 else th.float32
- self.num_heads = num_heads
- self.num_head_channels = num_head_channels
- self.num_heads_upsample = num_heads_upsample
- self.predict_codebook_ids = n_embed is not None
-
- time_embed_dim = model_channels * 4
- self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
-
- if self.num_classes is not None:
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
-
- self.input_blocks = nn.ModuleList(
- [
- TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
- )
- ]
- )
- self._feature_size = model_channels
- input_block_chans = [model_channels]
- ch = model_channels
- ds = 1
- for level, mult in enumerate(channel_mult):
- for _ in range(num_res_blocks):
- layers = [
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=mult * model_channels,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = mult * model_channels
- if ds in attention_resolutions:
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- #num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, use_checkpoint=use_checkpoint,
- )
- )
- self.input_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
- input_block_chans.append(ch)
- if level != len(channel_mult) - 1:
- out_ch = ch
- self.input_blocks.append(
- TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- down=True,
- )
- if resblock_updown
- else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
- )
- )
- )
- ch = out_ch
- input_block_chans.append(ch)
- ds *= 2
- self._feature_size += ch
-
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- #num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- self.middle_block = TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
- ),
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- )
- self._feature_size += ch
-
- self.output_blocks = nn.ModuleList([])
- for level, mult in list(enumerate(channel_mult))[::-1]:
- for i in range(num_res_blocks + 1):
- ich = input_block_chans.pop()
- layers = [
- ResBlock(
- ch + ich,
- time_embed_dim,
- dropout,
- out_channels=model_channels * mult,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = model_channels * mult
- if ds in attention_resolutions:
- if num_head_channels == -1:
- dim_head = ch // num_heads
- else:
- num_heads = ch // num_head_channels
- dim_head = num_head_channels
- if legacy:
- #num_heads = 1
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads_upsample,
- num_head_channels=dim_head,
- use_new_attention_order=use_new_attention_order,
- ) if not use_spatial_transformer else SpatialTransformer(
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
- )
- )
- if level and i == num_res_blocks:
- out_ch = ch
- layers.append(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- up=True,
- )
- if resblock_updown
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
- )
- ds //= 2
- self.output_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
-
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
- )
- if self.predict_codebook_ids:
- self.id_predictor = nn.Sequential(
- normalization(ch),
- conv_nd(dims, model_channels, n_embed, 1),
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
- )
- # if use_fp16:
- # self.convert_to_fp16()
- from diffusers.modeling_utils import load_state_dict
- if from_pretrained is not None:
- state_dict = load_state_dict(from_pretrained)
- self._load_pretrained_model(state_dict)
-
- def _input_blocks_mapping(self, input_dict):
- res_dict = {}
- for key_, value_ in input_dict.items():
- id_0 = int(key_[13])
- if "resnets" in key_:
- id_1 = int(key_[23])
- target_id = 3 * id_0 + 1 + id_1
- post_fix = key_[25:].replace('time_emb_proj', 'emb_layers.1')\
- .replace('norm1', 'in_layers.0')\
- .replace('norm2', 'out_layers.0')\
- .replace('conv1', 'in_layers.2')\
- .replace('conv2', 'out_layers.3')\
- .replace('conv_shortcut', 'skip_connection')
- res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_
- elif "attentions" in key_:
- id_1 = int(key_[26])
- target_id = 3 * id_0 + 1 + id_1
- post_fix = key_[28:]
- res_dict["input_blocks." + str(target_id) + '.1.' + post_fix] = value_
- elif "downsamplers" in key_:
- post_fix = key_[35:]
- target_id = 3 * (id_0 + 1)
- res_dict["input_blocks." + str(target_id) + '.0.op.' + post_fix] = value_
- return res_dict
-
-
- def _mid_blocks_mapping(self, mid_dict):
- res_dict = {}
- for key_, value_ in mid_dict.items():
- if "resnets" in key_:
- temp_key_ =key_.replace('time_emb_proj', 'emb_layers.1') \
- .replace('norm1', 'in_layers.0') \
- .replace('norm2', 'out_layers.0') \
- .replace('conv1', 'in_layers.2') \
- .replace('conv2', 'out_layers.3') \
- .replace('conv_shortcut', 'skip_connection')\
- .replace('middle_block.resnets.0', 'middle_block.0')\
- .replace('middle_block.resnets.1', 'middle_block.2')
- res_dict[temp_key_] = value_
- elif "attentions" in key_:
- res_dict[key_.replace('attentions.0', '1')] = value_
- return res_dict
-
- def _other_blocks_mapping(self, other_dict):
- res_dict = {}
- for key_, value_ in other_dict.items():
- tmp_key = key_.replace('conv_in', 'input_blocks.0.0')\
- .replace('time_embedding.linear_1', 'time_embed.0')\
- .replace('time_embedding.linear_2', 'time_embed.2')\
- .replace('conv_norm_out', 'out.0')\
- .replace('conv_out', 'out.2')
- res_dict[tmp_key] = value_
- return res_dict
-
-
- def _output_blocks_mapping(self, output_dict):
- res_dict = {}
- for key_, value_ in output_dict.items():
- id_0 = int(key_[14])
- if "resnets" in key_:
- id_1 = int(key_[24])
- target_id = 3 * id_0 + id_1
- post_fix = key_[26:].replace('time_emb_proj', 'emb_layers.1') \
- .replace('norm1', 'in_layers.0') \
- .replace('norm2', 'out_layers.0') \
- .replace('conv1', 'in_layers.2') \
- .replace('conv2', 'out_layers.3') \
- .replace('conv_shortcut', 'skip_connection')
- res_dict["output_blocks." + str(target_id) + '.0.' + post_fix] = value_
- elif "attentions" in key_:
- id_1 = int(key_[27])
- target_id = 3 * id_0 + id_1
- post_fix = key_[29:]
- res_dict["output_blocks." + str(target_id) + '.1.' + post_fix] = value_
- elif "upsamplers" in key_:
- post_fix = key_[34:]
- target_id = 3 * (id_0 + 1) - 1
- mid_str = '.2.conv.' if target_id != 2 else '.1.conv.'
- res_dict["output_blocks." + str(target_id) + mid_str + post_fix] = value_
- return res_dict
-
- def _state_key_mapping(self, state_dict: dict):
- import re
- res_dict = {}
- input_dict = {}
- mid_dict = {}
- output_dict = {}
- other_dict = {}
- for key_, value_ in state_dict.items():
- if "down_blocks" in key_:
- input_dict[key_.replace('down_blocks', 'input_blocks')] = value_
- elif "up_blocks" in key_:
- output_dict[key_.replace('up_blocks', 'output_blocks')] = value_
- elif "mid_block" in key_:
- mid_dict[key_.replace('mid_block', 'middle_block')] = value_
- else:
- other_dict[key_] = value_
-
- input_dict = self._input_blocks_mapping(input_dict)
- output_dict = self._output_blocks_mapping(output_dict)
- mid_dict = self._mid_blocks_mapping(mid_dict)
- other_dict = self._other_blocks_mapping(other_dict)
- # key_list = state_dict.keys()
- # key_str = " ".join(key_list)
-
- # for key_, val_ in state_dict.items():
- # key_ = key_.replace("down_blocks", "input_blocks")\
- # .replace("up_blocks", 'output_blocks')
- # res_dict[key_] = val_
- res_dict.update(input_dict)
- res_dict.update(output_dict)
- res_dict.update(mid_dict)
- res_dict.update(other_dict)
-
- return res_dict
-
- def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
- state_dict = self._state_key_mapping(state_dict)
- model_state_dict = self.state_dict()
- loaded_keys = [k for k in state_dict.keys()]
- expected_keys = list(model_state_dict.keys())
- original_loaded_keys = loaded_keys
- missing_keys = list(set(expected_keys) - set(loaded_keys))
- unexpected_keys = list(set(loaded_keys) - set(expected_keys))
-
- def _find_mismatched_keys(
- state_dict,
- model_state_dict,
- loaded_keys,
- ignore_mismatched_sizes,
- ):
- mismatched_keys = []
- if ignore_mismatched_sizes:
- for checkpoint_key in loaded_keys:
- model_key = checkpoint_key
-
- if (
- model_key in model_state_dict
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
- ):
- mismatched_keys.append(
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
- )
- del state_dict[checkpoint_key]
- return mismatched_keys
- if state_dict is not None:
- # Whole checkpoint
- mismatched_keys = _find_mismatched_keys(
- state_dict,
- model_state_dict,
- original_loaded_keys,
- ignore_mismatched_sizes,
- )
- error_msgs = self._load_state_dict_into_model(state_dict)
- return missing_keys, unexpected_keys, mismatched_keys, error_msgs
-
- def _load_state_dict_into_model(self, state_dict):
- # Convert old format to new format if needed from a PyTorch state_dict
- # copy state_dict so _load_from_state_dict can modify it
- state_dict = state_dict.copy()
- error_msgs = []
-
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
- # so we need to apply the function recursively.
- def load(module: torch.nn.Module, prefix=""):
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
- module._load_from_state_dict(*args)
-
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + ".")
-
- load(self)
-
- return error_msgs
-
- def convert_to_fp16(self):
- """
- Convert the torso of the model to float16.
- """
- self.input_blocks.apply(convert_module_to_f16)
- self.middle_block.apply(convert_module_to_f16)
- self.output_blocks.apply(convert_module_to_f16)
-
- def convert_to_fp32(self):
- """
- Convert the torso of the model to float32.
- """
- self.input_blocks.apply(convert_module_to_f32)
- self.middle_block.apply(convert_module_to_f32)
- self.output_blocks.apply(convert_module_to_f32)
-
- def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
- """
- Apply the model to an input batch.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :param context: conditioning plugged in via crossattn
- :param y: an [N] Tensor of labels, if class-conditional.
- :return: an [N x C x ...] Tensor of outputs.
- """
- assert (y is not None) == (
- self.num_classes is not None
- ), "must specify y if and only if the model is class-conditional"
- hs = []
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
- emb = self.time_embed(t_emb)
-
- if self.num_classes is not None:
- assert y.shape == (x.shape[0],)
- emb = emb + self.label_emb(y)
-
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb, context)
- hs.append(h)
- h = self.middle_block(h, emb, context)
- for module in self.output_blocks:
- h = th.cat([h, hs.pop()], dim=1)
- h = module(h, emb, context)
- h = h.type(self.dtype)
- if self.predict_codebook_ids:
- return self.id_predictor(h)
- else:
- return self.out(h)
-
-
-class EncoderUNetModel(nn.Module):
- """
- The half UNet model with attention and timestep embedding.
- For usage, see UNet.
- """
-
- def __init__(
- self,
- image_size,
- in_channels,
- model_channels,
- out_channels,
- num_res_blocks,
- attention_resolutions,
- dropout=0,
- channel_mult=(1, 2, 4, 8),
- conv_resample=True,
- dims=2,
- use_checkpoint=False,
- use_fp16=False,
- num_heads=1,
- num_head_channels=-1,
- num_heads_upsample=-1,
- use_scale_shift_norm=False,
- resblock_updown=False,
- use_new_attention_order=False,
- pool="adaptive",
- *args,
- **kwargs
- ):
- super().__init__()
-
- if num_heads_upsample == -1:
- num_heads_upsample = num_heads
-
- self.in_channels = in_channels
- self.model_channels = model_channels
- self.out_channels = out_channels
- self.num_res_blocks = num_res_blocks
- self.attention_resolutions = attention_resolutions
- self.dropout = dropout
- self.channel_mult = channel_mult
- self.conv_resample = conv_resample
- self.use_checkpoint = use_checkpoint
- self.dtype = th.float16 if use_fp16 else th.float32
- self.num_heads = num_heads
- self.num_head_channels = num_head_channels
- self.num_heads_upsample = num_heads_upsample
-
- time_embed_dim = model_channels * 4
- self.time_embed = nn.Sequential(
- linear(model_channels, time_embed_dim),
- nn.SiLU(),
- linear(time_embed_dim, time_embed_dim),
- )
-
- self.input_blocks = nn.ModuleList(
- [
- TimestepEmbedSequential(
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
- )
- ]
- )
- self._feature_size = model_channels
- input_block_chans = [model_channels]
- ch = model_channels
- ds = 1
- for level, mult in enumerate(channel_mult):
- for _ in range(num_res_blocks):
- layers = [
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=mult * model_channels,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- )
- ]
- ch = mult * model_channels
- if ds in attention_resolutions:
- layers.append(
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=num_head_channels,
- use_new_attention_order=use_new_attention_order,
- )
- )
- self.input_blocks.append(TimestepEmbedSequential(*layers))
- self._feature_size += ch
- input_block_chans.append(ch)
- if level != len(channel_mult) - 1:
- out_ch = ch
- self.input_blocks.append(
- TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- out_channels=out_ch,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- down=True,
- )
- if resblock_updown
- else Downsample(
- ch, conv_resample, dims=dims, out_channels=out_ch
- )
- )
- )
- ch = out_ch
- input_block_chans.append(ch)
- ds *= 2
- self._feature_size += ch
-
- self.middle_block = TimestepEmbedSequential(
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- AttentionBlock(
- ch,
- use_checkpoint=use_checkpoint,
- num_heads=num_heads,
- num_head_channels=num_head_channels,
- use_new_attention_order=use_new_attention_order,
- ),
- ResBlock(
- ch,
- time_embed_dim,
- dropout,
- dims=dims,
- use_checkpoint=use_checkpoint,
- use_scale_shift_norm=use_scale_shift_norm,
- ),
- )
- self._feature_size += ch
- self.pool = pool
- if pool == "adaptive":
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- nn.AdaptiveAvgPool2d((1, 1)),
- zero_module(conv_nd(dims, ch, out_channels, 1)),
- nn.Flatten(),
- )
- elif pool == "attention":
- assert num_head_channels != -1
- self.out = nn.Sequential(
- normalization(ch),
- nn.SiLU(),
- AttentionPool2d(
- (image_size // ds), ch, num_head_channels, out_channels
- ),
- )
- elif pool == "spatial":
- self.out = nn.Sequential(
- nn.Linear(self._feature_size, 2048),
- nn.ReLU(),
- nn.Linear(2048, self.out_channels),
- )
- elif pool == "spatial_v2":
- self.out = nn.Sequential(
- nn.Linear(self._feature_size, 2048),
- normalization(2048),
- nn.SiLU(),
- nn.Linear(2048, self.out_channels),
- )
- else:
- raise NotImplementedError(f"Unexpected {pool} pooling")
-
- def convert_to_fp16(self):
- """
- Convert the torso of the model to float16.
- """
- self.input_blocks.apply(convert_module_to_f16)
- self.middle_block.apply(convert_module_to_f16)
-
- def convert_to_fp32(self):
- """
- Convert the torso of the model to float32.
- """
- self.input_blocks.apply(convert_module_to_f32)
- self.middle_block.apply(convert_module_to_f32)
-
- def forward(self, x, timesteps):
- """
- Apply the model to an input batch.
- :param x: an [N x C x ...] Tensor of inputs.
- :param timesteps: a 1-D batch of timesteps.
- :return: an [N x K] Tensor of outputs.
- """
- emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
-
- results = []
- h = x.type(self.dtype)
- for module in self.input_blocks:
- h = module(h, emb)
- if self.pool.startswith("spatial"):
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
- h = self.middle_block(h, emb)
- if self.pool.startswith("spatial"):
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
- h = th.cat(results, axis=-1)
- return self.out(h)
- else:
- h = h.type(self.dtype)
- return self.out(h)
-
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py b/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py
deleted file mode 100644
index a7db9369c58a..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/diffusionmodules/util.py
+++ /dev/null
@@ -1,276 +0,0 @@
-# adopted from
-# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
-# and
-# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
-# and
-# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
-#
-# thanks!
-
-
-import os
-import math
-import torch
-import torch.nn as nn
-import numpy as np
-from einops import repeat
-
-from ldm.util import instantiate_from_config
-
-
-def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
- if schedule == "linear":
- betas = (
- torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
- )
-
- elif schedule == "cosine":
- timesteps = (
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
- )
- alphas = timesteps / (1 + cosine_s) * np.pi / 2
- alphas = torch.cos(alphas).pow(2)
- alphas = alphas / alphas[0]
- betas = 1 - alphas[1:] / alphas[:-1]
- betas = np.clip(betas, a_min=0, a_max=0.999)
-
- elif schedule == "sqrt_linear":
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
- elif schedule == "sqrt":
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
- else:
- raise ValueError(f"schedule '{schedule}' unknown.")
- return betas.numpy()
-
-
-def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
- if ddim_discr_method == 'uniform':
- c = num_ddpm_timesteps // num_ddim_timesteps
- ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
- elif ddim_discr_method == 'quad':
- ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
- else:
- raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
-
- # assert ddim_timesteps.shape[0] == num_ddim_timesteps
- # add one to get the final alpha values right (the ones from first scale to data during sampling)
- steps_out = ddim_timesteps + 1
- if verbose:
- print(f'Selected timesteps for ddim sampler: {steps_out}')
- return steps_out
-
-
-def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
- # select alphas for computing the variance schedule
- alphas = alphacums[ddim_timesteps]
- alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
-
- # according the the formula provided in https://arxiv.org/abs/2010.02502
- sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
- if verbose:
- print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
- print(f'For the chosen value of eta, which is {eta}, '
- f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
- return sigmas, alphas, alphas_prev
-
-
-def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
- """
- Create a beta schedule that discretizes the given alpha_t_bar function,
- which defines the cumulative product of (1-beta) over time from t = [0,1].
- :param num_diffusion_timesteps: the number of betas to produce.
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
- produces the cumulative product of (1-beta) up to that
- part of the diffusion process.
- :param max_beta: the maximum beta to use; use values lower than 1 to
- prevent singularities.
- """
- betas = []
- for i in range(num_diffusion_timesteps):
- t1 = i / num_diffusion_timesteps
- t2 = (i + 1) / num_diffusion_timesteps
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
- return np.array(betas)
-
-
-def extract_into_tensor(a, t, x_shape):
- b, *_ = t.shape
- out = a.gather(-1, t)
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
-
-
-def checkpoint(func, inputs, params, flag):
- """
- Evaluate a function without caching intermediate activations, allowing for
- reduced memory at the expense of extra compute in the backward pass.
- :param func: the function to evaluate.
- :param inputs: the argument sequence to pass to `func`.
- :param params: a sequence of parameters `func` depends on but does not
- explicitly take as arguments.
- :param flag: if False, disable gradient checkpointing.
- """
- if flag:
- args = tuple(inputs) + tuple(params)
- return CheckpointFunction.apply(func, len(inputs), *args)
- else:
- return func(*inputs)
-
-
-class CheckpointFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, run_function, length, *args):
- ctx.run_function = run_function
- ctx.input_tensors = list(args[:length])
- ctx.input_params = list(args[length:])
-
- with torch.no_grad():
- output_tensors = ctx.run_function(*ctx.input_tensors)
- return output_tensors
-
- @staticmethod
- def backward(ctx, *output_grads):
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
- with torch.enable_grad():
- # Fixes a bug where the first op in run_function modifies the
- # Tensor storage in place, which is not allowed for detach()'d
- # Tensors.
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
- output_tensors = ctx.run_function(*shallow_copies)
- input_grads = torch.autograd.grad(
- output_tensors,
- ctx.input_tensors + ctx.input_params,
- output_grads,
- allow_unused=True,
- )
- del ctx.input_tensors
- del ctx.input_params
- del output_tensors
- return (None, None) + input_grads
-
-
-def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False, use_fp16=True):
- """
- Create sinusoidal timestep embeddings.
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an [N x dim] Tensor of positional embeddings.
- """
- if not repeat_only:
- half = dim // 2
- freqs = torch.exp(
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
- ).to(device=timesteps.device)
- args = timesteps[:, None].float() * freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if dim % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- else:
- embedding = repeat(timesteps, 'b -> b d', d=dim)
- if use_fp16:
- return embedding.half()
- else:
- return embedding
-
-
-def zero_module(module):
- """
- Zero out the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().zero_()
- return module
-
-
-def scale_module(module, scale):
- """
- Scale the parameters of a module and return it.
- """
- for p in module.parameters():
- p.detach().mul_(scale)
- return module
-
-
-def mean_flat(tensor):
- """
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
-
-
-def normalization(channels, precision=16):
- """
- Make a standard normalization layer.
- :param channels: number of input channels.
- :return: an nn.Module for normalization.
- """
- if precision == 16:
- return GroupNorm16(16, channels)
- else:
- return GroupNorm32(32, channels)
-
-
-# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
-class SiLU(nn.Module):
- def forward(self, x):
- return x * torch.sigmoid(x)
-
-class GroupNorm16(nn.GroupNorm):
- def forward(self, x):
- return super().forward(x.half()).type(x.dtype)
-
-class GroupNorm32(nn.GroupNorm):
- def forward(self, x):
- return super().forward(x.float()).type(x.dtype)
-
-def conv_nd(dims, *args, **kwargs):
- """
- Create a 1D, 2D, or 3D convolution module.
- """
- if dims == 1:
- return nn.Conv1d(*args, **kwargs)
- elif dims == 2:
- return nn.Conv2d(*args, **kwargs)
- elif dims == 3:
- return nn.Conv3d(*args, **kwargs)
- raise ValueError(f"unsupported dimensions: {dims}")
-
-
-def linear(*args, **kwargs):
- """
- Create a linear module.
- """
- return nn.Linear(*args, **kwargs)
-
-
-def avg_pool_nd(dims, *args, **kwargs):
- """
- Create a 1D, 2D, or 3D average pooling module.
- """
- if dims == 1:
- return nn.AvgPool1d(*args, **kwargs)
- elif dims == 2:
- return nn.AvgPool2d(*args, **kwargs)
- elif dims == 3:
- return nn.AvgPool3d(*args, **kwargs)
- raise ValueError(f"unsupported dimensions: {dims}")
-
-
-class HybridConditioner(nn.Module):
-
- def __init__(self, c_concat_config, c_crossattn_config):
- super().__init__()
- self.concat_conditioner = instantiate_from_config(c_concat_config)
- self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
-
- def forward(self, c_concat, c_crossattn):
- c_concat = self.concat_conditioner(c_concat)
- c_crossattn = self.crossattn_conditioner(c_crossattn)
- return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
-
-
-def noise_like(shape, device, repeat=False):
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
- noise = lambda: torch.randn(shape, device=device)
- return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py b/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py
deleted file mode 100644
index f2b8ef901130..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/distributions/distributions.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import torch
-import numpy as np
-
-
-class AbstractDistribution:
- def sample(self):
- raise NotImplementedError()
-
- def mode(self):
- raise NotImplementedError()
-
-
-class DiracDistribution(AbstractDistribution):
- def __init__(self, value):
- self.value = value
-
- def sample(self):
- return self.value
-
- def mode(self):
- return self.value
-
-
-class DiagonalGaussianDistribution(object):
- def __init__(self, parameters, deterministic=False):
- self.parameters = parameters
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
- self.deterministic = deterministic
- self.std = torch.exp(0.5 * self.logvar)
- self.var = torch.exp(self.logvar)
- if self.deterministic:
- self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
-
- def sample(self):
- x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
- return x
-
- def kl(self, other=None):
- if self.deterministic:
- return torch.Tensor([0.])
- else:
- if other is None:
- return 0.5 * torch.sum(torch.pow(self.mean, 2)
- + self.var - 1.0 - self.logvar,
- dim=[1, 2, 3])
- else:
- return 0.5 * torch.sum(
- torch.pow(self.mean - other.mean, 2) / other.var
- + self.var / other.var - 1.0 - self.logvar + other.logvar,
- dim=[1, 2, 3])
-
- def nll(self, sample, dims=[1,2,3]):
- if self.deterministic:
- return torch.Tensor([0.])
- logtwopi = np.log(2.0 * np.pi)
- return 0.5 * torch.sum(
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
- dim=dims)
-
- def mode(self):
- return self.mean
-
-
-def normal_kl(mean1, logvar1, mean2, logvar2):
- """
- source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
- Compute the KL divergence between two gaussians.
- Shapes are automatically broadcasted, so batches can be compared to
- scalars, among other use cases.
- """
- tensor = None
- for obj in (mean1, logvar1, mean2, logvar2):
- if isinstance(obj, torch.Tensor):
- tensor = obj
- break
- assert tensor is not None, "at least one argument must be a Tensor"
-
- # Force variances to be Tensors. Broadcasting helps convert scalars to
- # Tensors, but it does not work for torch.exp().
- logvar1, logvar2 = [
- x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
- for x in (logvar1, logvar2)
- ]
-
- return 0.5 * (
- -1.0
- + logvar2
- - logvar1
- + torch.exp(logvar1 - logvar2)
- + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
- )
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/ema.py b/examples/tutorial/stable_diffusion/ldm/modules/ema.py
deleted file mode 100644
index c8c75af43565..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/ema.py
+++ /dev/null
@@ -1,76 +0,0 @@
-import torch
-from torch import nn
-
-
-class LitEma(nn.Module):
- def __init__(self, model, decay=0.9999, use_num_upates=True):
- super().__init__()
- if decay < 0.0 or decay > 1.0:
- raise ValueError('Decay must be between 0 and 1')
-
- self.m_name2s_name = {}
- self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
- self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
- else torch.tensor(-1,dtype=torch.int))
-
- for name, p in model.named_parameters():
- if p.requires_grad:
- #remove as '.'-character is not allowed in buffers
- s_name = name.replace('.','')
- self.m_name2s_name.update({name:s_name})
- self.register_buffer(s_name,p.clone().detach().data)
-
- self.collected_params = []
-
- def forward(self,model):
- decay = self.decay
-
- if self.num_updates >= 0:
- self.num_updates += 1
- decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
-
- one_minus_decay = 1.0 - decay
-
- with torch.no_grad():
- m_param = dict(model.named_parameters())
- shadow_params = dict(self.named_buffers())
-
- for key in m_param:
- if m_param[key].requires_grad:
- sname = self.m_name2s_name[key]
- shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
- shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
- else:
- assert not key in self.m_name2s_name
-
- def copy_to(self, model):
- m_param = dict(model.named_parameters())
- shadow_params = dict(self.named_buffers())
- for key in m_param:
- if m_param[key].requires_grad:
- m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
- else:
- assert not key in self.m_name2s_name
-
- def store(self, parameters):
- """
- Save the current parameters for restoring later.
- Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- temporarily stored.
- """
- self.collected_params = [param.clone() for param in parameters]
-
- def restore(self, parameters):
- """
- Restore the parameters stored with the `store` method.
- Useful to validate the model with EMA parameters without affecting the
- original optimization process. Store the parameters before the
- `copy_to` method. After validation (or model saving), use this to
- restore the former parameters.
- Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored parameters.
- """
- for c_param, param in zip(self.collected_params, parameters):
- param.data.copy_(c_param.data)
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py b/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py
deleted file mode 100644
index 8cfc01e5ded4..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/encoders/modules.py
+++ /dev/null
@@ -1,264 +0,0 @@
-import types
-
-import torch
-import torch.nn as nn
-from functools import partial
-import clip
-from einops import rearrange, repeat
-from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextConfig
-import kornia
-from transformers.models.clip.modeling_clip import CLIPTextTransformer
-
-from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
-
-
-class AbstractEncoder(nn.Module):
- def __init__(self):
- super().__init__()
-
- def encode(self, *args, **kwargs):
- raise NotImplementedError
-
-
-
-class ClassEmbedder(nn.Module):
- def __init__(self, embed_dim, n_classes=1000, key='class'):
- super().__init__()
- self.key = key
- self.embedding = nn.Embedding(n_classes, embed_dim)
-
- def forward(self, batch, key=None):
- if key is None:
- key = self.key
- # this is for use in crossattn
- c = batch[key][:, None]
- c = self.embedding(c)
- return c
-
-
-class TransformerEmbedder(AbstractEncoder):
- """Some transformer encoder layers"""
- def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
- super().__init__()
- self.device = device
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
- attn_layers=Encoder(dim=n_embed, depth=n_layer))
-
- def forward(self, tokens):
- tokens = tokens.to(self.device) # meh
- z = self.transformer(tokens, return_embeddings=True)
- return z
-
- def encode(self, x):
- return self(x)
-
-
-class BERTTokenizer(AbstractEncoder):
- """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
- def __init__(self, device="cuda", vq_interface=True, max_length=77):
- super().__init__()
- from transformers import BertTokenizerFast # TODO: add to reuquirements
- self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
- self.device = device
- self.vq_interface = vq_interface
- self.max_length = max_length
-
- def forward(self, text):
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
- tokens = batch_encoding["input_ids"].to(self.device)
- return tokens
-
- @torch.no_grad()
- def encode(self, text):
- tokens = self(text)
- if not self.vq_interface:
- return tokens
- return None, None, [None, None, tokens]
-
- def decode(self, text):
- return text
-
-
-class BERTEmbedder(AbstractEncoder):
- """Uses the BERT tokenizr model and add some transformer encoder layers"""
- def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
- device="cuda",use_tokenizer=True, embedding_dropout=0.0):
- super().__init__()
- self.use_tknz_fn = use_tokenizer
- if self.use_tknz_fn:
- self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
- self.device = device
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
- attn_layers=Encoder(dim=n_embed, depth=n_layer),
- emb_dropout=embedding_dropout)
-
- def forward(self, text):
- if self.use_tknz_fn:
- tokens = self.tknz_fn(text)#.to(self.device)
- else:
- tokens = text
- z = self.transformer(tokens, return_embeddings=True)
- return z
-
- def encode(self, text):
- # output of length 77
- return self(text)
-
-
-class SpatialRescaler(nn.Module):
- def __init__(self,
- n_stages=1,
- method='bilinear',
- multiplier=0.5,
- in_channels=3,
- out_channels=None,
- bias=False):
- super().__init__()
- self.n_stages = n_stages
- assert self.n_stages >= 0
- assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
- self.multiplier = multiplier
- self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
- self.remap_output = out_channels is not None
- if self.remap_output:
- print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
- self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
-
- def forward(self,x):
- for stage in range(self.n_stages):
- x = self.interpolator(x, scale_factor=self.multiplier)
-
-
- if self.remap_output:
- x = self.channel_mapper(x)
- return x
-
- def encode(self, x):
- return self(x)
-
-
-class CLIPTextModelZero(CLIPTextModel):
- config_class = CLIPTextConfig
-
- def __init__(self, config: CLIPTextConfig):
- super().__init__(config)
- self.text_model = CLIPTextTransformerZero(config)
-
-class CLIPTextTransformerZero(CLIPTextTransformer):
- def _build_causal_attention_mask(self, bsz, seq_len):
- # lazily create causal attention mask, with full attention between the vision tokens
- # pytorch uses additive attention mask; fill with -inf
- mask = torch.empty(bsz, seq_len, seq_len)
- mask.fill_(float("-inf"))
- mask.triu_(1) # zero out the lower diagonal
- mask = mask.unsqueeze(1) # expand mask
- return mask.half()
-
-class FrozenCLIPEmbedder(AbstractEncoder):
- """Uses the CLIP transformer encoder for text (from Hugging Face)"""
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, use_fp16=True):
- super().__init__()
- self.tokenizer = CLIPTokenizer.from_pretrained(version)
-
- if use_fp16:
- self.transformer = CLIPTextModelZero.from_pretrained(version)
- else:
- self.transformer = CLIPTextModel.from_pretrained(version)
-
- # print(self.transformer.modules())
- # print("check model dtyoe: {}, {}".format(self.tokenizer.dtype, self.transformer.dtype))
- self.device = device
- self.max_length = max_length
- self.freeze()
-
- def freeze(self):
- self.transformer = self.transformer.eval()
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
- # tokens = batch_encoding["input_ids"].to(self.device)
- tokens = batch_encoding["input_ids"].to(self.device)
- # print("token type: {}".format(tokens.dtype))
- outputs = self.transformer(input_ids=tokens)
-
- z = outputs.last_hidden_state
- return z
-
- def encode(self, text):
- return self(text)
-
-
-class FrozenCLIPTextEmbedder(nn.Module):
- """
- Uses the CLIP transformer encoder for text.
- """
- def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
- super().__init__()
- self.model, _ = clip.load(version, jit=False, device="cpu")
- self.device = device
- self.max_length = max_length
- self.n_repeat = n_repeat
- self.normalize = normalize
-
- def freeze(self):
- self.model = self.model.eval()
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, text):
- tokens = clip.tokenize(text).to(self.device)
- z = self.model.encode_text(tokens)
- if self.normalize:
- z = z / torch.linalg.norm(z, dim=1, keepdim=True)
- return z
-
- def encode(self, text):
- z = self(text)
- if z.ndim==2:
- z = z[:, None, :]
- z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
- return z
-
-
-class FrozenClipImageEmbedder(nn.Module):
- """
- Uses the CLIP image encoder.
- """
- def __init__(
- self,
- model,
- jit=False,
- device='cuda' if torch.cuda.is_available() else 'cpu',
- antialias=False,
- ):
- super().__init__()
- self.model, _ = clip.load(name=model, device=device, jit=jit)
-
- self.antialias = antialias
-
- self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
- self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
-
- def preprocess(self, x):
- # normalize to [0,1]
- x = kornia.geometry.resize(x, (224, 224),
- interpolation='bicubic',align_corners=True,
- antialias=self.antialias)
- x = (x + 1.) / 2.
- # renormalize according to clip
- x = kornia.enhance.normalize(x, self.mean, self.std)
- return x
-
- def forward(self, x):
- # x is assumed to be in range [-1,1]
- return self.model.encode_image(self.preprocess(x))
-
-
-if __name__ == "__main__":
- from ldm.util import count_params
- model = FrozenCLIPEmbedder()
- count_params(model, verbose=True)
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py b/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py
deleted file mode 100644
index 2a7a73879857..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/flash_attention.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""
-Fused Attention
-===============
-This is a Triton implementation of the Flash Attention algorithm
-(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf; Triton https://github.com/openai/triton)
-"""
-
-import torch
-try:
- from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func, flash_attn_unpadded_kvpacked_func
-except ImportError:
- raise ImportError('please install flash_attn from https://github.com/HazyResearch/flash-attention')
-
-
-
-def flash_attention_qkv(qkv, sm_scale, batch_size, seq_len):
- """
- Arguments:
- qkv: (batch*seq, 3, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Return:
- out: (total, nheads, headdim).
- """
- max_s = seq_len
- cu_seqlens = torch.arange(0, (batch_size + 1) * seq_len, step=seq_len, dtype=torch.int32,
- device=qkv.device)
- out = flash_attn_unpadded_qkvpacked_func(
- qkv, cu_seqlens, max_s, 0.0,
- softmax_scale=sm_scale, causal=False
- )
- return out
-
-
-def flash_attention_q_kv(q, kv, sm_scale, batch_size, q_seqlen, kv_seqlen):
- """
- Arguments:
- q: (batch*seq, nheads, headdim)
- kv: (batch*seq, 2, nheads, headdim)
- batch_size: int.
- seq_len: int.
- sm_scale: float. The scaling of QK^T before applying softmax.
- Return:
- out: (total, nheads, headdim).
- """
- cu_seqlens_q = torch.arange(0, (batch_size + 1) * q_seqlen, step=q_seqlen, dtype=torch.int32, device=q.device)
- cu_seqlens_k = torch.arange(0, (batch_size + 1) * kv_seqlen, step=kv_seqlen, dtype=torch.int32, device=kv.device)
- out = flash_attn_unpadded_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, q_seqlen, kv_seqlen, 0.0, sm_scale)
- return out
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py
deleted file mode 100644
index 7836cada81f9..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
-from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py
deleted file mode 100644
index 32ef56169978..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan.py
+++ /dev/null
@@ -1,730 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-# --------------------------------------------
-# Super-Resolution
-# --------------------------------------------
-#
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# From 2019/03--2021/08
-# --------------------------------------------
-"""
-
-import numpy as np
-import cv2
-import torch
-
-from functools import partial
-import random
-from scipy import ndimage
-import scipy
-import scipy.stats as ss
-from scipy.interpolate import interp2d
-from scipy.linalg import orth
-import albumentations
-
-import ldm.modules.image_degradation.utils_image as util
-
-
-def modcrop_np(img, sf):
- '''
- Args:
- img: numpy image, WxH or WxHxC
- sf: scale factor
- Return:
- cropped image
- '''
- w, h = img.shape[:2]
- im = np.copy(img)
- return im[:w - w % sf, :h - h % sf, ...]
-
-
-"""
-# --------------------------------------------
-# anisotropic Gaussian kernels
-# --------------------------------------------
-"""
-
-
-def analytic_kernel(k):
- """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
- k_size = k.shape[0]
- # Calculate the big kernels size
- big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
- # Loop over the small kernel to fill the big one
- for r in range(k_size):
- for c in range(k_size):
- big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
- # Crop the edges of the big kernel to ignore very small values and increase run time of SR
- crop = k_size // 2
- cropped_big_k = big_k[crop:-crop, crop:-crop]
- # Normalize to 1
- return cropped_big_k / cropped_big_k.sum()
-
-
-def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
- """ generate an anisotropic Gaussian kernel
- Args:
- ksize : e.g., 15, kernel size
- theta : [0, pi], rotation angle range
- l1 : [0.1,50], scaling of eigenvalues
- l2 : [0.1,l1], scaling of eigenvalues
- If l1 = l2, will get an isotropic Gaussian kernel.
- Returns:
- k : kernel
- """
-
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
- D = np.array([[l1, 0], [0, l2]])
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
-
- return k
-
-
-def gm_blur_kernel(mean, cov, size=15):
- center = size / 2.0 + 0.5
- k = np.zeros([size, size])
- for y in range(size):
- for x in range(size):
- cy = y - center + 1
- cx = x - center + 1
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
-
- k = k / np.sum(k)
- return k
-
-
-def shift_pixel(x, sf, upper_left=True):
- """shift pixel for super-resolution with different scale factors
- Args:
- x: WxHxC or WxH
- sf: scale factor
- upper_left: shift direction
- """
- h, w = x.shape[:2]
- shift = (sf - 1) * 0.5
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
- if upper_left:
- x1 = xv + shift
- y1 = yv + shift
- else:
- x1 = xv - shift
- y1 = yv - shift
-
- x1 = np.clip(x1, 0, w - 1)
- y1 = np.clip(y1, 0, h - 1)
-
- if x.ndim == 2:
- x = interp2d(xv, yv, x)(x1, y1)
- if x.ndim == 3:
- for i in range(x.shape[-1]):
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
-
- return x
-
-
-def blur(x, k):
- '''
- x: image, NxcxHxW
- k: kernel, Nx1xhxw
- '''
- n, c = x.shape[:2]
- p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
- x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
- k = k.repeat(1, c, 1, 1)
- k = k.view(-1, 1, k.shape[2], k.shape[3])
- x = x.view(1, -1, x.shape[2], x.shape[3])
- x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
- x = x.view(n, c, x.shape[2], x.shape[3])
-
- return x
-
-
-def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
- """"
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
- # Kai Zhang
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
- # max_var = 2.5 * sf
- """
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
- theta = np.random.rand() * np.pi # random theta
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
-
- # Set COV matrix using Lambdas and Theta
- LAMBDA = np.diag([lambda_1, lambda_2])
- Q = np.array([[np.cos(theta), -np.sin(theta)],
- [np.sin(theta), np.cos(theta)]])
- SIGMA = Q @ LAMBDA @ Q.T
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
-
- # Set expectation position (shifting kernel for aligned image)
- MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
- MU = MU[None, None, :, None]
-
- # Create meshgrid for Gaussian
- [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
- Z = np.stack([X, Y], 2)[:, :, :, None]
-
- # Calcualte Gaussian for every pixel of the kernel
- ZZ = Z - MU
- ZZ_t = ZZ.transpose(0, 1, 3, 2)
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
-
- # shift the kernel so it will be centered
- # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
-
- # Normalize the kernel and return
- # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
- kernel = raw_kernel / np.sum(raw_kernel)
- return kernel
-
-
-def fspecial_gaussian(hsize, sigma):
- hsize = [hsize, hsize]
- siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
- std = sigma
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
- arg = -(x * x + y * y) / (2 * std * std)
- h = np.exp(arg)
- h[h < scipy.finfo(float).eps * h.max()] = 0
- sumh = h.sum()
- if sumh != 0:
- h = h / sumh
- return h
-
-
-def fspecial_laplacian(alpha):
- alpha = max([0, min([alpha, 1])])
- h1 = alpha / (alpha + 1)
- h2 = (1 - alpha) / (alpha + 1)
- h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
- h = np.array(h)
- return h
-
-
-def fspecial(filter_type, *args, **kwargs):
- '''
- python code from:
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
- '''
- if filter_type == 'gaussian':
- return fspecial_gaussian(*args, **kwargs)
- if filter_type == 'laplacian':
- return fspecial_laplacian(*args, **kwargs)
-
-
-"""
-# --------------------------------------------
-# degradation models
-# --------------------------------------------
-"""
-
-
-def bicubic_degradation(x, sf=3):
- '''
- Args:
- x: HxWxC image, [0, 1]
- sf: down-scale factor
- Return:
- bicubicly downsampled LR image
- '''
- x = util.imresize_np(x, scale=1 / sf)
- return x
-
-
-def srmd_degradation(x, k, sf=3):
- ''' blur + bicubic downsampling
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2018learning,
- title={Learning a single convolutional super-resolution network for multiple degradations},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={3262--3271},
- year={2018}
- }
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
- x = bicubic_degradation(x, sf=sf)
- return x
-
-
-def dpsr_degradation(x, k, sf=3):
- ''' bicubic downsampling + blur
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2019deep,
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={1671--1681},
- year={2019}
- }
- '''
- x = bicubic_degradation(x, sf=sf)
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- return x
-
-
-def classical_degradation(x, k, sf=3):
- ''' blur + downsampling
- Args:
- x: HxWxC image, [0, 1]/[0, 255]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
- st = 0
- return x[st::sf, st::sf, ...]
-
-
-def add_sharpening(img, weight=0.5, radius=50, threshold=10):
- """USM sharpening. borrowed from real-ESRGAN
- Input image: I; Blurry image: B.
- 1. K = I + weight * (I - B)
- 2. Mask = 1 if abs(I - B) > threshold, else: 0
- 3. Blur mask:
- 4. Out = Mask * K + (1 - Mask) * I
- Args:
- img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
- weight (float): Sharp weight. Default: 1.
- radius (float): Kernel size of Gaussian blur. Default: 50.
- threshold (int):
- """
- if radius % 2 == 0:
- radius += 1
- blur = cv2.GaussianBlur(img, (radius, radius), 0)
- residual = img - blur
- mask = np.abs(residual) * 255 > threshold
- mask = mask.astype('float32')
- soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
-
- K = img + weight * residual
- K = np.clip(K, 0, 1)
- return soft_mask * K + (1 - soft_mask) * img
-
-
-def add_blur(img, sf=4):
- wd2 = 4.0 + sf
- wd = 2.0 + 0.2 * sf
- if random.random() < 0.5:
- l1 = wd2 * random.random()
- l2 = wd2 * random.random()
- k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
- else:
- k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
- img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
-
- return img
-
-
-def add_resize(img, sf=4):
- rnum = np.random.rand()
- if rnum > 0.8: # up
- sf1 = random.uniform(1, 2)
- elif rnum < 0.7: # down
- sf1 = random.uniform(0.5 / sf, 1)
- else:
- sf1 = 1.0
- img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- return img
-
-
-# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
-# noise_level = random.randint(noise_level1, noise_level2)
-# rnum = np.random.rand()
-# if rnum > 0.6: # add color Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
-# elif rnum < 0.4: # add grayscale Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
-# else: # add noise
-# L = noise_level2 / 255.
-# D = np.diag(np.random.rand(3))
-# U = orth(np.random.rand(3, 3))
-# conv = np.dot(np.dot(np.transpose(U), D), U)
-# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
-# img = np.clip(img, 0.0, 1.0)
-# return img
-
-def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- rnum = np.random.rand()
- if rnum > 0.6: # add color Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4: # add grayscale Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else: # add noise
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_speckle_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- img = np.clip(img, 0.0, 1.0)
- rnum = random.random()
- if rnum > 0.6:
- img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4:
- img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else:
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_Poisson_noise(img):
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
- vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
- if random.random() < 0.5:
- img = np.random.poisson(img * vals).astype(np.float32) / vals
- else:
- img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
- img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
- noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
- img += noise_gray[:, :, np.newaxis]
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_JPEG_noise(img):
- quality_factor = random.randint(30, 95)
- img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
- result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
- img = cv2.imdecode(encimg, 1)
- img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
- return img
-
-
-def random_crop(lq, hq, sf=4, lq_patchsize=64):
- h, w = lq.shape[:2]
- rnd_h = random.randint(0, h - lq_patchsize)
- rnd_w = random.randint(0, w - lq_patchsize)
- lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
-
- rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
- hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
- return lq, hq
-
-
-def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize * sf or w < lq_patchsize * sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- hq = img.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- img = util.imresize_np(img, 1 / 2, True)
- img = np.clip(img, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- img = add_blur(img, sf=sf)
-
- elif i == 1:
- img = add_blur(img, sf=sf)
-
- elif i == 2:
- a, b = img.shape[1], img.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1, 2 * sf)
- img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
- img = img[0::sf, 0::sf, ...] # nearest downsampling
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- img = add_JPEG_noise(img)
-
- elif i == 6:
- # add processed camera sensor noise
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
-
- return img, hq
-
-
-# todo no isp_model?
-def degradation_bsrgan_variant(image, sf=4, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- image = util.uint2single(image)
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = image.shape[:2]
- image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = image.shape[:2]
-
- hq = image.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- image = util.imresize_np(image, 1 / 2, True)
- image = np.clip(image, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- image = add_blur(image, sf=sf)
-
- elif i == 1:
- image = add_blur(image, sf=sf)
-
- elif i == 2:
- a, b = image.shape[1], image.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1, 2 * sf)
- image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
- image = image[0::sf, 0::sf, ...] # nearest downsampling
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- image = add_JPEG_noise(image)
-
- # elif i == 6:
- # # add processed camera sensor noise
- # if random.random() < isp_prob and isp_model is not None:
- # with torch.no_grad():
- # img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- image = add_JPEG_noise(image)
- image = util.single2uint(image)
- example = {"image":image}
- return example
-
-
-# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
-def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
- """
- This is an extended degradation model by combining
- the degradation models of BSRGAN and Real-ESRGAN
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- use_shuffle: the degradation shuffle
- use_sharp: sharpening the img
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize * sf or w < lq_patchsize * sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- if use_sharp:
- img = add_sharpening(img)
- hq = img.copy()
-
- if random.random() < shuffle_prob:
- shuffle_order = random.sample(range(13), 13)
- else:
- shuffle_order = list(range(13))
- # local shuffle for noise, JPEG is always the last one
- shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
- shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
-
- poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
-
- for i in shuffle_order:
- if i == 0:
- img = add_blur(img, sf=sf)
- elif i == 1:
- img = add_resize(img, sf=sf)
- elif i == 2:
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
- elif i == 3:
- if random.random() < poisson_prob:
- img = add_Poisson_noise(img)
- elif i == 4:
- if random.random() < speckle_prob:
- img = add_speckle_noise(img)
- elif i == 5:
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
- elif i == 6:
- img = add_JPEG_noise(img)
- elif i == 7:
- img = add_blur(img, sf=sf)
- elif i == 8:
- img = add_resize(img, sf=sf)
- elif i == 9:
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
- elif i == 10:
- if random.random() < poisson_prob:
- img = add_Poisson_noise(img)
- elif i == 11:
- if random.random() < speckle_prob:
- img = add_speckle_noise(img)
- elif i == 12:
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
- else:
- print('check the shuffle!')
-
- # resize to desired size
- img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
- interpolation=random.choice([1, 2, 3]))
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf, lq_patchsize)
-
- return img, hq
-
-
-if __name__ == '__main__':
- print("hey")
- img = util.imread_uint('utils/test.png', 3)
- print(img)
- img = util.uint2single(img)
- print(img)
- img = img[:448, :448]
- h = img.shape[0] // 4
- print("resizing to", h)
- sf = 4
- deg_fn = partial(degradation_bsrgan_variant, sf=sf)
- for i in range(20):
- print(i)
- img_lq = deg_fn(img)
- print(img_lq)
- img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
- print(img_lq.shape)
- print("bicubic", img_lq_bicubic.shape)
- print(img_hq.shape)
- lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
- util.imsave(img_concat, str(i) + '.png')
-
-
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py
deleted file mode 100644
index 9e1f823996bf..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/bsrgan_light.py
+++ /dev/null
@@ -1,650 +0,0 @@
-# -*- coding: utf-8 -*-
-import numpy as np
-import cv2
-import torch
-
-from functools import partial
-import random
-from scipy import ndimage
-import scipy
-import scipy.stats as ss
-from scipy.interpolate import interp2d
-from scipy.linalg import orth
-import albumentations
-
-import ldm.modules.image_degradation.utils_image as util
-
-"""
-# --------------------------------------------
-# Super-Resolution
-# --------------------------------------------
-#
-# Kai Zhang (cskaizhang@gmail.com)
-# https://github.com/cszn
-# From 2019/03--2021/08
-# --------------------------------------------
-"""
-
-
-def modcrop_np(img, sf):
- '''
- Args:
- img: numpy image, WxH or WxHxC
- sf: scale factor
- Return:
- cropped image
- '''
- w, h = img.shape[:2]
- im = np.copy(img)
- return im[:w - w % sf, :h - h % sf, ...]
-
-
-"""
-# --------------------------------------------
-# anisotropic Gaussian kernels
-# --------------------------------------------
-"""
-
-
-def analytic_kernel(k):
- """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
- k_size = k.shape[0]
- # Calculate the big kernels size
- big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
- # Loop over the small kernel to fill the big one
- for r in range(k_size):
- for c in range(k_size):
- big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
- # Crop the edges of the big kernel to ignore very small values and increase run time of SR
- crop = k_size // 2
- cropped_big_k = big_k[crop:-crop, crop:-crop]
- # Normalize to 1
- return cropped_big_k / cropped_big_k.sum()
-
-
-def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
- """ generate an anisotropic Gaussian kernel
- Args:
- ksize : e.g., 15, kernel size
- theta : [0, pi], rotation angle range
- l1 : [0.1,50], scaling of eigenvalues
- l2 : [0.1,l1], scaling of eigenvalues
- If l1 = l2, will get an isotropic Gaussian kernel.
- Returns:
- k : kernel
- """
-
- v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
- V = np.array([[v[0], v[1]], [v[1], -v[0]]])
- D = np.array([[l1, 0], [0, l2]])
- Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
- k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
-
- return k
-
-
-def gm_blur_kernel(mean, cov, size=15):
- center = size / 2.0 + 0.5
- k = np.zeros([size, size])
- for y in range(size):
- for x in range(size):
- cy = y - center + 1
- cx = x - center + 1
- k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
-
- k = k / np.sum(k)
- return k
-
-
-def shift_pixel(x, sf, upper_left=True):
- """shift pixel for super-resolution with different scale factors
- Args:
- x: WxHxC or WxH
- sf: scale factor
- upper_left: shift direction
- """
- h, w = x.shape[:2]
- shift = (sf - 1) * 0.5
- xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
- if upper_left:
- x1 = xv + shift
- y1 = yv + shift
- else:
- x1 = xv - shift
- y1 = yv - shift
-
- x1 = np.clip(x1, 0, w - 1)
- y1 = np.clip(y1, 0, h - 1)
-
- if x.ndim == 2:
- x = interp2d(xv, yv, x)(x1, y1)
- if x.ndim == 3:
- for i in range(x.shape[-1]):
- x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
-
- return x
-
-
-def blur(x, k):
- '''
- x: image, NxcxHxW
- k: kernel, Nx1xhxw
- '''
- n, c = x.shape[:2]
- p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
- x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
- k = k.repeat(1, c, 1, 1)
- k = k.view(-1, 1, k.shape[2], k.shape[3])
- x = x.view(1, -1, x.shape[2], x.shape[3])
- x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
- x = x.view(n, c, x.shape[2], x.shape[3])
-
- return x
-
-
-def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
- """"
- # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
- # Kai Zhang
- # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
- # max_var = 2.5 * sf
- """
- # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
- lambda_1 = min_var + np.random.rand() * (max_var - min_var)
- lambda_2 = min_var + np.random.rand() * (max_var - min_var)
- theta = np.random.rand() * np.pi # random theta
- noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
-
- # Set COV matrix using Lambdas and Theta
- LAMBDA = np.diag([lambda_1, lambda_2])
- Q = np.array([[np.cos(theta), -np.sin(theta)],
- [np.sin(theta), np.cos(theta)]])
- SIGMA = Q @ LAMBDA @ Q.T
- INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
-
- # Set expectation position (shifting kernel for aligned image)
- MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
- MU = MU[None, None, :, None]
-
- # Create meshgrid for Gaussian
- [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
- Z = np.stack([X, Y], 2)[:, :, :, None]
-
- # Calcualte Gaussian for every pixel of the kernel
- ZZ = Z - MU
- ZZ_t = ZZ.transpose(0, 1, 3, 2)
- raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
-
- # shift the kernel so it will be centered
- # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
-
- # Normalize the kernel and return
- # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
- kernel = raw_kernel / np.sum(raw_kernel)
- return kernel
-
-
-def fspecial_gaussian(hsize, sigma):
- hsize = [hsize, hsize]
- siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
- std = sigma
- [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
- arg = -(x * x + y * y) / (2 * std * std)
- h = np.exp(arg)
- h[h < scipy.finfo(float).eps * h.max()] = 0
- sumh = h.sum()
- if sumh != 0:
- h = h / sumh
- return h
-
-
-def fspecial_laplacian(alpha):
- alpha = max([0, min([alpha, 1])])
- h1 = alpha / (alpha + 1)
- h2 = (1 - alpha) / (alpha + 1)
- h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
- h = np.array(h)
- return h
-
-
-def fspecial(filter_type, *args, **kwargs):
- '''
- python code from:
- https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
- '''
- if filter_type == 'gaussian':
- return fspecial_gaussian(*args, **kwargs)
- if filter_type == 'laplacian':
- return fspecial_laplacian(*args, **kwargs)
-
-
-"""
-# --------------------------------------------
-# degradation models
-# --------------------------------------------
-"""
-
-
-def bicubic_degradation(x, sf=3):
- '''
- Args:
- x: HxWxC image, [0, 1]
- sf: down-scale factor
- Return:
- bicubicly downsampled LR image
- '''
- x = util.imresize_np(x, scale=1 / sf)
- return x
-
-
-def srmd_degradation(x, k, sf=3):
- ''' blur + bicubic downsampling
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2018learning,
- title={Learning a single convolutional super-resolution network for multiple degradations},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={3262--3271},
- year={2018}
- }
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
- x = bicubic_degradation(x, sf=sf)
- return x
-
-
-def dpsr_degradation(x, k, sf=3):
- ''' bicubic downsampling + blur
- Args:
- x: HxWxC image, [0, 1]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- Reference:
- @inproceedings{zhang2019deep,
- title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
- author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
- booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
- pages={1671--1681},
- year={2019}
- }
- '''
- x = bicubic_degradation(x, sf=sf)
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- return x
-
-
-def classical_degradation(x, k, sf=3):
- ''' blur + downsampling
- Args:
- x: HxWxC image, [0, 1]/[0, 255]
- k: hxw, double
- sf: down-scale factor
- Return:
- downsampled LR image
- '''
- x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
- # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
- st = 0
- return x[st::sf, st::sf, ...]
-
-
-def add_sharpening(img, weight=0.5, radius=50, threshold=10):
- """USM sharpening. borrowed from real-ESRGAN
- Input image: I; Blurry image: B.
- 1. K = I + weight * (I - B)
- 2. Mask = 1 if abs(I - B) > threshold, else: 0
- 3. Blur mask:
- 4. Out = Mask * K + (1 - Mask) * I
- Args:
- img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
- weight (float): Sharp weight. Default: 1.
- radius (float): Kernel size of Gaussian blur. Default: 50.
- threshold (int):
- """
- if radius % 2 == 0:
- radius += 1
- blur = cv2.GaussianBlur(img, (radius, radius), 0)
- residual = img - blur
- mask = np.abs(residual) * 255 > threshold
- mask = mask.astype('float32')
- soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
-
- K = img + weight * residual
- K = np.clip(K, 0, 1)
- return soft_mask * K + (1 - soft_mask) * img
-
-
-def add_blur(img, sf=4):
- wd2 = 4.0 + sf
- wd = 2.0 + 0.2 * sf
-
- wd2 = wd2/4
- wd = wd/4
-
- if random.random() < 0.5:
- l1 = wd2 * random.random()
- l2 = wd2 * random.random()
- k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
- else:
- k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
- img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
-
- return img
-
-
-def add_resize(img, sf=4):
- rnum = np.random.rand()
- if rnum > 0.8: # up
- sf1 = random.uniform(1, 2)
- elif rnum < 0.7: # down
- sf1 = random.uniform(0.5 / sf, 1)
- else:
- sf1 = 1.0
- img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- return img
-
-
-# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
-# noise_level = random.randint(noise_level1, noise_level2)
-# rnum = np.random.rand()
-# if rnum > 0.6: # add color Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
-# elif rnum < 0.4: # add grayscale Gaussian noise
-# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
-# else: # add noise
-# L = noise_level2 / 255.
-# D = np.diag(np.random.rand(3))
-# U = orth(np.random.rand(3, 3))
-# conv = np.dot(np.dot(np.transpose(U), D), U)
-# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
-# img = np.clip(img, 0.0, 1.0)
-# return img
-
-def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- rnum = np.random.rand()
- if rnum > 0.6: # add color Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4: # add grayscale Gaussian noise
- img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else: # add noise
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_speckle_noise(img, noise_level1=2, noise_level2=25):
- noise_level = random.randint(noise_level1, noise_level2)
- img = np.clip(img, 0.0, 1.0)
- rnum = random.random()
- if rnum > 0.6:
- img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
- elif rnum < 0.4:
- img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
- else:
- L = noise_level2 / 255.
- D = np.diag(np.random.rand(3))
- U = orth(np.random.rand(3, 3))
- conv = np.dot(np.dot(np.transpose(U), D), U)
- img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_Poisson_noise(img):
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
- vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
- if random.random() < 0.5:
- img = np.random.poisson(img * vals).astype(np.float32) / vals
- else:
- img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
- img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
- noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
- img += noise_gray[:, :, np.newaxis]
- img = np.clip(img, 0.0, 1.0)
- return img
-
-
-def add_JPEG_noise(img):
- quality_factor = random.randint(80, 95)
- img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
- result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
- img = cv2.imdecode(encimg, 1)
- img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
- return img
-
-
-def random_crop(lq, hq, sf=4, lq_patchsize=64):
- h, w = lq.shape[:2]
- rnd_h = random.randint(0, h - lq_patchsize)
- rnd_w = random.randint(0, w - lq_patchsize)
- lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
-
- rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
- hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
- return lq, hq
-
-
-def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = img.shape[:2]
- img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = img.shape[:2]
-
- if h < lq_patchsize * sf or w < lq_patchsize * sf:
- raise ValueError(f'img size ({h1}X{w1}) is too small!')
-
- hq = img.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- img = util.imresize_np(img, 1 / 2, True)
- img = np.clip(img, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- img = add_blur(img, sf=sf)
-
- elif i == 1:
- img = add_blur(img, sf=sf)
-
- elif i == 2:
- a, b = img.shape[1], img.shape[0]
- # downsample2
- if random.random() < 0.75:
- sf1 = random.uniform(1, 2 * sf)
- img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
- img = img[0::sf, 0::sf, ...] # nearest downsampling
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- img = np.clip(img, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- img = add_JPEG_noise(img)
-
- elif i == 6:
- # add processed camera sensor noise
- if random.random() < isp_prob and isp_model is not None:
- with torch.no_grad():
- img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- img = add_JPEG_noise(img)
-
- # random crop
- img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
-
- return img, hq
-
-
-# todo no isp_model?
-def degradation_bsrgan_variant(image, sf=4, isp_model=None):
- """
- This is the degradation model of BSRGAN from the paper
- "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
- ----------
- sf: scale factor
- isp_model: camera ISP model
- Returns
- -------
- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
- hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
- """
- image = util.uint2single(image)
- isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
- sf_ori = sf
-
- h1, w1 = image.shape[:2]
- image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
- h, w = image.shape[:2]
-
- hq = image.copy()
-
- if sf == 4 and random.random() < scale2_prob: # downsample1
- if np.random.rand() < 0.5:
- image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- image = util.imresize_np(image, 1 / 2, True)
- image = np.clip(image, 0.0, 1.0)
- sf = 2
-
- shuffle_order = random.sample(range(7), 7)
- idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
- if idx1 > idx2: # keep downsample3 last
- shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
-
- for i in shuffle_order:
-
- if i == 0:
- image = add_blur(image, sf=sf)
-
- # elif i == 1:
- # image = add_blur(image, sf=sf)
-
- if i == 0:
- pass
-
- elif i == 2:
- a, b = image.shape[1], image.shape[0]
- # downsample2
- if random.random() < 0.8:
- sf1 = random.uniform(1, 2 * sf)
- image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
- interpolation=random.choice([1, 2, 3]))
- else:
- k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
- k_shifted = shift_pixel(k, sf)
- k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
- image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
- image = image[0::sf, 0::sf, ...] # nearest downsampling
-
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 3:
- # downsample3
- image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
- image = np.clip(image, 0.0, 1.0)
-
- elif i == 4:
- # add Gaussian noise
- image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
-
- elif i == 5:
- # add JPEG noise
- if random.random() < jpeg_prob:
- image = add_JPEG_noise(image)
- #
- # elif i == 6:
- # # add processed camera sensor noise
- # if random.random() < isp_prob and isp_model is not None:
- # with torch.no_grad():
- # img, hq = isp_model.forward(img.copy(), hq)
-
- # add final JPEG compression noise
- image = add_JPEG_noise(image)
- image = util.single2uint(image)
- example = {"image": image}
- return example
-
-
-
-
-if __name__ == '__main__':
- print("hey")
- img = util.imread_uint('utils/test.png', 3)
- img = img[:448, :448]
- h = img.shape[0] // 4
- print("resizing to", h)
- sf = 4
- deg_fn = partial(degradation_bsrgan_variant, sf=sf)
- for i in range(20):
- print(i)
- img_hq = img
- img_lq = deg_fn(img)["image"]
- img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
- print(img_lq)
- img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
- print(img_lq.shape)
- print("bicubic", img_lq_bicubic.shape)
- print(img_hq.shape)
- lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
- (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
- interpolation=0)
- img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
- util.imsave(img_concat, str(i) + '.png')
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png
deleted file mode 100644
index 4249b43de0f2..000000000000
Binary files a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils/test.png and /dev/null differ
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py b/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py
deleted file mode 100644
index 0175f155ad90..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/image_degradation/utils_image.py
+++ /dev/null
@@ -1,916 +0,0 @@
-import os
-import math
-import random
-import numpy as np
-import torch
-import cv2
-from torchvision.utils import make_grid
-from datetime import datetime
-#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
-
-
-os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
-
-
-'''
-# --------------------------------------------
-# Kai Zhang (github: https://github.com/cszn)
-# 03/Mar/2019
-# --------------------------------------------
-# https://github.com/twhui/SRGAN-pyTorch
-# https://github.com/xinntao/BasicSR
-# --------------------------------------------
-'''
-
-
-IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
-
-
-def is_image_file(filename):
- return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
-
-
-def get_timestamp():
- return datetime.now().strftime('%y%m%d-%H%M%S')
-
-
-def imshow(x, title=None, cbar=False, figsize=None):
- plt.figure(figsize=figsize)
- plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
- if title:
- plt.title(title)
- if cbar:
- plt.colorbar()
- plt.show()
-
-
-def surf(Z, cmap='rainbow', figsize=None):
- plt.figure(figsize=figsize)
- ax3 = plt.axes(projection='3d')
-
- w, h = Z.shape[:2]
- xx = np.arange(0,w,1)
- yy = np.arange(0,h,1)
- X, Y = np.meshgrid(xx, yy)
- ax3.plot_surface(X,Y,Z,cmap=cmap)
- #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
- plt.show()
-
-
-'''
-# --------------------------------------------
-# get image pathes
-# --------------------------------------------
-'''
-
-
-def get_image_paths(dataroot):
- paths = None # return None if dataroot is None
- if dataroot is not None:
- paths = sorted(_get_paths_from_images(dataroot))
- return paths
-
-
-def _get_paths_from_images(path):
- assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
- images = []
- for dirpath, _, fnames in sorted(os.walk(path)):
- for fname in sorted(fnames):
- if is_image_file(fname):
- img_path = os.path.join(dirpath, fname)
- images.append(img_path)
- assert images, '{:s} has no valid image file'.format(path)
- return images
-
-
-'''
-# --------------------------------------------
-# split large images into small images
-# --------------------------------------------
-'''
-
-
-def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
- w, h = img.shape[:2]
- patches = []
- if w > p_max and h > p_max:
- w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
- h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
- w1.append(w-p_size)
- h1.append(h-p_size)
-# print(w1)
-# print(h1)
- for i in w1:
- for j in h1:
- patches.append(img[i:i+p_size, j:j+p_size,:])
- else:
- patches.append(img)
-
- return patches
-
-
-def imssave(imgs, img_path):
- """
- imgs: list, N images of size WxHxC
- """
- img_name, ext = os.path.splitext(os.path.basename(img_path))
-
- for i, img in enumerate(imgs):
- if img.ndim == 3:
- img = img[:, :, [2, 1, 0]]
- new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
- cv2.imwrite(new_path, img)
-
-
-def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
- """
- split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
- and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
- will be splitted.
- Args:
- original_dataroot:
- taget_dataroot:
- p_size: size of small images
- p_overlap: patch size in training is a good choice
- p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
- """
- paths = get_image_paths(original_dataroot)
- for img_path in paths:
- # img_name, ext = os.path.splitext(os.path.basename(img_path))
- img = imread_uint(img_path, n_channels=n_channels)
- patches = patches_from_image(img, p_size, p_overlap, p_max)
- imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
- #if original_dataroot == taget_dataroot:
- #del img_path
-
-'''
-# --------------------------------------------
-# makedir
-# --------------------------------------------
-'''
-
-
-def mkdir(path):
- if not os.path.exists(path):
- os.makedirs(path)
-
-
-def mkdirs(paths):
- if isinstance(paths, str):
- mkdir(paths)
- else:
- for path in paths:
- mkdir(path)
-
-
-def mkdir_and_rename(path):
- if os.path.exists(path):
- new_name = path + '_archived_' + get_timestamp()
- print('Path already exists. Rename it to [{:s}]'.format(new_name))
- os.rename(path, new_name)
- os.makedirs(path)
-
-
-'''
-# --------------------------------------------
-# read image from path
-# opencv is fast, but read BGR numpy image
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# get uint8 image of size HxWxn_channles (RGB)
-# --------------------------------------------
-def imread_uint(path, n_channels=3):
- # input: path
- # output: HxWx3(RGB or GGG), or HxWx1 (G)
- if n_channels == 1:
- img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
- img = np.expand_dims(img, axis=2) # HxWx1
- elif n_channels == 3:
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
- if img.ndim == 2:
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
- else:
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
- return img
-
-
-# --------------------------------------------
-# matlab's imwrite
-# --------------------------------------------
-def imsave(img, img_path):
- img = np.squeeze(img)
- if img.ndim == 3:
- img = img[:, :, [2, 1, 0]]
- cv2.imwrite(img_path, img)
-
-def imwrite(img, img_path):
- img = np.squeeze(img)
- if img.ndim == 3:
- img = img[:, :, [2, 1, 0]]
- cv2.imwrite(img_path, img)
-
-
-
-# --------------------------------------------
-# get single image of size HxWxn_channles (BGR)
-# --------------------------------------------
-def read_img(path):
- # read image by cv2
- # return: Numpy float32, HWC, BGR, [0,1]
- img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
- img = img.astype(np.float32) / 255.
- if img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- # some images have 4 channels
- if img.shape[2] > 3:
- img = img[:, :, :3]
- return img
-
-
-'''
-# --------------------------------------------
-# image format conversion
-# --------------------------------------------
-# numpy(single) <---> numpy(unit)
-# numpy(single) <---> tensor
-# numpy(unit) <---> tensor
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# numpy(single) [0, 1] <---> numpy(unit)
-# --------------------------------------------
-
-
-def uint2single(img):
-
- return np.float32(img/255.)
-
-
-def single2uint(img):
-
- return np.uint8((img.clip(0, 1)*255.).round())
-
-
-def uint162single(img):
-
- return np.float32(img/65535.)
-
-
-def single2uint16(img):
-
- return np.uint16((img.clip(0, 1)*65535.).round())
-
-
-# --------------------------------------------
-# numpy(unit) (HxWxC or HxW) <---> tensor
-# --------------------------------------------
-
-
-# convert uint to 4-dimensional torch tensor
-def uint2tensor4(img):
- if img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
-
-
-# convert uint to 3-dimensional torch tensor
-def uint2tensor3(img):
- if img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
-
-
-# convert 2/3/4-dimensional torch tensor to uint
-def tensor2uint(img):
- img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
- return np.uint8((img*255.0).round())
-
-
-# --------------------------------------------
-# numpy(single) (HxWxC) <---> tensor
-# --------------------------------------------
-
-
-# convert single (HxWxC) to 3-dimensional torch tensor
-def single2tensor3(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
-
-
-# convert single (HxWxC) to 4-dimensional torch tensor
-def single2tensor4(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
-
-
-# convert torch tensor to single
-def tensor2single(img):
- img = img.data.squeeze().float().cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
-
- return img
-
-# convert torch tensor to single
-def tensor2single3(img):
- img = img.data.squeeze().float().cpu().numpy()
- if img.ndim == 3:
- img = np.transpose(img, (1, 2, 0))
- elif img.ndim == 2:
- img = np.expand_dims(img, axis=2)
- return img
-
-
-def single2tensor5(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
-
-
-def single32tensor5(img):
- return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
-
-
-def single42tensor4(img):
- return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
-
-
-# from skimage.io import imread, imsave
-def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
- '''
- Converts a torch Tensor into an image Numpy array of BGR channel order
- Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
- Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
- '''
- tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
- tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
- n_dim = tensor.dim()
- if n_dim == 4:
- n_img = len(tensor)
- img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
- img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
- elif n_dim == 3:
- img_np = tensor.numpy()
- img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
- elif n_dim == 2:
- img_np = tensor.numpy()
- else:
- raise TypeError(
- 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
- if out_type == np.uint8:
- img_np = (img_np * 255.0).round()
- # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
- return img_np.astype(out_type)
-
-
-'''
-# --------------------------------------------
-# Augmentation, flipe and/or rotate
-# --------------------------------------------
-# The following two are enough.
-# (1) augmet_img: numpy image of WxHxC or WxH
-# (2) augment_img_tensor4: tensor image 1xCxWxH
-# --------------------------------------------
-'''
-
-
-def augment_img(img, mode=0):
- '''Kai Zhang (github: https://github.com/cszn)
- '''
- if mode == 0:
- return img
- elif mode == 1:
- return np.flipud(np.rot90(img))
- elif mode == 2:
- return np.flipud(img)
- elif mode == 3:
- return np.rot90(img, k=3)
- elif mode == 4:
- return np.flipud(np.rot90(img, k=2))
- elif mode == 5:
- return np.rot90(img)
- elif mode == 6:
- return np.rot90(img, k=2)
- elif mode == 7:
- return np.flipud(np.rot90(img, k=3))
-
-
-def augment_img_tensor4(img, mode=0):
- '''Kai Zhang (github: https://github.com/cszn)
- '''
- if mode == 0:
- return img
- elif mode == 1:
- return img.rot90(1, [2, 3]).flip([2])
- elif mode == 2:
- return img.flip([2])
- elif mode == 3:
- return img.rot90(3, [2, 3])
- elif mode == 4:
- return img.rot90(2, [2, 3]).flip([2])
- elif mode == 5:
- return img.rot90(1, [2, 3])
- elif mode == 6:
- return img.rot90(2, [2, 3])
- elif mode == 7:
- return img.rot90(3, [2, 3]).flip([2])
-
-
-def augment_img_tensor(img, mode=0):
- '''Kai Zhang (github: https://github.com/cszn)
- '''
- img_size = img.size()
- img_np = img.data.cpu().numpy()
- if len(img_size) == 3:
- img_np = np.transpose(img_np, (1, 2, 0))
- elif len(img_size) == 4:
- img_np = np.transpose(img_np, (2, 3, 1, 0))
- img_np = augment_img(img_np, mode=mode)
- img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
- if len(img_size) == 3:
- img_tensor = img_tensor.permute(2, 0, 1)
- elif len(img_size) == 4:
- img_tensor = img_tensor.permute(3, 2, 0, 1)
-
- return img_tensor.type_as(img)
-
-
-def augment_img_np3(img, mode=0):
- if mode == 0:
- return img
- elif mode == 1:
- return img.transpose(1, 0, 2)
- elif mode == 2:
- return img[::-1, :, :]
- elif mode == 3:
- img = img[::-1, :, :]
- img = img.transpose(1, 0, 2)
- return img
- elif mode == 4:
- return img[:, ::-1, :]
- elif mode == 5:
- img = img[:, ::-1, :]
- img = img.transpose(1, 0, 2)
- return img
- elif mode == 6:
- img = img[:, ::-1, :]
- img = img[::-1, :, :]
- return img
- elif mode == 7:
- img = img[:, ::-1, :]
- img = img[::-1, :, :]
- img = img.transpose(1, 0, 2)
- return img
-
-
-def augment_imgs(img_list, hflip=True, rot=True):
- # horizontal flip OR rotate
- hflip = hflip and random.random() < 0.5
- vflip = rot and random.random() < 0.5
- rot90 = rot and random.random() < 0.5
-
- def _augment(img):
- if hflip:
- img = img[:, ::-1, :]
- if vflip:
- img = img[::-1, :, :]
- if rot90:
- img = img.transpose(1, 0, 2)
- return img
-
- return [_augment(img) for img in img_list]
-
-
-'''
-# --------------------------------------------
-# modcrop and shave
-# --------------------------------------------
-'''
-
-
-def modcrop(img_in, scale):
- # img_in: Numpy, HWC or HW
- img = np.copy(img_in)
- if img.ndim == 2:
- H, W = img.shape
- H_r, W_r = H % scale, W % scale
- img = img[:H - H_r, :W - W_r]
- elif img.ndim == 3:
- H, W, C = img.shape
- H_r, W_r = H % scale, W % scale
- img = img[:H - H_r, :W - W_r, :]
- else:
- raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
- return img
-
-
-def shave(img_in, border=0):
- # img_in: Numpy, HWC or HW
- img = np.copy(img_in)
- h, w = img.shape[:2]
- img = img[border:h-border, border:w-border]
- return img
-
-
-'''
-# --------------------------------------------
-# image processing process on numpy image
-# channel_convert(in_c, tar_type, img_list):
-# rgb2ycbcr(img, only_y=True):
-# bgr2ycbcr(img, only_y=True):
-# ycbcr2rgb(img):
-# --------------------------------------------
-'''
-
-
-def rgb2ycbcr(img, only_y=True):
- '''same as matlab rgb2ycbcr
- only_y: only return Y channel
- Input:
- uint8, [0, 255]
- float, [0, 1]
- '''
- in_img_type = img.dtype
- img.astype(np.float32)
- if in_img_type != np.uint8:
- img *= 255.
- # convert
- if only_y:
- rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
- else:
- rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
- [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
- if in_img_type == np.uint8:
- rlt = rlt.round()
- else:
- rlt /= 255.
- return rlt.astype(in_img_type)
-
-
-def ycbcr2rgb(img):
- '''same as matlab ycbcr2rgb
- Input:
- uint8, [0, 255]
- float, [0, 1]
- '''
- in_img_type = img.dtype
- img.astype(np.float32)
- if in_img_type != np.uint8:
- img *= 255.
- # convert
- rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
- [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
- if in_img_type == np.uint8:
- rlt = rlt.round()
- else:
- rlt /= 255.
- return rlt.astype(in_img_type)
-
-
-def bgr2ycbcr(img, only_y=True):
- '''bgr version of rgb2ycbcr
- only_y: only return Y channel
- Input:
- uint8, [0, 255]
- float, [0, 1]
- '''
- in_img_type = img.dtype
- img.astype(np.float32)
- if in_img_type != np.uint8:
- img *= 255.
- # convert
- if only_y:
- rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
- else:
- rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
- [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
- if in_img_type == np.uint8:
- rlt = rlt.round()
- else:
- rlt /= 255.
- return rlt.astype(in_img_type)
-
-
-def channel_convert(in_c, tar_type, img_list):
- # conversion among BGR, gray and y
- if in_c == 3 and tar_type == 'gray': # BGR to gray
- gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
- return [np.expand_dims(img, axis=2) for img in gray_list]
- elif in_c == 3 and tar_type == 'y': # BGR to y
- y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
- return [np.expand_dims(img, axis=2) for img in y_list]
- elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
- return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
- else:
- return img_list
-
-
-'''
-# --------------------------------------------
-# metric, PSNR and SSIM
-# --------------------------------------------
-'''
-
-
-# --------------------------------------------
-# PSNR
-# --------------------------------------------
-def calculate_psnr(img1, img2, border=0):
- # img1 and img2 have range [0, 255]
- #img1 = img1.squeeze()
- #img2 = img2.squeeze()
- if not img1.shape == img2.shape:
- raise ValueError('Input images must have the same dimensions.')
- h, w = img1.shape[:2]
- img1 = img1[border:h-border, border:w-border]
- img2 = img2[border:h-border, border:w-border]
-
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
- mse = np.mean((img1 - img2)**2)
- if mse == 0:
- return float('inf')
- return 20 * math.log10(255.0 / math.sqrt(mse))
-
-
-# --------------------------------------------
-# SSIM
-# --------------------------------------------
-def calculate_ssim(img1, img2, border=0):
- '''calculate SSIM
- the same outputs as MATLAB's
- img1, img2: [0, 255]
- '''
- #img1 = img1.squeeze()
- #img2 = img2.squeeze()
- if not img1.shape == img2.shape:
- raise ValueError('Input images must have the same dimensions.')
- h, w = img1.shape[:2]
- img1 = img1[border:h-border, border:w-border]
- img2 = img2[border:h-border, border:w-border]
-
- if img1.ndim == 2:
- return ssim(img1, img2)
- elif img1.ndim == 3:
- if img1.shape[2] == 3:
- ssims = []
- for i in range(3):
- ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
- return np.array(ssims).mean()
- elif img1.shape[2] == 1:
- return ssim(np.squeeze(img1), np.squeeze(img2))
- else:
- raise ValueError('Wrong input image dimensions.')
-
-
-def ssim(img1, img2):
- C1 = (0.01 * 255)**2
- C2 = (0.03 * 255)**2
-
- img1 = img1.astype(np.float64)
- img2 = img2.astype(np.float64)
- kernel = cv2.getGaussianKernel(11, 1.5)
- window = np.outer(kernel, kernel.transpose())
-
- mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
- mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
- mu1_sq = mu1**2
- mu2_sq = mu2**2
- mu1_mu2 = mu1 * mu2
- sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
- sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
- sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
-
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
- (sigma1_sq + sigma2_sq + C2))
- return ssim_map.mean()
-
-
-'''
-# --------------------------------------------
-# matlab's bicubic imresize (numpy and torch) [0, 1]
-# --------------------------------------------
-'''
-
-
-# matlab 'imresize' function, now only support 'bicubic'
-def cubic(x):
- absx = torch.abs(x)
- absx2 = absx**2
- absx3 = absx**3
- return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
- (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
-
-
-def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
- if (scale < 1) and (antialiasing):
- # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
- kernel_width = kernel_width / scale
-
- # Output-space coordinates
- x = torch.linspace(1, out_length, out_length)
-
- # Input-space coordinates. Calculate the inverse mapping such that 0.5
- # in output space maps to 0.5 in input space, and 0.5+scale in output
- # space maps to 1.5 in input space.
- u = x / scale + 0.5 * (1 - 1 / scale)
-
- # What is the left-most pixel that can be involved in the computation?
- left = torch.floor(u - kernel_width / 2)
-
- # What is the maximum number of pixels that can be involved in the
- # computation? Note: it's OK to use an extra pixel here; if the
- # corresponding weights are all zero, it will be eliminated at the end
- # of this function.
- P = math.ceil(kernel_width) + 2
-
- # The indices of the input pixels involved in computing the k-th output
- # pixel are in row k of the indices matrix.
- indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
- 1, P).expand(out_length, P)
-
- # The weights used to compute the k-th output pixel are in row k of the
- # weights matrix.
- distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
- # apply cubic kernel
- if (scale < 1) and (antialiasing):
- weights = scale * cubic(distance_to_center * scale)
- else:
- weights = cubic(distance_to_center)
- # Normalize the weights matrix so that each row sums to 1.
- weights_sum = torch.sum(weights, 1).view(out_length, 1)
- weights = weights / weights_sum.expand(out_length, P)
-
- # If a column in weights is all zero, get rid of it. only consider the first and last column.
- weights_zero_tmp = torch.sum((weights == 0), 0)
- if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
- indices = indices.narrow(1, 1, P - 2)
- weights = weights.narrow(1, 1, P - 2)
- if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
- indices = indices.narrow(1, 0, P - 2)
- weights = weights.narrow(1, 0, P - 2)
- weights = weights.contiguous()
- indices = indices.contiguous()
- sym_len_s = -indices.min() + 1
- sym_len_e = indices.max() - in_length
- indices = indices + sym_len_s - 1
- return weights, indices, int(sym_len_s), int(sym_len_e)
-
-
-# --------------------------------------------
-# imresize for tensor image [0, 1]
-# --------------------------------------------
-def imresize(img, scale, antialiasing=True):
- # Now the scale should be the same for H and W
- # input: img: pytorch tensor, CHW or HW [0,1]
- # output: CHW or HW [0,1] w/o round
- need_squeeze = True if img.dim() == 2 else False
- if need_squeeze:
- img.unsqueeze_(0)
- in_C, in_H, in_W = img.size()
- out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
- kernel_width = 4
- kernel = 'cubic'
-
- # Return the desired dimension order for performing the resize. The
- # strategy is to perform the resize first along the dimension with the
- # smallest scale factor.
- # Now we do not support this.
-
- # get weights and indices
- weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
- in_H, out_H, scale, kernel, kernel_width, antialiasing)
- weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
- in_W, out_W, scale, kernel, kernel_width, antialiasing)
- # process H dimension
- # symmetric copying
- img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
- img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
-
- sym_patch = img[:, :sym_len_Hs, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
-
- sym_patch = img[:, -sym_len_He:, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
-
- out_1 = torch.FloatTensor(in_C, out_H, in_W)
- kernel_width = weights_H.size(1)
- for i in range(out_H):
- idx = int(indices_H[i][0])
- for j in range(out_C):
- out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
-
- # process W dimension
- # symmetric copying
- out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
- out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
-
- sym_patch = out_1[:, :, :sym_len_Ws]
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
- out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
-
- sym_patch = out_1[:, :, -sym_len_We:]
- inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(2, inv_idx)
- out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
-
- out_2 = torch.FloatTensor(in_C, out_H, out_W)
- kernel_width = weights_W.size(1)
- for i in range(out_W):
- idx = int(indices_W[i][0])
- for j in range(out_C):
- out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
- if need_squeeze:
- out_2.squeeze_()
- return out_2
-
-
-# --------------------------------------------
-# imresize for numpy image [0, 1]
-# --------------------------------------------
-def imresize_np(img, scale, antialiasing=True):
- # Now the scale should be the same for H and W
- # input: img: Numpy, HWC or HW [0,1]
- # output: HWC or HW [0,1] w/o round
- img = torch.from_numpy(img)
- need_squeeze = True if img.dim() == 2 else False
- if need_squeeze:
- img.unsqueeze_(2)
-
- in_H, in_W, in_C = img.size()
- out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
- kernel_width = 4
- kernel = 'cubic'
-
- # Return the desired dimension order for performing the resize. The
- # strategy is to perform the resize first along the dimension with the
- # smallest scale factor.
- # Now we do not support this.
-
- # get weights and indices
- weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
- in_H, out_H, scale, kernel, kernel_width, antialiasing)
- weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
- in_W, out_W, scale, kernel, kernel_width, antialiasing)
- # process H dimension
- # symmetric copying
- img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
- img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
-
- sym_patch = img[:sym_len_Hs, :, :]
- inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(0, inv_idx)
- img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
-
- sym_patch = img[-sym_len_He:, :, :]
- inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(0, inv_idx)
- img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
-
- out_1 = torch.FloatTensor(out_H, in_W, in_C)
- kernel_width = weights_H.size(1)
- for i in range(out_H):
- idx = int(indices_H[i][0])
- for j in range(out_C):
- out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
-
- # process W dimension
- # symmetric copying
- out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
- out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
-
- sym_patch = out_1[:, :sym_len_Ws, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
-
- sym_patch = out_1[:, -sym_len_We:, :]
- inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
- sym_patch_inv = sym_patch.index_select(1, inv_idx)
- out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
-
- out_2 = torch.FloatTensor(out_H, out_W, in_C)
- kernel_width = weights_W.size(1)
- for i in range(out_W):
- idx = int(indices_W[i][0])
- for j in range(out_C):
- out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
- if need_squeeze:
- out_2.squeeze_()
-
- return out_2.numpy()
-
-
-if __name__ == '__main__':
- print('---')
-# img = imread_uint('test.bmp', 3)
-# img = uint2single(img)
-# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py b/examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py
deleted file mode 100644
index 876d7c5bd6e3..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/losses/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py b/examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py
deleted file mode 100644
index 672c1e32a138..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/losses/contperceptual.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import torch
-import torch.nn as nn
-
-from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
-
-
-class LPIPSWithDiscriminator(nn.Module):
- def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
- disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
- perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
- disc_loss="hinge"):
-
- super().__init__()
- assert disc_loss in ["hinge", "vanilla"]
- self.kl_weight = kl_weight
- self.pixel_weight = pixelloss_weight
- self.perceptual_loss = LPIPS().eval()
- self.perceptual_weight = perceptual_weight
- # output log variance
- self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
-
- self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
- n_layers=disc_num_layers,
- use_actnorm=use_actnorm
- ).apply(weights_init)
- self.discriminator_iter_start = disc_start
- self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
- self.disc_factor = disc_factor
- self.discriminator_weight = disc_weight
- self.disc_conditional = disc_conditional
-
- def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
- if last_layer is not None:
- nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
- else:
- nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
-
- d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
- d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
- d_weight = d_weight * self.discriminator_weight
- return d_weight
-
- def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
- global_step, last_layer=None, cond=None, split="train",
- weights=None):
- rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
- if self.perceptual_weight > 0:
- p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
- rec_loss = rec_loss + self.perceptual_weight * p_loss
-
- nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
- weighted_nll_loss = nll_loss
- if weights is not None:
- weighted_nll_loss = weights*nll_loss
- weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
- nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
- kl_loss = posteriors.kl()
- kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
-
- # now the GAN part
- if optimizer_idx == 0:
- # generator update
- if cond is None:
- assert not self.disc_conditional
- logits_fake = self.discriminator(reconstructions.contiguous())
- else:
- assert self.disc_conditional
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
- g_loss = -torch.mean(logits_fake)
-
- if self.disc_factor > 0.0:
- try:
- d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
- except RuntimeError:
- assert not self.training
- d_weight = torch.tensor(0.0)
- else:
- d_weight = torch.tensor(0.0)
-
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
- loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
-
- log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
- "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
- "{}/rec_loss".format(split): rec_loss.detach().mean(),
- "{}/d_weight".format(split): d_weight.detach(),
- "{}/disc_factor".format(split): torch.tensor(disc_factor),
- "{}/g_loss".format(split): g_loss.detach().mean(),
- }
- return loss, log
-
- if optimizer_idx == 1:
- # second pass for discriminator update
- if cond is None:
- logits_real = self.discriminator(inputs.contiguous().detach())
- logits_fake = self.discriminator(reconstructions.contiguous().detach())
- else:
- logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
-
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
- d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
-
- log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
- "{}/logits_real".format(split): logits_real.detach().mean(),
- "{}/logits_fake".format(split): logits_fake.detach().mean()
- }
- return d_loss, log
-
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py b/examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py
deleted file mode 100644
index f69981769e4b..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/losses/vqperceptual.py
+++ /dev/null
@@ -1,167 +0,0 @@
-import torch
-from torch import nn
-import torch.nn.functional as F
-from einops import repeat
-
-from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
-from taming.modules.losses.lpips import LPIPS
-from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
-
-
-def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
- assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
- loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
- loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
- loss_real = (weights * loss_real).sum() / weights.sum()
- loss_fake = (weights * loss_fake).sum() / weights.sum()
- d_loss = 0.5 * (loss_real + loss_fake)
- return d_loss
-
-def adopt_weight(weight, global_step, threshold=0, value=0.):
- if global_step < threshold:
- weight = value
- return weight
-
-
-def measure_perplexity(predicted_indices, n_embed):
- # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
- # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
- encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
- avg_probs = encodings.mean(0)
- perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
- cluster_use = torch.sum(avg_probs > 0)
- return perplexity, cluster_use
-
-def l1(x, y):
- return torch.abs(x-y)
-
-
-def l2(x, y):
- return torch.pow((x-y), 2)
-
-
-class VQLPIPSWithDiscriminator(nn.Module):
- def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
- disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
- perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
- disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
- pixel_loss="l1"):
- super().__init__()
- assert disc_loss in ["hinge", "vanilla"]
- assert perceptual_loss in ["lpips", "clips", "dists"]
- assert pixel_loss in ["l1", "l2"]
- self.codebook_weight = codebook_weight
- self.pixel_weight = pixelloss_weight
- if perceptual_loss == "lpips":
- print(f"{self.__class__.__name__}: Running with LPIPS.")
- self.perceptual_loss = LPIPS().eval()
- else:
- raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
- self.perceptual_weight = perceptual_weight
-
- if pixel_loss == "l1":
- self.pixel_loss = l1
- else:
- self.pixel_loss = l2
-
- self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
- n_layers=disc_num_layers,
- use_actnorm=use_actnorm,
- ndf=disc_ndf
- ).apply(weights_init)
- self.discriminator_iter_start = disc_start
- if disc_loss == "hinge":
- self.disc_loss = hinge_d_loss
- elif disc_loss == "vanilla":
- self.disc_loss = vanilla_d_loss
- else:
- raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
- print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
- self.disc_factor = disc_factor
- self.discriminator_weight = disc_weight
- self.disc_conditional = disc_conditional
- self.n_classes = n_classes
-
- def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
- if last_layer is not None:
- nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
- else:
- nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
- g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
-
- d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
- d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
- d_weight = d_weight * self.discriminator_weight
- return d_weight
-
- def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
- global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
- if not exists(codebook_loss):
- codebook_loss = torch.tensor([0.]).to(inputs.device)
- #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
- rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
- if self.perceptual_weight > 0:
- p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
- rec_loss = rec_loss + self.perceptual_weight * p_loss
- else:
- p_loss = torch.tensor([0.0])
-
- nll_loss = rec_loss
- #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
- nll_loss = torch.mean(nll_loss)
-
- # now the GAN part
- if optimizer_idx == 0:
- # generator update
- if cond is None:
- assert not self.disc_conditional
- logits_fake = self.discriminator(reconstructions.contiguous())
- else:
- assert self.disc_conditional
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
- g_loss = -torch.mean(logits_fake)
-
- try:
- d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
- except RuntimeError:
- assert not self.training
- d_weight = torch.tensor(0.0)
-
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
- loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
-
- log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
- "{}/quant_loss".format(split): codebook_loss.detach().mean(),
- "{}/nll_loss".format(split): nll_loss.detach().mean(),
- "{}/rec_loss".format(split): rec_loss.detach().mean(),
- "{}/p_loss".format(split): p_loss.detach().mean(),
- "{}/d_weight".format(split): d_weight.detach(),
- "{}/disc_factor".format(split): torch.tensor(disc_factor),
- "{}/g_loss".format(split): g_loss.detach().mean(),
- }
- if predicted_indices is not None:
- assert self.n_classes is not None
- with torch.no_grad():
- perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
- log[f"{split}/perplexity"] = perplexity
- log[f"{split}/cluster_usage"] = cluster_usage
- return loss, log
-
- if optimizer_idx == 1:
- # second pass for discriminator update
- if cond is None:
- logits_real = self.discriminator(inputs.contiguous().detach())
- logits_fake = self.discriminator(reconstructions.contiguous().detach())
- else:
- logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
- logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
-
- disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
- d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
-
- log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
- "{}/logits_real".format(split): logits_real.detach().mean(),
- "{}/logits_fake".format(split): logits_fake.detach().mean()
- }
- return d_loss, log
diff --git a/examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py b/examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py
deleted file mode 100644
index 5fc15bf9cfe0..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/modules/x_transformer.py
+++ /dev/null
@@ -1,641 +0,0 @@
-"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
-import torch
-from torch import nn, einsum
-import torch.nn.functional as F
-from functools import partial
-from inspect import isfunction
-from collections import namedtuple
-from einops import rearrange, repeat, reduce
-
-# constants
-
-DEFAULT_DIM_HEAD = 64
-
-Intermediates = namedtuple('Intermediates', [
- 'pre_softmax_attn',
- 'post_softmax_attn'
-])
-
-LayerIntermediates = namedtuple('Intermediates', [
- 'hiddens',
- 'attn_intermediates'
-])
-
-
-class AbsolutePositionalEmbedding(nn.Module):
- def __init__(self, dim, max_seq_len):
- super().__init__()
- self.emb = nn.Embedding(max_seq_len, dim)
- self.init_()
-
- def init_(self):
- nn.init.normal_(self.emb.weight, std=0.02)
-
- def forward(self, x):
- n = torch.arange(x.shape[1], device=x.device)
- return self.emb(n)[None, :, :]
-
-
-class FixedPositionalEmbedding(nn.Module):
- def __init__(self, dim):
- super().__init__()
- inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
- self.register_buffer('inv_freq', inv_freq)
-
- def forward(self, x, seq_dim=1, offset=0):
- t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
- sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
- emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
- return emb[None, :, :]
-
-
-# helpers
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-def always(val):
- def inner(*args, **kwargs):
- return val
- return inner
-
-
-def not_equals(val):
- def inner(x):
- return x != val
- return inner
-
-
-def equals(val):
- def inner(x):
- return x == val
- return inner
-
-
-def max_neg_value(tensor):
- return -torch.finfo(tensor.dtype).max
-
-
-# keyword argument helpers
-
-def pick_and_pop(keys, d):
- values = list(map(lambda key: d.pop(key), keys))
- return dict(zip(keys, values))
-
-
-def group_dict_by_key(cond, d):
- return_val = [dict(), dict()]
- for key in d.keys():
- match = bool(cond(key))
- ind = int(not match)
- return_val[ind][key] = d[key]
- return (*return_val,)
-
-
-def string_begins_with(prefix, str):
- return str.startswith(prefix)
-
-
-def group_by_key_prefix(prefix, d):
- return group_dict_by_key(partial(string_begins_with, prefix), d)
-
-
-def groupby_prefix_and_trim(prefix, d):
- kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
- kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
- return kwargs_without_prefix, kwargs
-
-
-# classes
-class Scale(nn.Module):
- def __init__(self, value, fn):
- super().__init__()
- self.value = value
- self.fn = fn
-
- def forward(self, x, **kwargs):
- x, *rest = self.fn(x, **kwargs)
- return (x * self.value, *rest)
-
-
-class Rezero(nn.Module):
- def __init__(self, fn):
- super().__init__()
- self.fn = fn
- self.g = nn.Parameter(torch.zeros(1))
-
- def forward(self, x, **kwargs):
- x, *rest = self.fn(x, **kwargs)
- return (x * self.g, *rest)
-
-
-class ScaleNorm(nn.Module):
- def __init__(self, dim, eps=1e-5):
- super().__init__()
- self.scale = dim ** -0.5
- self.eps = eps
- self.g = nn.Parameter(torch.ones(1))
-
- def forward(self, x):
- norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
- return x / norm.clamp(min=self.eps) * self.g
-
-
-class RMSNorm(nn.Module):
- def __init__(self, dim, eps=1e-8):
- super().__init__()
- self.scale = dim ** -0.5
- self.eps = eps
- self.g = nn.Parameter(torch.ones(dim))
-
- def forward(self, x):
- norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
- return x / norm.clamp(min=self.eps) * self.g
-
-
-class Residual(nn.Module):
- def forward(self, x, residual):
- return x + residual
-
-
-class GRUGating(nn.Module):
- def __init__(self, dim):
- super().__init__()
- self.gru = nn.GRUCell(dim, dim)
-
- def forward(self, x, residual):
- gated_output = self.gru(
- rearrange(x, 'b n d -> (b n) d'),
- rearrange(residual, 'b n d -> (b n) d')
- )
-
- return gated_output.reshape_as(x)
-
-
-# feedforward
-
-class GEGLU(nn.Module):
- def __init__(self, dim_in, dim_out):
- super().__init__()
- self.proj = nn.Linear(dim_in, dim_out * 2)
-
- def forward(self, x):
- x, gate = self.proj(x).chunk(2, dim=-1)
- return x * F.gelu(gate)
-
-
-class FeedForward(nn.Module):
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
- super().__init__()
- inner_dim = int(dim * mult)
- dim_out = default(dim_out, dim)
- project_in = nn.Sequential(
- nn.Linear(dim, inner_dim),
- nn.GELU()
- ) if not glu else GEGLU(dim, inner_dim)
-
- self.net = nn.Sequential(
- project_in,
- nn.Dropout(dropout),
- nn.Linear(inner_dim, dim_out)
- )
-
- def forward(self, x):
- return self.net(x)
-
-
-# attention.
-class Attention(nn.Module):
- def __init__(
- self,
- dim,
- dim_head=DEFAULT_DIM_HEAD,
- heads=8,
- causal=False,
- mask=None,
- talking_heads=False,
- sparse_topk=None,
- use_entmax15=False,
- num_mem_kv=0,
- dropout=0.,
- on_attn=False
- ):
- super().__init__()
- if use_entmax15:
- raise NotImplementedError("Check out entmax activation instead of softmax activation!")
- self.scale = dim_head ** -0.5
- self.heads = heads
- self.causal = causal
- self.mask = mask
-
- inner_dim = dim_head * heads
-
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
- self.to_k = nn.Linear(dim, inner_dim, bias=False)
- self.to_v = nn.Linear(dim, inner_dim, bias=False)
- self.dropout = nn.Dropout(dropout)
-
- # talking heads
- self.talking_heads = talking_heads
- if talking_heads:
- self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
- self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
-
- # explicit topk sparse attention
- self.sparse_topk = sparse_topk
-
- # entmax
- #self.attn_fn = entmax15 if use_entmax15 else F.softmax
- self.attn_fn = F.softmax
-
- # add memory key / values
- self.num_mem_kv = num_mem_kv
- if num_mem_kv > 0:
- self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
- self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
-
- # attention on attention
- self.attn_on_attn = on_attn
- self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
-
- def forward(
- self,
- x,
- context=None,
- mask=None,
- context_mask=None,
- rel_pos=None,
- sinusoidal_emb=None,
- prev_attn=None,
- mem=None
- ):
- b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
- kv_input = default(context, x)
-
- q_input = x
- k_input = kv_input
- v_input = kv_input
-
- if exists(mem):
- k_input = torch.cat((mem, k_input), dim=-2)
- v_input = torch.cat((mem, v_input), dim=-2)
-
- if exists(sinusoidal_emb):
- # in shortformer, the query would start at a position offset depending on the past cached memory
- offset = k_input.shape[-2] - q_input.shape[-2]
- q_input = q_input + sinusoidal_emb(q_input, offset=offset)
- k_input = k_input + sinusoidal_emb(k_input)
-
- q = self.to_q(q_input)
- k = self.to_k(k_input)
- v = self.to_v(v_input)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
-
- input_mask = None
- if any(map(exists, (mask, context_mask))):
- q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
- k_mask = q_mask if not exists(context) else context_mask
- k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
- q_mask = rearrange(q_mask, 'b i -> b () i ()')
- k_mask = rearrange(k_mask, 'b j -> b () () j')
- input_mask = q_mask * k_mask
-
- if self.num_mem_kv > 0:
- mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
- k = torch.cat((mem_k, k), dim=-2)
- v = torch.cat((mem_v, v), dim=-2)
- if exists(input_mask):
- input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
-
- dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
- mask_value = max_neg_value(dots)
-
- if exists(prev_attn):
- dots = dots + prev_attn
-
- pre_softmax_attn = dots
-
- if talking_heads:
- dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
-
- if exists(rel_pos):
- dots = rel_pos(dots)
-
- if exists(input_mask):
- dots.masked_fill_(~input_mask, mask_value)
- del input_mask
-
- if self.causal:
- i, j = dots.shape[-2:]
- r = torch.arange(i, device=device)
- mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
- mask = F.pad(mask, (j - i, 0), value=False)
- dots.masked_fill_(mask, mask_value)
- del mask
-
- if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
- top, _ = dots.topk(self.sparse_topk, dim=-1)
- vk = top[..., -1].unsqueeze(-1).expand_as(dots)
- mask = dots < vk
- dots.masked_fill_(mask, mask_value)
- del mask
-
- attn = self.attn_fn(dots, dim=-1)
- post_softmax_attn = attn
-
- attn = self.dropout(attn)
-
- if talking_heads:
- attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
-
- out = einsum('b h i j, b h j d -> b h i d', attn, v)
- out = rearrange(out, 'b h n d -> b n (h d)')
-
- intermediates = Intermediates(
- pre_softmax_attn=pre_softmax_attn,
- post_softmax_attn=post_softmax_attn
- )
-
- return self.to_out(out), intermediates
-
-
-class AttentionLayers(nn.Module):
- def __init__(
- self,
- dim,
- depth,
- heads=8,
- causal=False,
- cross_attend=False,
- only_cross=False,
- use_scalenorm=False,
- use_rmsnorm=False,
- use_rezero=False,
- rel_pos_num_buckets=32,
- rel_pos_max_distance=128,
- position_infused_attn=False,
- custom_layers=None,
- sandwich_coef=None,
- par_ratio=None,
- residual_attn=False,
- cross_residual_attn=False,
- macaron=False,
- pre_norm=True,
- gate_residual=False,
- **kwargs
- ):
- super().__init__()
- ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
- attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
-
- dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
-
- self.dim = dim
- self.depth = depth
- self.layers = nn.ModuleList([])
-
- self.has_pos_emb = position_infused_attn
- self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
- self.rotary_pos_emb = always(None)
-
- assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
- self.rel_pos = None
-
- self.pre_norm = pre_norm
-
- self.residual_attn = residual_attn
- self.cross_residual_attn = cross_residual_attn
-
- norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
- norm_class = RMSNorm if use_rmsnorm else norm_class
- norm_fn = partial(norm_class, dim)
-
- norm_fn = nn.Identity if use_rezero else norm_fn
- branch_fn = Rezero if use_rezero else None
-
- if cross_attend and not only_cross:
- default_block = ('a', 'c', 'f')
- elif cross_attend and only_cross:
- default_block = ('c', 'f')
- else:
- default_block = ('a', 'f')
-
- if macaron:
- default_block = ('f',) + default_block
-
- if exists(custom_layers):
- layer_types = custom_layers
- elif exists(par_ratio):
- par_depth = depth * len(default_block)
- assert 1 < par_ratio <= par_depth, 'par ratio out of range'
- default_block = tuple(filter(not_equals('f'), default_block))
- par_attn = par_depth // par_ratio
- depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
- par_width = (depth_cut + depth_cut // par_attn) // par_attn
- assert len(default_block) <= par_width, 'default block is too large for par_ratio'
- par_block = default_block + ('f',) * (par_width - len(default_block))
- par_head = par_block * par_attn
- layer_types = par_head + ('f',) * (par_depth - len(par_head))
- elif exists(sandwich_coef):
- assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
- layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
- else:
- layer_types = default_block * depth
-
- self.layer_types = layer_types
- self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
-
- for layer_type in self.layer_types:
- if layer_type == 'a':
- layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
- elif layer_type == 'c':
- layer = Attention(dim, heads=heads, **attn_kwargs)
- elif layer_type == 'f':
- layer = FeedForward(dim, **ff_kwargs)
- layer = layer if not macaron else Scale(0.5, layer)
- else:
- raise Exception(f'invalid layer type {layer_type}')
-
- if isinstance(layer, Attention) and exists(branch_fn):
- layer = branch_fn(layer)
-
- if gate_residual:
- residual_fn = GRUGating(dim)
- else:
- residual_fn = Residual()
-
- self.layers.append(nn.ModuleList([
- norm_fn(),
- layer,
- residual_fn
- ]))
-
- def forward(
- self,
- x,
- context=None,
- mask=None,
- context_mask=None,
- mems=None,
- return_hiddens=False
- ):
- hiddens = []
- intermediates = []
- prev_attn = None
- prev_cross_attn = None
-
- mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
-
- for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
- is_last = ind == (len(self.layers) - 1)
-
- if layer_type == 'a':
- hiddens.append(x)
- layer_mem = mems.pop(0)
-
- residual = x
-
- if self.pre_norm:
- x = norm(x)
-
- if layer_type == 'a':
- out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
- prev_attn=prev_attn, mem=layer_mem)
- elif layer_type == 'c':
- out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
- elif layer_type == 'f':
- out = block(x)
-
- x = residual_fn(out, residual)
-
- if layer_type in ('a', 'c'):
- intermediates.append(inter)
-
- if layer_type == 'a' and self.residual_attn:
- prev_attn = inter.pre_softmax_attn
- elif layer_type == 'c' and self.cross_residual_attn:
- prev_cross_attn = inter.pre_softmax_attn
-
- if not self.pre_norm and not is_last:
- x = norm(x)
-
- if return_hiddens:
- intermediates = LayerIntermediates(
- hiddens=hiddens,
- attn_intermediates=intermediates
- )
-
- return x, intermediates
-
- return x
-
-
-class Encoder(AttentionLayers):
- def __init__(self, **kwargs):
- assert 'causal' not in kwargs, 'cannot set causality on encoder'
- super().__init__(causal=False, **kwargs)
-
-
-
-class TransformerWrapper(nn.Module):
- def __init__(
- self,
- *,
- num_tokens,
- max_seq_len,
- attn_layers,
- emb_dim=None,
- max_mem_len=0.,
- emb_dropout=0.,
- num_memory_tokens=None,
- tie_embedding=False,
- use_pos_emb=True
- ):
- super().__init__()
- assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
-
- dim = attn_layers.dim
- emb_dim = default(emb_dim, dim)
-
- self.max_seq_len = max_seq_len
- self.max_mem_len = max_mem_len
- self.num_tokens = num_tokens
-
- self.token_emb = nn.Embedding(num_tokens, emb_dim)
- self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
- use_pos_emb and not attn_layers.has_pos_emb) else always(0)
- self.emb_dropout = nn.Dropout(emb_dropout)
-
- self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
- self.attn_layers = attn_layers
- self.norm = nn.LayerNorm(dim)
-
- self.init_()
-
- self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
-
- # memory tokens (like [cls]) from Memory Transformers paper
- num_memory_tokens = default(num_memory_tokens, 0)
- self.num_memory_tokens = num_memory_tokens
- if num_memory_tokens > 0:
- self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
-
- # let funnel encoder know number of memory tokens, if specified
- if hasattr(attn_layers, 'num_memory_tokens'):
- attn_layers.num_memory_tokens = num_memory_tokens
-
- def init_(self):
- nn.init.normal_(self.token_emb.weight, std=0.02)
-
- def forward(
- self,
- x,
- return_embeddings=False,
- mask=None,
- return_mems=False,
- return_attn=False,
- mems=None,
- **kwargs
- ):
- b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
- x = self.token_emb(x)
- x += self.pos_emb(x)
- x = self.emb_dropout(x)
-
- x = self.project_emb(x)
-
- if num_mem > 0:
- mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
- x = torch.cat((mem, x), dim=1)
-
- # auto-handle masking after appending memory tokens
- if exists(mask):
- mask = F.pad(mask, (num_mem, 0), value=True)
-
- x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
- x = self.norm(x)
-
- mem, x = x[:, :num_mem], x[:, num_mem:]
-
- out = self.to_logits(x) if not return_embeddings else x
-
- if return_mems:
- hiddens = intermediates.hiddens
- new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
- new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
- return out, new_mems
-
- if return_attn:
- attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
- return out, attn_maps
-
- return out
-
diff --git a/examples/tutorial/stable_diffusion/ldm/util.py b/examples/tutorial/stable_diffusion/ldm/util.py
deleted file mode 100644
index 8ba38853e7a0..000000000000
--- a/examples/tutorial/stable_diffusion/ldm/util.py
+++ /dev/null
@@ -1,203 +0,0 @@
-import importlib
-
-import torch
-import numpy as np
-from collections import abc
-from einops import rearrange
-from functools import partial
-
-import multiprocessing as mp
-from threading import Thread
-from queue import Queue
-
-from inspect import isfunction
-from PIL import Image, ImageDraw, ImageFont
-
-
-def log_txt_as_img(wh, xc, size=10):
- # wh a tuple of (width, height)
- # xc a list of captions to plot
- b = len(xc)
- txts = list()
- for bi in range(b):
- txt = Image.new("RGB", wh, color="white")
- draw = ImageDraw.Draw(txt)
- font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
- nc = int(40 * (wh[0] / 256))
- lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
-
- try:
- draw.text((0, 0), lines, fill="black", font=font)
- except UnicodeEncodeError:
- print("Cant encode string for logging. Skipping.")
-
- txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
- txts.append(txt)
- txts = np.stack(txts)
- txts = torch.tensor(txts)
- return txts
-
-
-def ismap(x):
- if not isinstance(x, torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] > 3)
-
-
-def isimage(x):
- if not isinstance(x, torch.Tensor):
- return False
- return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
-
-
-def exists(x):
- return x is not None
-
-
-def default(val, d):
- if exists(val):
- return val
- return d() if isfunction(d) else d
-
-
-def mean_flat(tensor):
- """
- https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
- Take the mean over all non-batch dimensions.
- """
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
-
-
-def count_params(model, verbose=False):
- total_params = sum(p.numel() for p in model.parameters())
- if verbose:
- print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
- return total_params
-
-
-def instantiate_from_config(config):
- if not "target" in config:
- if config == '__is_first_stage__':
- return None
- elif config == "__is_unconditional__":
- return None
- raise KeyError("Expected key `target` to instantiate.")
- return get_obj_from_str(config["target"])(**config.get("params", dict()))
-
-
-def get_obj_from_str(string, reload=False):
- module, cls = string.rsplit(".", 1)
- if reload:
- module_imp = importlib.import_module(module)
- importlib.reload(module_imp)
- return getattr(importlib.import_module(module, package=None), cls)
-
-
-def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
- # create dummy dataset instance
-
- # run prefetching
- if idx_to_fn:
- res = func(data, worker_id=idx)
- else:
- res = func(data)
- Q.put([idx, res])
- Q.put("Done")
-
-
-def parallel_data_prefetch(
- func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
-):
- # if target_data_type not in ["ndarray", "list"]:
- # raise ValueError(
- # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
- # )
- if isinstance(data, np.ndarray) and target_data_type == "list":
- raise ValueError("list expected but function got ndarray.")
- elif isinstance(data, abc.Iterable):
- if isinstance(data, dict):
- print(
- f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
- )
- data = list(data.values())
- if target_data_type == "ndarray":
- data = np.asarray(data)
- else:
- data = list(data)
- else:
- raise TypeError(
- f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
- )
-
- if cpu_intensive:
- Q = mp.Queue(1000)
- proc = mp.Process
- else:
- Q = Queue(1000)
- proc = Thread
- # spawn processes
- if target_data_type == "ndarray":
- arguments = [
- [func, Q, part, i, use_worker_id]
- for i, part in enumerate(np.array_split(data, n_proc))
- ]
- else:
- step = (
- int(len(data) / n_proc + 1)
- if len(data) % n_proc != 0
- else int(len(data) / n_proc)
- )
- arguments = [
- [func, Q, part, i, use_worker_id]
- for i, part in enumerate(
- [data[i: i + step] for i in range(0, len(data), step)]
- )
- ]
- processes = []
- for i in range(n_proc):
- p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
- processes += [p]
-
- # start processes
- print(f"Start prefetching...")
- import time
-
- start = time.time()
- gather_res = [[] for _ in range(n_proc)]
- try:
- for p in processes:
- p.start()
-
- k = 0
- while k < n_proc:
- # get result
- res = Q.get()
- if res == "Done":
- k += 1
- else:
- gather_res[res[0]] = res[1]
-
- except Exception as e:
- print("Exception: ", e)
- for p in processes:
- p.terminate()
-
- raise e
- finally:
- for p in processes:
- p.join()
- print(f"Prefetching complete. [{time.time() - start} sec.]")
-
- if target_data_type == 'ndarray':
- if not isinstance(gather_res[0], np.ndarray):
- return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
-
- # order outputs
- return np.concatenate(gather_res, axis=0)
- elif target_data_type == 'list':
- out = []
- for r in gather_res:
- out.extend(r)
- return out
- else:
- return gather_res
diff --git a/examples/tutorial/stable_diffusion/main.py b/examples/tutorial/stable_diffusion/main.py
deleted file mode 100644
index 7cd00e4c0c26..000000000000
--- a/examples/tutorial/stable_diffusion/main.py
+++ /dev/null
@@ -1,830 +0,0 @@
-import argparse, os, sys, datetime, glob, importlib, csv
-import numpy as np
-import time
-import torch
-import torchvision
-import pytorch_lightning as pl
-
-from packaging import version
-from omegaconf import OmegaConf
-from torch.utils.data import random_split, DataLoader, Dataset, Subset
-from functools import partial
-from PIL import Image
-# from pytorch_lightning.strategies.colossalai import ColossalAIStrategy
-# from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
-from colossalai.nn.optimizer import HybridAdam
-from prefetch_generator import BackgroundGenerator
-
-from pytorch_lightning import seed_everything
-from pytorch_lightning.trainer import Trainer
-from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
-from pytorch_lightning.utilities.rank_zero import rank_zero_only
-from pytorch_lightning.utilities import rank_zero_info
-from diffusers.models.unet_2d import UNet2DModel
-
-from clip.model import Bottleneck
-from transformers.models.clip.modeling_clip import CLIPTextTransformer
-
-from ldm.data.base import Txt2ImgIterableBaseDataset
-from ldm.util import instantiate_from_config
-import clip
-from einops import rearrange, repeat
-from transformers import CLIPTokenizer, CLIPTextModel
-import kornia
-
-from ldm.modules.x_transformer import *
-from ldm.modules.encoders.modules import *
-from taming.modules.diffusionmodules.model import ResnetBlock
-from taming.modules.transformer.mingpt import *
-from taming.modules.transformer.permuter import *
-
-
-from ldm.modules.ema import LitEma
-from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
-from ldm.models.autoencoder import AutoencoderKL
-from ldm.models.autoencoder import *
-from ldm.models.diffusion.ddim import *
-from ldm.modules.diffusionmodules.openaimodel import *
-from ldm.modules.diffusionmodules.model import *
-from ldm.modules.diffusionmodules.model import Decoder, Encoder, Up_module, Down_module, Mid_module, temb_module
-from ldm.modules.attention import enable_flash_attention
-
-class DataLoaderX(DataLoader):
-
- def __iter__(self):
- return BackgroundGenerator(super().__iter__())
-
-
-def get_parser(**parser_kwargs):
- def str2bool(v):
- if isinstance(v, bool):
- return v
- if v.lower() in ("yes", "true", "t", "y", "1"):
- return True
- elif v.lower() in ("no", "false", "f", "n", "0"):
- return False
- else:
- raise argparse.ArgumentTypeError("Boolean value expected.")
-
- parser = argparse.ArgumentParser(**parser_kwargs)
- parser.add_argument(
- "-n",
- "--name",
- type=str,
- const=True,
- default="",
- nargs="?",
- help="postfix for logdir",
- )
- parser.add_argument(
- "-r",
- "--resume",
- type=str,
- const=True,
- default="",
- nargs="?",
- help="resume from logdir or checkpoint in logdir",
- )
- parser.add_argument(
- "-b",
- "--base",
- nargs="*",
- metavar="base_config.yaml",
- help="paths to base configs. Loaded from left-to-right. "
- "Parameters can be overwritten or added with command-line options of the form `--key value`.",
- default=list(),
- )
- parser.add_argument(
- "-t",
- "--train",
- type=str2bool,
- const=True,
- default=False,
- nargs="?",
- help="train",
- )
- parser.add_argument(
- "--no-test",
- type=str2bool,
- const=True,
- default=False,
- nargs="?",
- help="disable test",
- )
- parser.add_argument(
- "-p",
- "--project",
- help="name of new or path to existing project"
- )
- parser.add_argument(
- "-d",
- "--debug",
- type=str2bool,
- nargs="?",
- const=True,
- default=False,
- help="enable post-mortem debugging",
- )
- parser.add_argument(
- "-s",
- "--seed",
- type=int,
- default=23,
- help="seed for seed_everything",
- )
- parser.add_argument(
- "-f",
- "--postfix",
- type=str,
- default="",
- help="post-postfix for default name",
- )
- parser.add_argument(
- "-l",
- "--logdir",
- type=str,
- default="logs",
- help="directory for logging dat shit",
- )
- parser.add_argument(
- "--scale_lr",
- type=str2bool,
- nargs="?",
- const=True,
- default=True,
- help="scale base-lr by ngpu * batch_size * n_accumulate",
- )
- parser.add_argument(
- "--use_fp16",
- type=str2bool,
- nargs="?",
- const=True,
- default=True,
- help="whether to use fp16",
- )
- parser.add_argument(
- "--flash",
- type=str2bool,
- const=True,
- default=False,
- nargs="?",
- help="whether to use flash attention",
- )
- return parser
-
-
-def nondefault_trainer_args(opt):
- parser = argparse.ArgumentParser()
- parser = Trainer.add_argparse_args(parser)
- args = parser.parse_args([])
- return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
-
-
-class WrappedDataset(Dataset):
- """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
-
- def __init__(self, dataset):
- self.data = dataset
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, idx):
- return self.data[idx]
-
-
-def worker_init_fn(_):
- worker_info = torch.utils.data.get_worker_info()
-
- dataset = worker_info.dataset
- worker_id = worker_info.id
-
- if isinstance(dataset, Txt2ImgIterableBaseDataset):
- split_size = dataset.num_records // worker_info.num_workers
- # reset num_records to the true number to retain reliable length information
- dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
- current_id = np.random.choice(len(np.random.get_state()[1]), 1)
- return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
- else:
- return np.random.seed(np.random.get_state()[1][0] + worker_id)
-
-
-class DataModuleFromConfig(pl.LightningDataModule):
- def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
- wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
- shuffle_val_dataloader=False):
- super().__init__()
- self.batch_size = batch_size
- self.dataset_configs = dict()
- self.num_workers = num_workers if num_workers is not None else batch_size * 2
- self.use_worker_init_fn = use_worker_init_fn
- if train is not None:
- self.dataset_configs["train"] = train
- self.train_dataloader = self._train_dataloader
- if validation is not None:
- self.dataset_configs["validation"] = validation
- self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
- if test is not None:
- self.dataset_configs["test"] = test
- self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
- if predict is not None:
- self.dataset_configs["predict"] = predict
- self.predict_dataloader = self._predict_dataloader
- self.wrap = wrap
-
- def prepare_data(self):
- for data_cfg in self.dataset_configs.values():
- instantiate_from_config(data_cfg)
-
- def setup(self, stage=None):
- self.datasets = dict(
- (k, instantiate_from_config(self.dataset_configs[k]))
- for k in self.dataset_configs)
- if self.wrap:
- for k in self.datasets:
- self.datasets[k] = WrappedDataset(self.datasets[k])
-
- def _train_dataloader(self):
- is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
- if is_iterable_dataset or self.use_worker_init_fn:
- init_fn = worker_init_fn
- else:
- init_fn = None
- return DataLoaderX(self.datasets["train"], batch_size=self.batch_size,
- num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
- worker_init_fn=init_fn)
-
- def _val_dataloader(self, shuffle=False):
- if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
- init_fn = worker_init_fn
- else:
- init_fn = None
- return DataLoaderX(self.datasets["validation"],
- batch_size=self.batch_size,
- num_workers=self.num_workers,
- worker_init_fn=init_fn,
- shuffle=shuffle)
-
- def _test_dataloader(self, shuffle=False):
- is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
- if is_iterable_dataset or self.use_worker_init_fn:
- init_fn = worker_init_fn
- else:
- init_fn = None
-
- # do not shuffle dataloader for iterable dataset
- shuffle = shuffle and (not is_iterable_dataset)
-
- return DataLoaderX(self.datasets["test"], batch_size=self.batch_size,
- num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
-
- def _predict_dataloader(self, shuffle=False):
- if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
- init_fn = worker_init_fn
- else:
- init_fn = None
- return DataLoaderX(self.datasets["predict"], batch_size=self.batch_size,
- num_workers=self.num_workers, worker_init_fn=init_fn)
-
-
-class SetupCallback(Callback):
- def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
- super().__init__()
- self.resume = resume
- self.now = now
- self.logdir = logdir
- self.ckptdir = ckptdir
- self.cfgdir = cfgdir
- self.config = config
- self.lightning_config = lightning_config
-
- def on_keyboard_interrupt(self, trainer, pl_module):
- if trainer.global_rank == 0:
- print("Summoning checkpoint.")
- ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
- trainer.save_checkpoint(ckpt_path)
-
- # def on_pretrain_routine_start(self, trainer, pl_module):
- def on_fit_start(self, trainer, pl_module):
- if trainer.global_rank == 0:
- # Create logdirs and save configs
- os.makedirs(self.logdir, exist_ok=True)
- os.makedirs(self.ckptdir, exist_ok=True)
- os.makedirs(self.cfgdir, exist_ok=True)
-
- if "callbacks" in self.lightning_config:
- if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
- os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
- print("Project config")
- print(OmegaConf.to_yaml(self.config))
- OmegaConf.save(self.config,
- os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
-
- print("Lightning config")
- print(OmegaConf.to_yaml(self.lightning_config))
- OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
- os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
-
- else:
- # ModelCheckpoint callback created log directory --- remove it
- if not self.resume and os.path.exists(self.logdir):
- dst, name = os.path.split(self.logdir)
- dst = os.path.join(dst, "child_runs", name)
- os.makedirs(os.path.split(dst)[0], exist_ok=True)
- try:
- os.rename(self.logdir, dst)
- except FileNotFoundError:
- pass
-
-
-class ImageLogger(Callback):
- def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
- rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
- log_images_kwargs=None):
- super().__init__()
- self.rescale = rescale
- self.batch_freq = batch_frequency
- self.max_images = max_images
- self.logger_log_images = {
- pl.loggers.CSVLogger: self._testtube,
- }
- self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
- if not increase_log_steps:
- self.log_steps = [self.batch_freq]
- self.clamp = clamp
- self.disabled = disabled
- self.log_on_batch_idx = log_on_batch_idx
- self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
- self.log_first_step = log_first_step
-
- @rank_zero_only
- def _testtube(self, pl_module, images, batch_idx, split):
- for k in images:
- grid = torchvision.utils.make_grid(images[k])
- grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
-
- tag = f"{split}/{k}"
- pl_module.logger.experiment.add_image(
- tag, grid,
- global_step=pl_module.global_step)
-
- @rank_zero_only
- def log_local(self, save_dir, split, images,
- global_step, current_epoch, batch_idx):
- root = os.path.join(save_dir, "images", split)
- for k in images:
- grid = torchvision.utils.make_grid(images[k], nrow=4)
- if self.rescale:
- grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
- grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
- grid = grid.numpy()
- grid = (grid * 255).astype(np.uint8)
- filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
- k,
- global_step,
- current_epoch,
- batch_idx)
- path = os.path.join(root, filename)
- os.makedirs(os.path.split(path)[0], exist_ok=True)
- Image.fromarray(grid).save(path)
-
- def log_img(self, pl_module, batch, batch_idx, split="train"):
- check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
- if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
- hasattr(pl_module, "log_images") and
- callable(pl_module.log_images) and
- self.max_images > 0):
- logger = type(pl_module.logger)
-
- is_train = pl_module.training
- if is_train:
- pl_module.eval()
-
- with torch.no_grad():
- images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
-
- for k in images:
- N = min(images[k].shape[0], self.max_images)
- images[k] = images[k][:N]
- if isinstance(images[k], torch.Tensor):
- images[k] = images[k].detach().cpu()
- if self.clamp:
- images[k] = torch.clamp(images[k], -1., 1.)
-
- self.log_local(pl_module.logger.save_dir, split, images,
- pl_module.global_step, pl_module.current_epoch, batch_idx)
-
- logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
- logger_log_images(pl_module, images, pl_module.global_step, split)
-
- if is_train:
- pl_module.train()
-
- def check_frequency(self, check_idx):
- if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
- check_idx > 0 or self.log_first_step):
- try:
- self.log_steps.pop(0)
- except IndexError as e:
- print(e)
- pass
- return True
- return False
-
- def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
- # if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
- # self.log_img(pl_module, batch, batch_idx, split="train")
- pass
-
- def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
- if not self.disabled and pl_module.global_step > 0:
- self.log_img(pl_module, batch, batch_idx, split="val")
- if hasattr(pl_module, 'calibrate_grad_norm'):
- if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
- self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
-
-
-class CUDACallback(Callback):
- # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
-
- def on_train_start(self, trainer, pl_module):
- rank_zero_info("Training is starting")
-
- def on_train_end(self, trainer, pl_module):
- rank_zero_info("Training is ending")
-
- def on_train_epoch_start(self, trainer, pl_module):
- # Reset the memory use counter
- torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
- torch.cuda.synchronize(trainer.strategy.root_device.index)
- self.start_time = time.time()
-
- def on_train_epoch_end(self, trainer, pl_module):
- torch.cuda.synchronize(trainer.strategy.root_device.index)
- max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20
- epoch_time = time.time() - self.start_time
-
- try:
- max_memory = trainer.strategy.reduce(max_memory)
- epoch_time = trainer.strategy.reduce(epoch_time)
-
- rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
- rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
- except AttributeError:
- pass
-
-
-if __name__ == "__main__":
- # custom parser to specify config files, train, test and debug mode,
- # postfix, resume.
- # `--key value` arguments are interpreted as arguments to the trainer.
- # `nested.key=value` arguments are interpreted as config parameters.
- # configs are merged from left-to-right followed by command line parameters.
-
- # model:
- # base_learning_rate: float
- # target: path to lightning module
- # params:
- # key: value
- # data:
- # target: main.DataModuleFromConfig
- # params:
- # batch_size: int
- # wrap: bool
- # train:
- # target: path to train dataset
- # params:
- # key: value
- # validation:
- # target: path to validation dataset
- # params:
- # key: value
- # test:
- # target: path to test dataset
- # params:
- # key: value
- # lightning: (optional, has sane defaults and can be specified on cmdline)
- # trainer:
- # additional arguments to trainer
- # logger:
- # logger to instantiate
- # modelcheckpoint:
- # modelcheckpoint to instantiate
- # callbacks:
- # callback1:
- # target: importpath
- # params:
- # key: value
-
- now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
-
- # add cwd for convenience and to make classes in this file available when
- # running as `python main.py`
- # (in particular `main.DataModuleFromConfig`)
- sys.path.append(os.getcwd())
-
- parser = get_parser()
- parser = Trainer.add_argparse_args(parser)
-
- opt, unknown = parser.parse_known_args()
- if opt.name and opt.resume:
- raise ValueError(
- "-n/--name and -r/--resume cannot be specified both."
- "If you want to resume training in a new log folder, "
- "use -n/--name in combination with --resume_from_checkpoint"
- )
- if opt.flash:
- enable_flash_attention()
- if opt.resume:
- if not os.path.exists(opt.resume):
- raise ValueError("Cannot find {}".format(opt.resume))
- if os.path.isfile(opt.resume):
- paths = opt.resume.split("/")
- # idx = len(paths)-paths[::-1].index("logs")+1
- # logdir = "/".join(paths[:idx])
- logdir = "/".join(paths[:-2])
- ckpt = opt.resume
- else:
- assert os.path.isdir(opt.resume), opt.resume
- logdir = opt.resume.rstrip("/")
- ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
-
- opt.resume_from_checkpoint = ckpt
- base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
- opt.base = base_configs + opt.base
- _tmp = logdir.split("/")
- nowname = _tmp[-1]
- else:
- if opt.name:
- name = "_" + opt.name
- elif opt.base:
- cfg_fname = os.path.split(opt.base[0])[-1]
- cfg_name = os.path.splitext(cfg_fname)[0]
- name = "_" + cfg_name
- else:
- name = ""
- nowname = now + name + opt.postfix
- logdir = os.path.join(opt.logdir, nowname)
-
- ckptdir = os.path.join(logdir, "checkpoints")
- cfgdir = os.path.join(logdir, "configs")
- seed_everything(opt.seed)
-
- try:
- # init and save configs
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
- cli = OmegaConf.from_dotlist(unknown)
- config = OmegaConf.merge(*configs, cli)
- lightning_config = config.pop("lightning", OmegaConf.create())
- # merge trainer cli with config
- trainer_config = lightning_config.get("trainer", OmegaConf.create())
-
- for k in nondefault_trainer_args(opt):
- trainer_config[k] = getattr(opt, k)
-
- print(trainer_config)
- if not trainer_config["accelerator"] == "gpu":
- del trainer_config["accelerator"]
- cpu = True
- print("Running on CPU")
- else:
- cpu = False
- print("Running on GPU")
- trainer_opt = argparse.Namespace(**trainer_config)
- lightning_config.trainer = trainer_config
-
- # model
- use_fp16 = trainer_config.get("precision", 32) == 16
- if use_fp16:
- config.model["params"].update({"use_fp16": True})
- print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
- else:
- config.model["params"].update({"use_fp16": False})
- print("Using FP16 = {}".format(config.model["params"]["use_fp16"]))
-
- model = instantiate_from_config(config.model)
- # trainer and callbacks
- trainer_kwargs = dict()
-
- # config the logger
- # default logger configs
- default_logger_cfgs = {
- "wandb": {
- "target": "pytorch_lightning.loggers.WandbLogger",
- "params": {
- "name": nowname,
- "save_dir": logdir,
- "offline": opt.debug,
- "id": nowname,
- }
- },
- "tensorboard":{
- "target": "pytorch_lightning.loggers.TensorBoardLogger",
- "params":{
- "save_dir": logdir,
- "name": "diff_tb",
- "log_graph": True
- }
- }
- }
-
- default_logger_cfg = default_logger_cfgs["tensorboard"]
- if "logger" in lightning_config:
- logger_cfg = lightning_config.logger
- else:
- logger_cfg = default_logger_cfg
- logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
- trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
-
- # config the strategy, defualt is ddp
- if "strategy" in trainer_config:
- strategy_cfg = trainer_config["strategy"]
- print("Using strategy: {}".format(strategy_cfg["target"]))
- else:
- strategy_cfg = {
- "target": "pytorch_lightning.strategies.DDPStrategy",
- "params": {
- "find_unused_parameters": False
- }
- }
- print("Using strategy: DDPStrategy")
-
- trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg)
-
- # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
- # specify which metric is used to determine best models
- default_modelckpt_cfg = {
- "target": "pytorch_lightning.callbacks.ModelCheckpoint",
- "params": {
- "dirpath": ckptdir,
- "filename": "{epoch:06}",
- "verbose": True,
- "save_last": True,
- }
- }
- if hasattr(model, "monitor"):
- print(f"Monitoring {model.monitor} as checkpoint metric.")
- default_modelckpt_cfg["params"]["monitor"] = model.monitor
- default_modelckpt_cfg["params"]["save_top_k"] = 3
-
- if "modelcheckpoint" in lightning_config:
- modelckpt_cfg = lightning_config.modelcheckpoint
- else:
- modelckpt_cfg = OmegaConf.create()
- modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
- print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
- if version.parse(pl.__version__) < version.parse('1.4.0'):
- trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
-
- # add callback which sets up log directory
- default_callbacks_cfg = {
- "setup_callback": {
- "target": "main.SetupCallback",
- "params": {
- "resume": opt.resume,
- "now": now,
- "logdir": logdir,
- "ckptdir": ckptdir,
- "cfgdir": cfgdir,
- "config": config,
- "lightning_config": lightning_config,
- }
- },
- "image_logger": {
- "target": "main.ImageLogger",
- "params": {
- "batch_frequency": 750,
- "max_images": 4,
- "clamp": True
- }
- },
- "learning_rate_logger": {
- "target": "main.LearningRateMonitor",
- "params": {
- "logging_interval": "step",
- # "log_momentum": True
- }
- },
- "cuda_callback": {
- "target": "main.CUDACallback"
- },
- }
- if version.parse(pl.__version__) >= version.parse('1.4.0'):
- default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
-
- if "callbacks" in lightning_config:
- callbacks_cfg = lightning_config.callbacks
- else:
- callbacks_cfg = OmegaConf.create()
-
- if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
- print(
- 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
- default_metrics_over_trainsteps_ckpt_dict = {
- 'metrics_over_trainsteps_checkpoint':
- {"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
- 'params': {
- "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
- "filename": "{epoch:06}-{step:09}",
- "verbose": True,
- 'save_top_k': -1,
- 'every_n_train_steps': 10000,
- 'save_weights_only': True
- }
- }
- }
- default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
-
- callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
- if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
- callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
- elif 'ignore_keys_callback' in callbacks_cfg:
- del callbacks_cfg['ignore_keys_callback']
-
- trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
-
- trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
- trainer.logdir = logdir ###
-
- # data
- data = instantiate_from_config(config.data)
- # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
- # calling these ourselves should not be necessary but it is.
- # lightning still takes care of proper multiprocessing though
- data.prepare_data()
- data.setup()
- print("#### Data #####")
- for k in data.datasets:
- print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
-
- # configure learning rate
- bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
- if not cpu:
- ngpu = trainer_config["devices"]
- else:
- ngpu = 1
- if 'accumulate_grad_batches' in lightning_config.trainer:
- accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
- else:
- accumulate_grad_batches = 1
- print(f"accumulate_grad_batches = {accumulate_grad_batches}")
- lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
- if opt.scale_lr:
- model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
- print(
- "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
- model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))
- else:
- model.learning_rate = base_lr
- print("++++ NOT USING LR SCALING ++++")
- print(f"Setting learning rate to {model.learning_rate:.2e}")
-
-
- # allow checkpointing via USR1
- def melk(*args, **kwargs):
- # run all checkpoint hooks
- if trainer.global_rank == 0:
- print("Summoning checkpoint.")
- ckpt_path = os.path.join(ckptdir, "last.ckpt")
- trainer.save_checkpoint(ckpt_path)
-
-
- def divein(*args, **kwargs):
- if trainer.global_rank == 0:
- import pudb;
- pudb.set_trace()
-
-
- import signal
-
- signal.signal(signal.SIGUSR1, melk)
- signal.signal(signal.SIGUSR2, divein)
-
- # run
- if opt.train:
- try:
- for name, m in model.named_parameters():
- print(name)
- trainer.fit(model, data)
- except Exception:
- melk()
- raise
- # if not opt.no_test and not trainer.interrupted:
- # trainer.test(model, data)
- except Exception:
- if opt.debug and trainer.global_rank == 0:
- try:
- import pudb as debugger
- except ImportError:
- import pdb as debugger
- debugger.post_mortem()
- raise
- finally:
- # move newly created debug project to debug_runs
- if opt.debug and not opt.resume and trainer.global_rank == 0:
- dst, name = os.path.split(logdir)
- dst = os.path.join(dst, "debug_runs", name)
- os.makedirs(os.path.split(dst)[0], exist_ok=True)
- os.rename(logdir, dst)
- if trainer.global_rank == 0:
- print(trainer.profiler.summary())
diff --git a/examples/tutorial/stable_diffusion/requirements.txt b/examples/tutorial/stable_diffusion/requirements.txt
deleted file mode 100644
index a57003562a3b..000000000000
--- a/examples/tutorial/stable_diffusion/requirements.txt
+++ /dev/null
@@ -1,22 +0,0 @@
-albumentations==0.4.3
-diffusers
-pudb==2019.2
-datasets
-invisible-watermark
-imageio==2.9.0
-imageio-ffmpeg==0.4.2
-omegaconf==2.1.1
-multiprocess
-test-tube>=0.7.5
-streamlit>=0.73.1
-einops==0.3.0
-torch-fidelity==0.3.0
-transformers==4.19.2
-torchmetrics==0.6.0
-kornia==0.6
-opencv-python==4.6.0.66
-prefetch_generator
-colossalai
--e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
--e git+https://github.com/openai/CLIP.git@main#egg=clip
--e .
diff --git a/examples/tutorial/stable_diffusion/scripts/download_first_stages.sh b/examples/tutorial/stable_diffusion/scripts/download_first_stages.sh
deleted file mode 100644
index a8d79e99ccdf..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/download_first_stages.sh
+++ /dev/null
@@ -1,41 +0,0 @@
-#!/bin/bash
-wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
-wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
-wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
-wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
-wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
-wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
-wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
-wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
-wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
-
-
-
-cd models/first_stage_models/kl-f4
-unzip -o model.zip
-
-cd ../kl-f8
-unzip -o model.zip
-
-cd ../kl-f16
-unzip -o model.zip
-
-cd ../kl-f32
-unzip -o model.zip
-
-cd ../vq-f4
-unzip -o model.zip
-
-cd ../vq-f4-noattn
-unzip -o model.zip
-
-cd ../vq-f8
-unzip -o model.zip
-
-cd ../vq-f8-n256
-unzip -o model.zip
-
-cd ../vq-f16
-unzip -o model.zip
-
-cd ../..
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/scripts/download_models.sh b/examples/tutorial/stable_diffusion/scripts/download_models.sh
deleted file mode 100644
index 84297d7b8b9a..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/download_models.sh
+++ /dev/null
@@ -1,49 +0,0 @@
-#!/bin/bash
-wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
-wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
-wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
-wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
-wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
-wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
-wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
-wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
-wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
-wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
-wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
-
-
-
-cd models/ldm/celeba256
-unzip -o celeba-256.zip
-
-cd ../ffhq256
-unzip -o ffhq-256.zip
-
-cd ../lsun_churches256
-unzip -o lsun_churches-256.zip
-
-cd ../lsun_beds256
-unzip -o lsun_beds-256.zip
-
-cd ../text2img256
-unzip -o model.zip
-
-cd ../cin256
-unzip -o model.zip
-
-cd ../semantic_synthesis512
-unzip -o model.zip
-
-cd ../semantic_synthesis256
-unzip -o model.zip
-
-cd ../bsr_sr
-unzip -o model.zip
-
-cd ../layout2img-openimages256
-unzip -o model.zip
-
-cd ../inpainting_big
-unzip -o model.zip
-
-cd ../..
diff --git a/examples/tutorial/stable_diffusion/scripts/img2img.py b/examples/tutorial/stable_diffusion/scripts/img2img.py
deleted file mode 100644
index 421e2151d9e9..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/img2img.py
+++ /dev/null
@@ -1,293 +0,0 @@
-"""make variations of input image"""
-
-import argparse, os, sys, glob
-import PIL
-import torch
-import numpy as np
-from omegaconf import OmegaConf
-from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
-from einops import rearrange, repeat
-from torchvision.utils import make_grid
-from torch import autocast
-from contextlib import nullcontext
-import time
-from pytorch_lightning import seed_everything
-
-from ldm.util import instantiate_from_config
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.models.diffusion.plms import PLMSSampler
-
-
-def chunk(it, size):
- it = iter(it)
- return iter(lambda: tuple(islice(it, size)), ())
-
-
-def load_model_from_config(config, ckpt, verbose=False):
- print(f"Loading model from {ckpt}")
- pl_sd = torch.load(ckpt, map_location="cpu")
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
- sd = pl_sd["state_dict"]
- model = instantiate_from_config(config.model)
- m, u = model.load_state_dict(sd, strict=False)
- if len(m) > 0 and verbose:
- print("missing keys:")
- print(m)
- if len(u) > 0 and verbose:
- print("unexpected keys:")
- print(u)
-
- model.cuda()
- model.eval()
- return model
-
-
-def load_img(path):
- image = Image.open(path).convert("RGB")
- w, h = image.size
- print(f"loaded input image of size ({w}, {h}) from {path}")
- w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
- image = np.array(image).astype(np.float32) / 255.0
- image = image[None].transpose(0, 3, 1, 2)
- image = torch.from_numpy(image)
- return 2.*image - 1.
-
-
-def main():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--prompt",
- type=str,
- nargs="?",
- default="a painting of a virus monster playing guitar",
- help="the prompt to render"
- )
-
- parser.add_argument(
- "--init-img",
- type=str,
- nargs="?",
- help="path to the input image"
- )
-
- parser.add_argument(
- "--outdir",
- type=str,
- nargs="?",
- help="dir to write results to",
- default="outputs/img2img-samples"
- )
-
- parser.add_argument(
- "--skip_grid",
- action='store_true',
- help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
- )
-
- parser.add_argument(
- "--skip_save",
- action='store_true',
- help="do not save indiviual samples. For speed measurements.",
- )
-
- parser.add_argument(
- "--ddim_steps",
- type=int,
- default=50,
- help="number of ddim sampling steps",
- )
-
- parser.add_argument(
- "--plms",
- action='store_true',
- help="use plms sampling",
- )
- parser.add_argument(
- "--fixed_code",
- action='store_true',
- help="if enabled, uses the same starting code across all samples ",
- )
-
- parser.add_argument(
- "--ddim_eta",
- type=float,
- default=0.0,
- help="ddim eta (eta=0.0 corresponds to deterministic sampling",
- )
- parser.add_argument(
- "--n_iter",
- type=int,
- default=1,
- help="sample this often",
- )
- parser.add_argument(
- "--C",
- type=int,
- default=4,
- help="latent channels",
- )
- parser.add_argument(
- "--f",
- type=int,
- default=8,
- help="downsampling factor, most often 8 or 16",
- )
- parser.add_argument(
- "--n_samples",
- type=int,
- default=2,
- help="how many samples to produce for each given prompt. A.k.a batch size",
- )
- parser.add_argument(
- "--n_rows",
- type=int,
- default=0,
- help="rows in the grid (default: n_samples)",
- )
- parser.add_argument(
- "--scale",
- type=float,
- default=5.0,
- help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
- )
-
- parser.add_argument(
- "--strength",
- type=float,
- default=0.75,
- help="strength for noising/unnoising. 1.0 corresponds to full destruction of information in init image",
- )
- parser.add_argument(
- "--from-file",
- type=str,
- help="if specified, load prompts from this file",
- )
- parser.add_argument(
- "--config",
- type=str,
- default="configs/stable-diffusion/v1-inference.yaml",
- help="path to config which constructs model",
- )
- parser.add_argument(
- "--ckpt",
- type=str,
- default="models/ldm/stable-diffusion-v1/model.ckpt",
- help="path to checkpoint of model",
- )
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="the seed (for reproducible sampling)",
- )
- parser.add_argument(
- "--precision",
- type=str,
- help="evaluate at this precision",
- choices=["full", "autocast"],
- default="autocast"
- )
-
- opt = parser.parse_args()
- seed_everything(opt.seed)
-
- config = OmegaConf.load(f"{opt.config}")
- model = load_model_from_config(config, f"{opt.ckpt}")
-
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- model = model.to(device)
-
- if opt.plms:
- raise NotImplementedError("PLMS sampler not (yet) supported")
- sampler = PLMSSampler(model)
- else:
- sampler = DDIMSampler(model)
-
- os.makedirs(opt.outdir, exist_ok=True)
- outpath = opt.outdir
-
- batch_size = opt.n_samples
- n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
- if not opt.from_file:
- prompt = opt.prompt
- assert prompt is not None
- data = [batch_size * [prompt]]
-
- else:
- print(f"reading prompts from {opt.from_file}")
- with open(opt.from_file, "r") as f:
- data = f.read().splitlines()
- data = list(chunk(data, batch_size))
-
- sample_path = os.path.join(outpath, "samples")
- os.makedirs(sample_path, exist_ok=True)
- base_count = len(os.listdir(sample_path))
- grid_count = len(os.listdir(outpath)) - 1
-
- assert os.path.isfile(opt.init_img)
- init_image = load_img(opt.init_img).to(device)
- init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
- init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
-
- sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
-
- assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
- t_enc = int(opt.strength * opt.ddim_steps)
- print(f"target t_enc is {t_enc} steps")
-
- precision_scope = autocast if opt.precision == "autocast" else nullcontext
- with torch.no_grad():
- with precision_scope("cuda"):
- with model.ema_scope():
- tic = time.time()
- all_samples = list()
- for n in trange(opt.n_iter, desc="Sampling"):
- for prompts in tqdm(data, desc="data"):
- uc = None
- if opt.scale != 1.0:
- uc = model.get_learned_conditioning(batch_size * [""])
- if isinstance(prompts, tuple):
- prompts = list(prompts)
- c = model.get_learned_conditioning(prompts)
-
- # encode (scaled latent)
- z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
- # decode it
- samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
- unconditional_conditioning=uc,)
-
- x_samples = model.decode_first_stage(samples)
- x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
-
- if not opt.skip_save:
- for x_sample in x_samples:
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
- Image.fromarray(x_sample.astype(np.uint8)).save(
- os.path.join(sample_path, f"{base_count:05}.png"))
- base_count += 1
- all_samples.append(x_samples)
-
- if not opt.skip_grid:
- # additionally, save as grid
- grid = torch.stack(all_samples, 0)
- grid = rearrange(grid, 'n b c h w -> (n b) c h w')
- grid = make_grid(grid, nrow=n_rows)
-
- # to image
- grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
- Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
- grid_count += 1
-
- toc = time.time()
-
- print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
- f" \nEnjoy.")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/tutorial/stable_diffusion/scripts/inpaint.py b/examples/tutorial/stable_diffusion/scripts/inpaint.py
deleted file mode 100644
index d6e6387a9a3b..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/inpaint.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import argparse, os, sys, glob
-from omegaconf import OmegaConf
-from PIL import Image
-from tqdm import tqdm
-import numpy as np
-import torch
-from main import instantiate_from_config
-from ldm.models.diffusion.ddim import DDIMSampler
-
-
-def make_batch(image, mask, device):
- image = np.array(Image.open(image).convert("RGB"))
- image = image.astype(np.float32)/255.0
- image = image[None].transpose(0,3,1,2)
- image = torch.from_numpy(image)
-
- mask = np.array(Image.open(mask).convert("L"))
- mask = mask.astype(np.float32)/255.0
- mask = mask[None,None]
- mask[mask < 0.5] = 0
- mask[mask >= 0.5] = 1
- mask = torch.from_numpy(mask)
-
- masked_image = (1-mask)*image
-
- batch = {"image": image, "mask": mask, "masked_image": masked_image}
- for k in batch:
- batch[k] = batch[k].to(device=device)
- batch[k] = batch[k]*2.0-1.0
- return batch
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--indir",
- type=str,
- nargs="?",
- help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
- )
- parser.add_argument(
- "--outdir",
- type=str,
- nargs="?",
- help="dir to write results to",
- )
- parser.add_argument(
- "--steps",
- type=int,
- default=50,
- help="number of ddim sampling steps",
- )
- opt = parser.parse_args()
-
- masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
- images = [x.replace("_mask.png", ".png") for x in masks]
- print(f"Found {len(masks)} inputs.")
-
- config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
- model = instantiate_from_config(config.model)
- model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
- strict=False)
-
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- model = model.to(device)
- sampler = DDIMSampler(model)
-
- os.makedirs(opt.outdir, exist_ok=True)
- with torch.no_grad():
- with model.ema_scope():
- for image, mask in tqdm(zip(images, masks)):
- outpath = os.path.join(opt.outdir, os.path.split(image)[1])
- batch = make_batch(image, mask, device=device)
-
- # encode masked image and concat downsampled mask
- c = model.cond_stage_model.encode(batch["masked_image"])
- cc = torch.nn.functional.interpolate(batch["mask"],
- size=c.shape[-2:])
- c = torch.cat((c, cc), dim=1)
-
- shape = (c.shape[1]-1,)+c.shape[2:]
- samples_ddim, _ = sampler.sample(S=opt.steps,
- conditioning=c,
- batch_size=c.shape[0],
- shape=shape,
- verbose=False)
- x_samples_ddim = model.decode_first_stage(samples_ddim)
-
- image = torch.clamp((batch["image"]+1.0)/2.0,
- min=0.0, max=1.0)
- mask = torch.clamp((batch["mask"]+1.0)/2.0,
- min=0.0, max=1.0)
- predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
- min=0.0, max=1.0)
-
- inpainted = (1-mask)*image+mask*predicted_image
- inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
- Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
diff --git a/examples/tutorial/stable_diffusion/scripts/knn2img.py b/examples/tutorial/stable_diffusion/scripts/knn2img.py
deleted file mode 100644
index e6eaaecab53e..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/knn2img.py
+++ /dev/null
@@ -1,398 +0,0 @@
-import argparse, os, sys, glob
-import clip
-import torch
-import torch.nn as nn
-import numpy as np
-from omegaconf import OmegaConf
-from PIL import Image
-from tqdm import tqdm, trange
-from itertools import islice
-from einops import rearrange, repeat
-from torchvision.utils import make_grid
-import scann
-import time
-from multiprocessing import cpu_count
-
-from ldm.util import instantiate_from_config, parallel_data_prefetch
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.models.diffusion.plms import PLMSSampler
-from ldm.modules.encoders.modules import FrozenClipImageEmbedder, FrozenCLIPTextEmbedder
-
-DATABASES = [
- "openimages",
- "artbench-art_nouveau",
- "artbench-baroque",
- "artbench-expressionism",
- "artbench-impressionism",
- "artbench-post_impressionism",
- "artbench-realism",
- "artbench-romanticism",
- "artbench-renaissance",
- "artbench-surrealism",
- "artbench-ukiyo_e",
-]
-
-
-def chunk(it, size):
- it = iter(it)
- return iter(lambda: tuple(islice(it, size)), ())
-
-
-def load_model_from_config(config, ckpt, verbose=False):
- print(f"Loading model from {ckpt}")
- pl_sd = torch.load(ckpt, map_location="cpu")
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
- sd = pl_sd["state_dict"]
- model = instantiate_from_config(config.model)
- m, u = model.load_state_dict(sd, strict=False)
- if len(m) > 0 and verbose:
- print("missing keys:")
- print(m)
- if len(u) > 0 and verbose:
- print("unexpected keys:")
- print(u)
-
- model.cuda()
- model.eval()
- return model
-
-
-class Searcher(object):
- def __init__(self, database, retriever_version='ViT-L/14'):
- assert database in DATABASES
- # self.database = self.load_database(database)
- self.database_name = database
- self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
- self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
- self.retriever = self.load_retriever(version=retriever_version)
- self.database = {'embedding': [],
- 'img_id': [],
- 'patch_coords': []}
- self.load_database()
- self.load_searcher()
-
- def train_searcher(self, k,
- metric='dot_product',
- searcher_savedir=None):
-
- print('Start training searcher')
- searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
- np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
- k, metric)
- self.searcher = searcher.score_brute_force().build()
- print('Finish training searcher')
-
- if searcher_savedir is not None:
- print(f'Save trained searcher under "{searcher_savedir}"')
- os.makedirs(searcher_savedir, exist_ok=True)
- self.searcher.serialize(searcher_savedir)
-
- def load_single_file(self, saved_embeddings):
- compressed = np.load(saved_embeddings)
- self.database = {key: compressed[key] for key in compressed.files}
- print('Finished loading of clip embeddings.')
-
- def load_multi_files(self, data_archive):
- out_data = {key: [] for key in self.database}
- for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
- for key in d.files:
- out_data[key].append(d[key])
-
- return out_data
-
- def load_database(self):
-
- print(f'Load saved patch embedding from "{self.database_path}"')
- file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
-
- if len(file_content) == 1:
- self.load_single_file(file_content[0])
- elif len(file_content) > 1:
- data = [np.load(f) for f in file_content]
- prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
- n_proc=min(len(data), cpu_count()), target_data_type='dict')
-
- self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
- self.database}
- else:
- raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
-
- print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
-
- def load_retriever(self, version='ViT-L/14', ):
- model = FrozenClipImageEmbedder(model=version)
- if torch.cuda.is_available():
- model.cuda()
- model.eval()
- return model
-
- def load_searcher(self):
- print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
- self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
- print('Finished loading searcher.')
-
- def search(self, x, k):
- if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
- self.train_searcher(k) # quickly fit searcher on the fly for small databases
- assert self.searcher is not None, 'Cannot search with uninitialized searcher'
- if isinstance(x, torch.Tensor):
- x = x.detach().cpu().numpy()
- if len(x.shape) == 3:
- x = x[:, 0]
- query_embeddings = x / np.linalg.norm(x, axis=1)[:, np.newaxis]
-
- start = time.time()
- nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
- end = time.time()
-
- out_embeddings = self.database['embedding'][nns]
- out_img_ids = self.database['img_id'][nns]
- out_pc = self.database['patch_coords'][nns]
-
- out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
- 'img_ids': out_img_ids,
- 'patch_coords': out_pc,
- 'queries': x,
- 'exec_time': end - start,
- 'nns': nns,
- 'q_embeddings': query_embeddings}
-
- return out
-
- def __call__(self, x, n):
- return self.search(x, n)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- # TODO: add n_neighbors and modes (text-only, text-image-retrieval, image-image retrieval etc)
- # TODO: add 'image variation' mode when knn=0 but a single image is given instead of a text prompt?
- parser.add_argument(
- "--prompt",
- type=str,
- nargs="?",
- default="a painting of a virus monster playing guitar",
- help="the prompt to render"
- )
-
- parser.add_argument(
- "--outdir",
- type=str,
- nargs="?",
- help="dir to write results to",
- default="outputs/txt2img-samples"
- )
-
- parser.add_argument(
- "--skip_grid",
- action='store_true',
- help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
- )
-
- parser.add_argument(
- "--ddim_steps",
- type=int,
- default=50,
- help="number of ddim sampling steps",
- )
-
- parser.add_argument(
- "--n_repeat",
- type=int,
- default=1,
- help="number of repeats in CLIP latent space",
- )
-
- parser.add_argument(
- "--plms",
- action='store_true',
- help="use plms sampling",
- )
-
- parser.add_argument(
- "--ddim_eta",
- type=float,
- default=0.0,
- help="ddim eta (eta=0.0 corresponds to deterministic sampling",
- )
- parser.add_argument(
- "--n_iter",
- type=int,
- default=1,
- help="sample this often",
- )
-
- parser.add_argument(
- "--H",
- type=int,
- default=768,
- help="image height, in pixel space",
- )
-
- parser.add_argument(
- "--W",
- type=int,
- default=768,
- help="image width, in pixel space",
- )
-
- parser.add_argument(
- "--n_samples",
- type=int,
- default=3,
- help="how many samples to produce for each given prompt. A.k.a batch size",
- )
-
- parser.add_argument(
- "--n_rows",
- type=int,
- default=0,
- help="rows in the grid (default: n_samples)",
- )
-
- parser.add_argument(
- "--scale",
- type=float,
- default=5.0,
- help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
- )
-
- parser.add_argument(
- "--from-file",
- type=str,
- help="if specified, load prompts from this file",
- )
-
- parser.add_argument(
- "--config",
- type=str,
- default="configs/retrieval-augmented-diffusion/768x768.yaml",
- help="path to config which constructs model",
- )
-
- parser.add_argument(
- "--ckpt",
- type=str,
- default="models/rdm/rdm768x768/model.ckpt",
- help="path to checkpoint of model",
- )
-
- parser.add_argument(
- "--clip_type",
- type=str,
- default="ViT-L/14",
- help="which CLIP model to use for retrieval and NN encoding",
- )
- parser.add_argument(
- "--database",
- type=str,
- default='artbench-surrealism',
- choices=DATABASES,
- help="The database used for the search, only applied when --use_neighbors=True",
- )
- parser.add_argument(
- "--use_neighbors",
- default=False,
- action='store_true',
- help="Include neighbors in addition to text prompt for conditioning",
- )
- parser.add_argument(
- "--knn",
- default=10,
- type=int,
- help="The number of included neighbors, only applied when --use_neighbors=True",
- )
-
- opt = parser.parse_args()
-
- config = OmegaConf.load(f"{opt.config}")
- model = load_model_from_config(config, f"{opt.ckpt}")
-
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- model = model.to(device)
-
- clip_text_encoder = FrozenCLIPTextEmbedder(opt.clip_type).to(device)
-
- if opt.plms:
- sampler = PLMSSampler(model)
- else:
- sampler = DDIMSampler(model)
-
- os.makedirs(opt.outdir, exist_ok=True)
- outpath = opt.outdir
-
- batch_size = opt.n_samples
- n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
- if not opt.from_file:
- prompt = opt.prompt
- assert prompt is not None
- data = [batch_size * [prompt]]
-
- else:
- print(f"reading prompts from {opt.from_file}")
- with open(opt.from_file, "r") as f:
- data = f.read().splitlines()
- data = list(chunk(data, batch_size))
-
- sample_path = os.path.join(outpath, "samples")
- os.makedirs(sample_path, exist_ok=True)
- base_count = len(os.listdir(sample_path))
- grid_count = len(os.listdir(outpath)) - 1
-
- print(f"sampling scale for cfg is {opt.scale:.2f}")
-
- searcher = None
- if opt.use_neighbors:
- searcher = Searcher(opt.database)
-
- with torch.no_grad():
- with model.ema_scope():
- for n in trange(opt.n_iter, desc="Sampling"):
- all_samples = list()
- for prompts in tqdm(data, desc="data"):
- print("sampling prompts:", prompts)
- if isinstance(prompts, tuple):
- prompts = list(prompts)
- c = clip_text_encoder.encode(prompts)
- uc = None
- if searcher is not None:
- nn_dict = searcher(c, opt.knn)
- c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
- if opt.scale != 1.0:
- uc = torch.zeros_like(c)
- if isinstance(prompts, tuple):
- prompts = list(prompts)
- shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
- samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
- conditioning=c,
- batch_size=c.shape[0],
- shape=shape,
- verbose=False,
- unconditional_guidance_scale=opt.scale,
- unconditional_conditioning=uc,
- eta=opt.ddim_eta,
- )
-
- x_samples_ddim = model.decode_first_stage(samples_ddim)
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
-
- for x_sample in x_samples_ddim:
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
- Image.fromarray(x_sample.astype(np.uint8)).save(
- os.path.join(sample_path, f"{base_count:05}.png"))
- base_count += 1
- all_samples.append(x_samples_ddim)
-
- if not opt.skip_grid:
- # additionally, save as grid
- grid = torch.stack(all_samples, 0)
- grid = rearrange(grid, 'n b c h w -> (n b) c h w')
- grid = make_grid(grid, nrow=n_rows)
-
- # to image
- grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
- Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
- grid_count += 1
-
- print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")
diff --git a/examples/tutorial/stable_diffusion/scripts/sample_diffusion.py b/examples/tutorial/stable_diffusion/scripts/sample_diffusion.py
deleted file mode 100644
index 876fe3c3642f..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/sample_diffusion.py
+++ /dev/null
@@ -1,313 +0,0 @@
-import argparse, os, sys, glob, datetime, yaml
-import torch
-import time
-import numpy as np
-from tqdm import trange
-
-from omegaconf import OmegaConf
-from PIL import Image
-
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.util import instantiate_from_config
-
-rescale = lambda x: (x + 1.) / 2.
-
-def custom_to_pil(x):
- x = x.detach().cpu()
- x = torch.clamp(x, -1., 1.)
- x = (x + 1.) / 2.
- x = x.permute(1, 2, 0).numpy()
- x = (255 * x).astype(np.uint8)
- x = Image.fromarray(x)
- if not x.mode == "RGB":
- x = x.convert("RGB")
- return x
-
-
-def custom_to_np(x):
- # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
- sample = x.detach().cpu()
- sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
- sample = sample.permute(0, 2, 3, 1)
- sample = sample.contiguous()
- return sample
-
-
-def logs2pil(logs, keys=["sample"]):
- imgs = dict()
- for k in logs:
- try:
- if len(logs[k].shape) == 4:
- img = custom_to_pil(logs[k][0, ...])
- elif len(logs[k].shape) == 3:
- img = custom_to_pil(logs[k])
- else:
- print(f"Unknown format for key {k}. ")
- img = None
- except:
- img = None
- imgs[k] = img
- return imgs
-
-
-@torch.no_grad()
-def convsample(model, shape, return_intermediates=True,
- verbose=True,
- make_prog_row=False):
-
-
- if not make_prog_row:
- return model.p_sample_loop(None, shape,
- return_intermediates=return_intermediates, verbose=verbose)
- else:
- return model.progressive_denoising(
- None, shape, verbose=True
- )
-
-
-@torch.no_grad()
-def convsample_ddim(model, steps, shape, eta=1.0
- ):
- ddim = DDIMSampler(model)
- bs = shape[0]
- shape = shape[1:]
- samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
- return samples, intermediates
-
-
-@torch.no_grad()
-def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
-
-
- log = dict()
-
- shape = [batch_size,
- model.model.diffusion_model.in_channels,
- model.model.diffusion_model.image_size,
- model.model.diffusion_model.image_size]
-
- with model.ema_scope("Plotting"):
- t0 = time.time()
- if vanilla:
- sample, progrow = convsample(model, shape,
- make_prog_row=True)
- else:
- sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
- eta=eta)
-
- t1 = time.time()
-
- x_sample = model.decode_first_stage(sample)
-
- log["sample"] = x_sample
- log["time"] = t1 - t0
- log['throughput'] = sample.shape[0] / (t1 - t0)
- print(f'Throughput for this batch: {log["throughput"]}')
- return log
-
-def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
- if vanilla:
- print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
- else:
- print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
-
-
- tstart = time.time()
- n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
- # path = logdir
- if model.cond_stage_model is None:
- all_images = []
-
- print(f"Running unconditional sampling for {n_samples} samples")
- for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
- logs = make_convolutional_sample(model, batch_size=batch_size,
- vanilla=vanilla, custom_steps=custom_steps,
- eta=eta)
- n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
- all_images.extend([custom_to_np(logs["sample"])])
- if n_saved >= n_samples:
- print(f'Finish after generating {n_saved} samples')
- break
- all_img = np.concatenate(all_images, axis=0)
- all_img = all_img[:n_samples]
- shape_str = "x".join([str(x) for x in all_img.shape])
- nppath = os.path.join(nplog, f"{shape_str}-samples.npz")
- np.savez(nppath, all_img)
-
- else:
- raise NotImplementedError('Currently only sampling for unconditional models supported.')
-
- print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
-
-
-def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
- for k in logs:
- if k == key:
- batch = logs[key]
- if np_path is None:
- for x in batch:
- img = custom_to_pil(x)
- imgpath = os.path.join(path, f"{key}_{n_saved:06}.png")
- img.save(imgpath)
- n_saved += 1
- else:
- npbatch = custom_to_np(batch)
- shape_str = "x".join([str(x) for x in npbatch.shape])
- nppath = os.path.join(np_path, f"{n_saved}-{shape_str}-samples.npz")
- np.savez(nppath, npbatch)
- n_saved += npbatch.shape[0]
- return n_saved
-
-
-def get_parser():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-r",
- "--resume",
- type=str,
- nargs="?",
- help="load from logdir or checkpoint in logdir",
- )
- parser.add_argument(
- "-n",
- "--n_samples",
- type=int,
- nargs="?",
- help="number of samples to draw",
- default=50000
- )
- parser.add_argument(
- "-e",
- "--eta",
- type=float,
- nargs="?",
- help="eta for ddim sampling (0.0 yields deterministic sampling)",
- default=1.0
- )
- parser.add_argument(
- "-v",
- "--vanilla_sample",
- default=False,
- action='store_true',
- help="vanilla sampling (default option is DDIM sampling)?",
- )
- parser.add_argument(
- "-l",
- "--logdir",
- type=str,
- nargs="?",
- help="extra logdir",
- default="none"
- )
- parser.add_argument(
- "-c",
- "--custom_steps",
- type=int,
- nargs="?",
- help="number of steps for ddim and fastdpm sampling",
- default=50
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- nargs="?",
- help="the bs",
- default=10
- )
- return parser
-
-
-def load_model_from_config(config, sd):
- model = instantiate_from_config(config)
- model.load_state_dict(sd,strict=False)
- model.cuda()
- model.eval()
- return model
-
-
-def load_model(config, ckpt, gpu, eval_mode):
- if ckpt:
- print(f"Loading model from {ckpt}")
- pl_sd = torch.load(ckpt, map_location="cpu")
- global_step = pl_sd["global_step"]
- else:
- pl_sd = {"state_dict": None}
- global_step = None
- model = load_model_from_config(config.model,
- pl_sd["state_dict"])
-
- return model, global_step
-
-
-if __name__ == "__main__":
- now = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
- sys.path.append(os.getcwd())
- command = " ".join(sys.argv)
-
- parser = get_parser()
- opt, unknown = parser.parse_known_args()
- ckpt = None
-
- if not os.path.exists(opt.resume):
- raise ValueError("Cannot find {}".format(opt.resume))
- if os.path.isfile(opt.resume):
- # paths = opt.resume.split("/")
- try:
- logdir = '/'.join(opt.resume.split('/')[:-1])
- # idx = len(paths)-paths[::-1].index("logs")+1
- print(f'Logdir is {logdir}')
- except ValueError:
- paths = opt.resume.split("/")
- idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
- logdir = "/".join(paths[:idx])
- ckpt = opt.resume
- else:
- assert os.path.isdir(opt.resume), f"{opt.resume} is not a directory"
- logdir = opt.resume.rstrip("/")
- ckpt = os.path.join(logdir, "model.ckpt")
-
- base_configs = sorted(glob.glob(os.path.join(logdir, "config.yaml")))
- opt.base = base_configs
-
- configs = [OmegaConf.load(cfg) for cfg in opt.base]
- cli = OmegaConf.from_dotlist(unknown)
- config = OmegaConf.merge(*configs, cli)
-
- gpu = True
- eval_mode = True
-
- if opt.logdir != "none":
- locallog = logdir.split(os.sep)[-1]
- if locallog == "": locallog = logdir.split(os.sep)[-2]
- print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
- logdir = os.path.join(opt.logdir, locallog)
-
- print(config)
-
- model, global_step = load_model(config, ckpt, gpu, eval_mode)
- print(f"global step: {global_step}")
- print(75 * "=")
- print("logging to:")
- logdir = os.path.join(logdir, "samples", f"{global_step:08}", now)
- imglogdir = os.path.join(logdir, "img")
- numpylogdir = os.path.join(logdir, "numpy")
-
- os.makedirs(imglogdir)
- os.makedirs(numpylogdir)
- print(logdir)
- print(75 * "=")
-
- # write config out
- sampling_file = os.path.join(logdir, "sampling_config.yaml")
- sampling_conf = vars(opt)
-
- with open(sampling_file, 'w') as f:
- yaml.dump(sampling_conf, f, default_flow_style=False)
- print(sampling_conf)
-
-
- run(model, imglogdir, eta=opt.eta,
- vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
- batch_size=opt.batch_size, nplog=numpylogdir)
-
- print("done.")
diff --git a/examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py b/examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py
deleted file mode 100644
index a32e66d44cf2..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/tests/test_checkpoint.py
+++ /dev/null
@@ -1,37 +0,0 @@
-import os
-import sys
-from copy import deepcopy
-
-import yaml
-from datetime import datetime
-
-from diffusers import StableDiffusionPipeline
-import torch
-from ldm.util import instantiate_from_config
-from main import get_parser
-
-if __name__ == "__main__":
- with torch.no_grad():
- yaml_path = "../../train_colossalai.yaml"
- with open(yaml_path, 'r', encoding='utf-8') as f:
- config = f.read()
- base_config = yaml.load(config, Loader=yaml.FullLoader)
- unet_config = base_config['model']['params']['unet_config']
- diffusion_model = instantiate_from_config(unet_config).to("cuda:0")
-
- pipe = StableDiffusionPipeline.from_pretrained(
- "/data/scratch/diffuser/stable-diffusion-v1-4"
- ).to("cuda:0")
- dif_model_2 = pipe.unet
-
- random_input_ = torch.rand((4, 4, 32, 32)).to("cuda:0")
- random_input_2 = torch.clone(random_input_).to("cuda:0")
- time_stamp = torch.randint(20, (4,)).to("cuda:0")
- time_stamp2 = torch.clone(time_stamp).to("cuda:0")
- context_ = torch.rand((4, 77, 768)).to("cuda:0")
- context_2 = torch.clone(context_).to("cuda:0")
-
- out_1 = diffusion_model(random_input_, time_stamp, context_)
- out_2 = dif_model_2(random_input_2, time_stamp2, context_2)
- print(out_1.shape)
- print(out_2['sample'].shape)
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py b/examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py
deleted file mode 100644
index f93f8a6e7076..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/tests/test_watermark.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import cv2
-import fire
-from imwatermark import WatermarkDecoder
-
-
-def testit(img_path):
- bgr = cv2.imread(img_path)
- decoder = WatermarkDecoder('bytes', 136)
- watermark = decoder.decode(bgr, 'dwtDct')
- try:
- dec = watermark.decode('utf-8')
- except:
- dec = "null"
- print(dec)
-
-
-if __name__ == "__main__":
- fire.Fire(testit)
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/scripts/train_searcher.py b/examples/tutorial/stable_diffusion/scripts/train_searcher.py
deleted file mode 100644
index 1e7904889c01..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/train_searcher.py
+++ /dev/null
@@ -1,147 +0,0 @@
-import os, sys
-import numpy as np
-import scann
-import argparse
-import glob
-from multiprocessing import cpu_count
-from tqdm import tqdm
-
-from ldm.util import parallel_data_prefetch
-
-
-def search_bruteforce(searcher):
- return searcher.score_brute_force().build()
-
-
-def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
- partioning_trainsize, num_leaves, num_leaves_to_search):
- return searcher.tree(num_leaves=num_leaves,
- num_leaves_to_search=num_leaves_to_search,
- training_sample_size=partioning_trainsize). \
- score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
-
-
-def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
- return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
- reorder_k).build()
-
-def load_datapool(dpath):
-
-
- def load_single_file(saved_embeddings):
- compressed = np.load(saved_embeddings)
- database = {key: compressed[key] for key in compressed.files}
- return database
-
- def load_multi_files(data_archive):
- database = {key: [] for key in data_archive[0].files}
- for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
- for key in d.files:
- database[key].append(d[key])
-
- return database
-
- print(f'Load saved patch embedding from "{dpath}"')
- file_content = glob.glob(os.path.join(dpath, '*.npz'))
-
- if len(file_content) == 1:
- data_pool = load_single_file(file_content[0])
- elif len(file_content) > 1:
- data = [np.load(f) for f in file_content]
- prefetched_data = parallel_data_prefetch(load_multi_files, data,
- n_proc=min(len(data), cpu_count()), target_data_type='dict')
-
- data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
- else:
- raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
-
- print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
- return data_pool
-
-
-def train_searcher(opt,
- metric='dot_product',
- partioning_trainsize=None,
- reorder_k=None,
- # todo tune
- aiq_thld=0.2,
- dims_per_block=2,
- num_leaves=None,
- num_leaves_to_search=None,):
-
- data_pool = load_datapool(opt.database)
- k = opt.knn
-
- if not reorder_k:
- reorder_k = 2 * k
-
- # normalize
- # embeddings =
- searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
- pool_size = data_pool['embedding'].shape[0]
-
- print(*(['#'] * 100))
- print('Initializing scaNN searcher with the following values:')
- print(f'k: {k}')
- print(f'metric: {metric}')
- print(f'reorder_k: {reorder_k}')
- print(f'anisotropic_quantization_threshold: {aiq_thld}')
- print(f'dims_per_block: {dims_per_block}')
- print(*(['#'] * 100))
- print('Start training searcher....')
- print(f'N samples in pool is {pool_size}')
-
- # this reflects the recommended design choices proposed at
- # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
- if pool_size < 2e4:
- print('Using brute force search.')
- searcher = search_bruteforce(searcher)
- elif 2e4 <= pool_size and pool_size < 1e5:
- print('Using asymmetric hashing search and reordering.')
- searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
- else:
- print('Using using partioning, asymmetric hashing search and reordering.')
-
- if not partioning_trainsize:
- partioning_trainsize = data_pool['embedding'].shape[0] // 10
- if not num_leaves:
- num_leaves = int(np.sqrt(pool_size))
-
- if not num_leaves_to_search:
- num_leaves_to_search = max(num_leaves // 20, 1)
-
- print('Partitioning params:')
- print(f'num_leaves: {num_leaves}')
- print(f'num_leaves_to_search: {num_leaves_to_search}')
- # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
- searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
- partioning_trainsize, num_leaves, num_leaves_to_search)
-
- print('Finish training searcher')
- searcher_savedir = opt.target_path
- os.makedirs(searcher_savedir, exist_ok=True)
- searcher.serialize(searcher_savedir)
- print(f'Saved trained searcher under "{searcher_savedir}"')
-
-if __name__ == '__main__':
- sys.path.append(os.getcwd())
- parser = argparse.ArgumentParser()
- parser.add_argument('--database',
- '-d',
- default='data/rdm/retrieval_databases/openimages',
- type=str,
- help='path to folder containing the clip feature of the database')
- parser.add_argument('--target_path',
- '-t',
- default='data/rdm/searchers/openimages',
- type=str,
- help='path to the target folder where the searcher shall be stored.')
- parser.add_argument('--knn',
- '-k',
- default=20,
- type=int,
- help='number of nearest neighbors, for which the searcher shall be optimized')
-
- opt, _ = parser.parse_known_args()
-
- train_searcher(opt,)
\ No newline at end of file
diff --git a/examples/tutorial/stable_diffusion/scripts/txt2img.py b/examples/tutorial/stable_diffusion/scripts/txt2img.py
deleted file mode 100644
index 59c16a1db871..000000000000
--- a/examples/tutorial/stable_diffusion/scripts/txt2img.py
+++ /dev/null
@@ -1,344 +0,0 @@
-import argparse, os, sys, glob
-import cv2
-import torch
-import numpy as np
-from omegaconf import OmegaConf
-from PIL import Image
-from tqdm import tqdm, trange
-from imwatermark import WatermarkEncoder
-from itertools import islice
-from einops import rearrange
-from torchvision.utils import make_grid
-import time
-from pytorch_lightning import seed_everything
-from torch import autocast
-from contextlib import contextmanager, nullcontext
-
-from ldm.util import instantiate_from_config
-from ldm.models.diffusion.ddim import DDIMSampler
-from ldm.models.diffusion.plms import PLMSSampler
-
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from transformers import AutoFeatureExtractor
-
-
-# load safety model
-safety_model_id = "CompVis/stable-diffusion-safety-checker"
-safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
-safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
-
-
-def chunk(it, size):
- it = iter(it)
- return iter(lambda: tuple(islice(it, size)), ())
-
-
-def numpy_to_pil(images):
- """
- Convert a numpy image or a batch of images to a PIL image.
- """
- if images.ndim == 3:
- images = images[None, ...]
- images = (images * 255).round().astype("uint8")
- pil_images = [Image.fromarray(image) for image in images]
-
- return pil_images
-
-
-def load_model_from_config(config, ckpt, verbose=False):
- print(f"Loading model from {ckpt}")
- pl_sd = torch.load(ckpt, map_location="cpu")
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
- sd = pl_sd["state_dict"]
- model = instantiate_from_config(config.model)
- m, u = model.load_state_dict(sd, strict=False)
- if len(m) > 0 and verbose:
- print("missing keys:")
- print(m)
- if len(u) > 0 and verbose:
- print("unexpected keys:")
- print(u)
-
- model.cuda()
- model.eval()
- return model
-
-
-def put_watermark(img, wm_encoder=None):
- if wm_encoder is not None:
- img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
- img = wm_encoder.encode(img, 'dwtDct')
- img = Image.fromarray(img[:, :, ::-1])
- return img
-
-
-def load_replacement(x):
- try:
- hwc = x.shape
- y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
- y = (np.array(y)/255.0).astype(x.dtype)
- assert y.shape == x.shape
- return y
- except Exception:
- return x
-
-
-def check_safety(x_image):
- safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
- x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
- assert x_checked_image.shape[0] == len(has_nsfw_concept)
- for i in range(len(has_nsfw_concept)):
- if has_nsfw_concept[i]:
- x_checked_image[i] = load_replacement(x_checked_image[i])
- return x_checked_image, has_nsfw_concept
-
-
-def main():
- parser = argparse.ArgumentParser()
-
- parser.add_argument(
- "--prompt",
- type=str,
- nargs="?",
- default="a painting of a virus monster playing guitar",
- help="the prompt to render"
- )
- parser.add_argument(
- "--outdir",
- type=str,
- nargs="?",
- help="dir to write results to",
- default="outputs/txt2img-samples"
- )
- parser.add_argument(
- "--skip_grid",
- action='store_true',
- help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
- )
- parser.add_argument(
- "--skip_save",
- action='store_true',
- help="do not save individual samples. For speed measurements.",
- )
- parser.add_argument(
- "--ddim_steps",
- type=int,
- default=50,
- help="number of ddim sampling steps",
- )
- parser.add_argument(
- "--plms",
- action='store_true',
- help="use plms sampling",
- )
- parser.add_argument(
- "--laion400m",
- action='store_true',
- help="uses the LAION400M model",
- )
- parser.add_argument(
- "--fixed_code",
- action='store_true',
- help="if enabled, uses the same starting code across samples ",
- )
- parser.add_argument(
- "--ddim_eta",
- type=float,
- default=0.0,
- help="ddim eta (eta=0.0 corresponds to deterministic sampling",
- )
- parser.add_argument(
- "--n_iter",
- type=int,
- default=2,
- help="sample this often",
- )
- parser.add_argument(
- "--H",
- type=int,
- default=512,
- help="image height, in pixel space",
- )
- parser.add_argument(
- "--W",
- type=int,
- default=512,
- help="image width, in pixel space",
- )
- parser.add_argument(
- "--C",
- type=int,
- default=4,
- help="latent channels",
- )
- parser.add_argument(
- "--f",
- type=int,
- default=8,
- help="downsampling factor",
- )
- parser.add_argument(
- "--n_samples",
- type=int,
- default=3,
- help="how many samples to produce for each given prompt. A.k.a. batch size",
- )
- parser.add_argument(
- "--n_rows",
- type=int,
- default=0,
- help="rows in the grid (default: n_samples)",
- )
- parser.add_argument(
- "--scale",
- type=float,
- default=7.5,
- help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
- )
- parser.add_argument(
- "--from-file",
- type=str,
- help="if specified, load prompts from this file",
- )
- parser.add_argument(
- "--config",
- type=str,
- default="configs/stable-diffusion/v1-inference.yaml",
- help="path to config which constructs model",
- )
- parser.add_argument(
- "--ckpt",
- type=str,
- default="models/ldm/stable-diffusion-v1/model.ckpt",
- help="path to checkpoint of model",
- )
- parser.add_argument(
- "--seed",
- type=int,
- default=42,
- help="the seed (for reproducible sampling)",
- )
- parser.add_argument(
- "--precision",
- type=str,
- help="evaluate at this precision",
- choices=["full", "autocast"],
- default="autocast"
- )
- opt = parser.parse_args()
-
- if opt.laion400m:
- print("Falling back to LAION 400M model...")
- opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml"
- opt.ckpt = "models/ldm/text2img-large/model.ckpt"
- opt.outdir = "outputs/txt2img-samples-laion400m"
-
- seed_everything(opt.seed)
-
- config = OmegaConf.load(f"{opt.config}")
- model = load_model_from_config(config, f"{opt.ckpt}")
-
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- model = model.to(device)
-
- if opt.plms:
- sampler = PLMSSampler(model)
- else:
- sampler = DDIMSampler(model)
-
- os.makedirs(opt.outdir, exist_ok=True)
- outpath = opt.outdir
-
- print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
- wm = "StableDiffusionV1"
- wm_encoder = WatermarkEncoder()
- wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
-
- batch_size = opt.n_samples
- n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
- if not opt.from_file:
- prompt = opt.prompt
- assert prompt is not None
- data = [batch_size * [prompt]]
-
- else:
- print(f"reading prompts from {opt.from_file}")
- with open(opt.from_file, "r") as f:
- data = f.read().splitlines()
- data = list(chunk(data, batch_size))
-
- sample_path = os.path.join(outpath, "samples")
- os.makedirs(sample_path, exist_ok=True)
- base_count = len(os.listdir(sample_path))
- grid_count = len(os.listdir(outpath)) - 1
-
- start_code = None
- if opt.fixed_code:
- start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
-
- precision_scope = autocast if opt.precision=="autocast" else nullcontext
- with torch.no_grad():
- with precision_scope("cuda"):
- with model.ema_scope():
- tic = time.time()
- all_samples = list()
- for n in trange(opt.n_iter, desc="Sampling"):
- for prompts in tqdm(data, desc="data"):
- uc = None
- if opt.scale != 1.0:
- uc = model.get_learned_conditioning(batch_size * [""])
- if isinstance(prompts, tuple):
- prompts = list(prompts)
- c = model.get_learned_conditioning(prompts)
- shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
- samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
- conditioning=c,
- batch_size=opt.n_samples,
- shape=shape,
- verbose=False,
- unconditional_guidance_scale=opt.scale,
- unconditional_conditioning=uc,
- eta=opt.ddim_eta,
- x_T=start_code)
-
- x_samples_ddim = model.decode_first_stage(samples_ddim)
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
- x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
-
- x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
-
- x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
-
- if not opt.skip_save:
- for x_sample in x_checked_image_torch:
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
- img = Image.fromarray(x_sample.astype(np.uint8))
- img = put_watermark(img, wm_encoder)
- img.save(os.path.join(sample_path, f"{base_count:05}.png"))
- base_count += 1
-
- if not opt.skip_grid:
- all_samples.append(x_checked_image_torch)
-
- if not opt.skip_grid:
- # additionally, save as grid
- grid = torch.stack(all_samples, 0)
- grid = rearrange(grid, 'n b c h w -> (n b) c h w')
- grid = make_grid(grid, nrow=n_rows)
-
- # to image
- grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
- img = Image.fromarray(grid.astype(np.uint8))
- img = put_watermark(img, wm_encoder)
- img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
- grid_count += 1
-
- toc = time.time()
-
- print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
- f" \nEnjoy.")
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/tutorial/stable_diffusion/train.sh b/examples/tutorial/stable_diffusion/train.sh
deleted file mode 100644
index 63abcadbf62b..000000000000
--- a/examples/tutorial/stable_diffusion/train.sh
+++ /dev/null
@@ -1,4 +0,0 @@
-HF_DATASETS_OFFLINE=1
-TRANSFORMERS_OFFLINE=1
-
-python main.py --logdir /tmp -t --postfix test -b configs/train_colossalai.yaml
diff --git a/inference b/inference
index 6dadc2a4f293..56b35f3c06ea 160000
--- a/inference
+++ b/inference
@@ -1 +1 @@
-Subproject commit 6dadc2a4f293f4314280d6250463d986536e46ea
+Subproject commit 56b35f3c06eaac11b1bee633d1e836563f74bcea
diff --git a/op_builder/README.md b/op_builder/README.md
index 057da1038555..b7ac6107300c 100644
--- a/op_builder/README.md
+++ b/op_builder/README.md
@@ -15,17 +15,18 @@ Method 2 is good because it allows the user to only build the kernel they actual
## PyTorch Extensions in Colossal-AI
-As mentioned in the section above, our aim is to make these two methods coherently supported in Colossal-AI, meaning that for a kernel should be either built in `setup.py` or during runtime.
-There are mainly two functions used to build extensions.
+The project DeepSpeed (https://github.com/microsoft/DeepSpeed) has proposed a [solution](https://github.com/microsoft/DeepSpeed/tree/master/op_builder)) to support kernel-build during either installation or runtime.
+We have adapted from DeepSpeed's solution to build extensions. The extension build requries two main functions from PyTorch:
1. `torch.utils.cpp_extension.CUDAExtension`: used to build extensions in `setup.py` during `pip install`.
2. `torch.utils.cpp_extension.load`: used to build and load extension during runtime
Please note that the extension build by `CUDAExtension` cannot be loaded by the `load` function and `load` will run its own build again (correct me if I am wrong).
-We have implemented the following conventions:
+Based on the DeepSpeed's work, we have make several modifications and improvements:
1. All pre-built kernels (those installed with `setup.py`) will be found in `colossalai._C`
2. All runtime-built kernels will be found in the default torch extension path, i.e. ~/.cache/colossalai/torch_extensions. (If we put the built kernels in the installed site-package directory, this will make pip uninstall incomplete)
+3. Once a kernel is loaded, we will cache it in the builder to avoid repeated kernel loading.
When loading the built kernel, we will first check if the pre-built one exists. If not, the runtime build will be triggered.
diff --git a/op_builder/builder.py b/op_builder/builder.py
index dc9ea8e115d8..b9f44decc119 100644
--- a/op_builder/builder.py
+++ b/op_builder/builder.py
@@ -1,3 +1,7 @@
+# This code has been adapted from the DeepSpeed library.
+# Copyright (c) Microsoft Corporation.
+
+# Licensed under the MIT License.
import importlib
import os
import time
@@ -5,6 +9,8 @@
from pathlib import Path
from typing import List
+from .utils import check_cuda_availability, check_system_pytorch_cuda_match, print_rank_0
+
class Builder(ABC):
"""
@@ -20,6 +26,9 @@ def __init__(self, name: str, prebuilt_import_path: str):
self.prebuilt_import_path = prebuilt_import_path
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
+ # we store the op as an attribute to avoid repeated building and loading
+ self.cached_op_module = None
+
assert prebuilt_import_path.startswith('colossalai._C'), \
f'The prebuilt_import_path should start with colossalai._C, but got {self.prebuilt_import_path}'
@@ -100,6 +109,35 @@ def import_op(self):
"""
return importlib.import_module(self.prebuilt_import_path)
+ def check_runtime_build_environment(self):
+ """
+ Check whether the system environment is ready for extension compilation.
+ """
+ try:
+ import torch
+ from torch.utils.cpp_extension import CUDA_HOME
+ TORCH_AVAILABLE = True
+ except ImportError:
+ TORCH_AVAILABLE = False
+ CUDA_HOME = None
+
+ if not TORCH_AVAILABLE:
+ raise ModuleNotFoundError(
+ "PyTorch is not found. You need to install PyTorch first in order to build CUDA extensions")
+
+ if CUDA_HOME is None:
+ raise RuntimeError(
+ "CUDA_HOME is not found. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions"
+ )
+
+ # make sure CUDA is available for compilation during
+ cuda_available = check_cuda_availability()
+ if not cuda_available:
+ raise RuntimeError("CUDA is not available on your system as torch.cuda.is_avaible() returns False.")
+
+ # make sure system CUDA and pytorch CUDA match, an error will raised inside the function if not
+ check_system_pytorch_cuda_match(CUDA_HOME)
+
def load(self, verbose=True):
"""
load the kernel during runtime. If the kernel is not built during pip install, it will build the kernel.
@@ -111,16 +149,27 @@ def load(self, verbose=True):
Args:
verbose (bool, optional): show detailed info. Defaults to True.
"""
- from torch.utils.cpp_extension import load
- start_build = time.time()
+ # if the kernel has be compiled and cached, we directly use it
+ if self.cached_op_module is not None:
+ return self.cached_op_module
try:
+ # if the kernel has been pre-built during installation
+ # we just directly import it
op_module = self.import_op()
if verbose:
- print(f"OP {self.prebuilt_import_path} already exists, skip building.")
+ print_rank_0(
+ f"[extension] OP {self.prebuilt_import_path} has been compileed ahead of time, skip building.")
except ImportError:
+ # check environment
+ self.check_runtime_build_environment()
+
+ # time the kernel compilation
+ start_build = time.time()
+
# construct the build directory
import torch
+ from torch.utils.cpp_extension import load
torch_version_major = torch.__version__.split('.')[0]
torch_version_minor = torch.__version__.split('.')[1]
torch_cuda_version = torch.version.cuda
@@ -130,9 +179,7 @@ def load(self, verbose=True):
Path(build_directory).mkdir(parents=True, exist_ok=True)
if verbose:
- print("=========================================================================================")
- print(f"No pre-built kernel is found, build and load the {self.name} kernel during runtime now")
- print("=========================================================================================")
+ print_rank_0(f"[extension] Compiling or loading the JIT-built {self.name} kernel during runtime now")
# load the kernel
op_module = load(name=self.name,
@@ -144,9 +191,14 @@ def load(self, verbose=True):
build_directory=build_directory,
verbose=verbose)
- build_duration = time.time() - start_build
- if verbose:
- print(f"Time to load {self.name} op: {build_duration} seconds")
+ build_duration = time.time() - start_build
+
+ # log jit compilation time
+ if verbose:
+ print_rank_0(f"[extension] Time to compile or load {self.name} op: {build_duration} seconds")
+
+ # cache the built/loaded kernel
+ self.cached_op_module = op_module
return op_module
diff --git a/op_builder/utils.py b/op_builder/utils.py
index b6bada99efe5..4029703e4829 100644
--- a/op_builder/utils.py
+++ b/op_builder/utils.py
@@ -1,29 +1,203 @@
+import os
import re
import subprocess
+import warnings
from typing import List
-def get_cuda_bare_metal_version(cuda_dir):
- raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
- output = raw_output.split()
- release_idx = output.index("release") + 1
- release = output[release_idx].split(".")
- bare_metal_major = release[0]
- bare_metal_minor = release[1][0]
+def print_rank_0(message: str) -> None:
+ """
+ Print on only one process to avoid spamming.
+ """
+ try:
+ import torch.distributed as dist
+ if not dist.is_initialized():
+ is_main_rank = True
+ else:
+ is_main_rank = dist.get_rank() == 0
+ except ImportError:
+ is_main_rank = True
+
+ if is_main_rank:
+ print(message)
+
+
+def get_cuda_version_in_pytorch() -> List[int]:
+ """
+ This function returns the CUDA version in the PyTorch build.
+
+ Returns:
+ The CUDA version required by PyTorch, in the form of tuple (major, minor).
+ """
+ import torch
+
+ try:
+ torch_cuda_major = torch.version.cuda.split(".")[0]
+ torch_cuda_minor = torch.version.cuda.split(".")[1]
+ except:
+ raise ValueError(
+ "[extension] Cannot retrive the CUDA version in the PyTorch binary given by torch.version.cuda")
+ return torch_cuda_major, torch_cuda_minor
+
+
+def get_cuda_bare_metal_version(cuda_dir) -> List[int]:
+ """
+ Get the System CUDA version from nvcc.
+
+ Args:
+ cuda_dir (str): the directory for CUDA Toolkit.
+
+ Returns:
+ The CUDA version required by PyTorch, in the form of tuple (major, minor).
+ """
+ nvcc_path = os.path.join(cuda_dir, 'bin/nvcc')
+
+ if cuda_dir is None:
+ raise ValueError(
+ f"[extension] The argument cuda_dir is None, but expected to be a string. Please make sure your have exported the environment variable CUDA_HOME correctly."
+ )
+
+ # check for nvcc path
+ if not os.path.exists(nvcc_path):
+ raise FileNotFoundError(
+ f"[extension] The nvcc compiler is not found in {nvcc_path}, please make sure you have set the correct value for CUDA_HOME."
+ )
+
+ # parse the nvcc -v output to obtain the system cuda version
+ try:
+ raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
+ output = raw_output.split()
+ release_idx = output.index("release") + 1
+ release = output[release_idx].split(".")
+ bare_metal_major = release[0]
+ bare_metal_minor = release[1][0]
+ except:
+ raise ValueError(
+ f"[extension] Failed to parse the nvcc output to obtain the system CUDA bare metal version. The output for 'nvcc -v' is \n{raw_output}"
+ )
+
+ return bare_metal_major, bare_metal_minor
+
+
+def check_system_pytorch_cuda_match(cuda_dir):
+ bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
+ torch_cuda_major, torch_cuda_minor = get_cuda_version_in_pytorch()
+
+ if bare_metal_major != torch_cuda_major:
+ raise Exception(
+ f'[extension] Failed to build PyTorch extension because the detected CUDA version ({bare_metal_major}.{bare_metal_minor}) '
+ f'mismatches the version that was used to compile PyTorch ({torch_cuda_major}.{torch_cuda_minor}).'
+ 'Please make sure you have set the CUDA_HOME correctly and installed the correct PyTorch in https://pytorch.org/get-started/locally/ .'
+ )
+
+ print(bare_metal_minor != torch_cuda_minor)
+ if bare_metal_minor != torch_cuda_minor:
+ warnings.warn(
+ f"[extension] The CUDA version on the system ({bare_metal_major}.{bare_metal_minor}) does not match with the version ({torch_cuda_major}.{torch_cuda_minor}) torch was compiled with. "
+ "The mismatch is found in the minor version. As the APIs are compatible, we will allow compilation to proceed. "
+ "If you encounter any issue when using the built kernel, please try to build it again with fully matched CUDA versions"
+ )
+ return True
+
+
+def get_pytorch_version() -> List[int]:
+ """
+ This functions finds the PyTorch version.
+
+ Returns:
+ A tuple of integers in the form of (major, minor, patch).
+ """
+ import torch
+ torch_version = torch.__version__.split('+')[0]
+ TORCH_MAJOR = int(torch_version.split('.')[0])
+ TORCH_MINOR = int(torch_version.split('.')[1])
+ TORCH_PATCH = int(torch_version.split('.')[2])
+ return TORCH_MAJOR, TORCH_MINOR, TORCH_PATCH
- return raw_output, bare_metal_major, bare_metal_minor
-def get_cuda_cc_flag() -> List:
- """get_cuda_cc_flag
+def check_pytorch_version(min_major_version, min_minor_version) -> bool:
+ """
+ Compare the current PyTorch version with the minium required version.
+
+ Args:
+ min_major_version (int): the minimum major version of PyTorch required
+ min_minor_version (int): the minimum minor version of PyTorch required
+
+ Returns:
+ A boolean value. The value is True if the current pytorch version is acceptable and False otherwise.
+ """
+ # get pytorch version
+ torch_major, torch_minor, _ = get_pytorch_version()
+
+ # if the
+ if torch_major < min_major_version or (torch_major == min_major_version and torch_minor < min_minor_version):
+ raise RuntimeError(
+ f"[extension] Colossal-AI requires Pytorch {min_major_version}.{min_minor_version} or newer.\n"
+ "The latest stable release can be obtained from https://pytorch.org/get-started/locally/")
+
+
+def check_cuda_availability():
+ """
+ Check if CUDA is available on the system.
+
+ Returns:
+ A boolean value. True if CUDA is available and False otherwise.
+ """
+ import torch
+ return torch.cuda.is_available()
+
+
+def set_cuda_arch_list(cuda_dir):
+ """
+ This function sets the PyTorch TORCH_CUDA_ARCH_LIST variable for ahead-of-time extension compilation.
+ Ahead-of-time compilation occurs when CUDA_EXT=1 is set when running 'pip install'.
+ """
+ cuda_available = check_cuda_availability()
- cc flag for your GPU arch
+ # we only need to set this when CUDA is not available for cross-compilation
+ if not cuda_available:
+ warnings.warn(
+ '\n[extension] PyTorch did not find available GPUs on this system.\n'
+ 'If your intention is to cross-compile, this is not an error.\n'
+ 'By default, Colossal-AI will cross-compile for \n'
+ '1. Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
+ '2. Volta (compute capability 7.0)\n'
+ '3. Turing (compute capability 7.5),\n'
+ '4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n'
+ '\nIf you wish to cross-compile for a single specific architecture,\n'
+ 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
+
+ if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
+ bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
+
+ arch_list = ['6.0', '6.1', '6.2', '7.0', '7.5']
+
+ if int(bare_metal_major) == 11:
+ if int(bare_metal_minor) == 0:
+ arch_list.append('8.0')
+ else:
+ arch_list.append('8.0')
+ arch_list.append('8.6')
+
+ arch_list_str = ';'.join(arch_list)
+ os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
+ return False
+ return True
+
+
+def get_cuda_cc_flag() -> List[str]:
+ """
+ This function produces the cc flags for your GPU arch
+
+ Returns:
+ The CUDA cc flags for compilation.
"""
# only import torch when needed
# this is to avoid importing torch when building on a machine without torch pre-installed
# one case is to build wheel for pypi release
import torch
-
+
cc_flag = []
for arch in torch.cuda.get_arch_list():
res = re.search(r'sm_(\d+)', arch)
@@ -31,12 +205,19 @@ def get_cuda_cc_flag() -> List:
arch_cap = res[1]
if int(arch_cap) >= 60:
cc_flag.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
-
return cc_flag
-def append_nvcc_threads(nvcc_extra_args):
+
+def append_nvcc_threads(nvcc_extra_args: List[str]) -> List[str]:
+ """
+ This function appends the threads flag to your nvcc args.
+
+ Returns:
+ The nvcc compilation flags including the threads flag.
+ """
from torch.utils.cpp_extension import CUDA_HOME
- _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
+
+ bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
return nvcc_extra_args + ["--threads", "4"]
return nvcc_extra_args
diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt
index f9e8960d2eaf..05c0e6ac5e5c 100644
--- a/requirements/requirements-test.txt
+++ b/requirements/requirements-test.txt
@@ -1,5 +1,7 @@
+diffusers
fbgemm-gpu==0.2.0
pytest
+pytest-cov
torchvision
transformers
timm
@@ -8,5 +10,5 @@ torchaudio
torchrec==0.2.0
contexttimer
einops
-triton==2.0.0.dev20221011
+triton==2.0.0.dev20221202
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index cc99257a93e5..8e619ac24477 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -8,3 +8,4 @@ click
fabric
contexttimer
ninja
+torch
diff --git a/setup.py b/setup.py
index 38d5fa91cecd..89a7b0de461b 100644
--- a/setup.py
+++ b/setup.py
@@ -1,110 +1,92 @@
import os
-import re
+import sys
+from datetime import datetime
+from typing import List
from setuptools import find_packages, setup
-from op_builder.utils import get_cuda_bare_metal_version
+from op_builder.utils import (
+ check_cuda_availability,
+ check_pytorch_version,
+ check_system_pytorch_cuda_match,
+ get_cuda_bare_metal_version,
+ get_pytorch_version,
+ set_cuda_arch_list,
+)
try:
import torch
- from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
- print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
- TORCH_MAJOR = int(torch.__version__.split('.')[0])
- TORCH_MINOR = int(torch.__version__.split('.')[1])
-
- if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 10):
- raise RuntimeError("Colossal-AI requires Pytorch 1.10 or newer.\n"
- "The latest stable release can be obtained from https://pytorch.org/")
+ from torch.utils.cpp_extension import CUDA_HOME, BuildExtension
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
CUDA_HOME = None
+# Some constants for installation checks
+MIN_PYTORCH_VERSION_MAJOR = 1
+MIN_PYTORCH_VERSION_MINOR = 10
+THIS_DIR = os.path.dirname(os.path.abspath(__file__))
+BUILD_CUDA_EXT = int(os.environ.get('CUDA_EXT', '0')) == 1
+IS_NIGHTLY = int(os.environ.get('NIGHTLY', '0')) == 1
-# ninja build does not work unless include_dirs are abs path
-this_dir = os.path.dirname(os.path.abspath(__file__))
-build_cuda_ext = False
+# a variable to store the op builder
ext_modules = []
-if int(os.environ.get('CUDA_EXT', '0')) == 1:
+# we do not support windows currently
+if sys.platform == 'win32':
+ raise RuntimeError("Windows is not supported yet. Please try again within the Windows Subsystem for Linux (WSL).")
+
+
+# check for CUDA extension dependencies
+def environment_check_for_cuda_extension_build():
if not TORCH_AVAILABLE:
- raise ModuleNotFoundError("PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions")
+ raise ModuleNotFoundError(
+ "[extension] PyTorch is not found while CUDA_EXT=1. You need to install PyTorch first in order to build CUDA extensions"
+ )
if not CUDA_HOME:
- raise RuntimeError("CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions")
-
- build_cuda_ext = True
-
-
-def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
- raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir)
- torch_binary_major = torch.version.cuda.split(".")[0]
- torch_binary_minor = torch.version.cuda.split(".")[1]
-
- print("\nCompiling cuda extensions with")
- print(raw_output + "from " + cuda_dir + "/bin\n")
-
- if bare_metal_major != torch_binary_major:
- print(f'The detected CUDA version ({raw_output}) mismatches the version that was used to compile PyTorch '
- f'({torch.version.cuda}). CUDA extension will not be installed.')
- return False
-
- if bare_metal_minor != torch_binary_minor:
- print("\nWarning: Cuda extensions are being compiled with a version of Cuda that does "
- "not match the version used to compile Pytorch binaries. "
- f"Pytorch binaries were compiled with Cuda {torch.version.cuda}.\n"
- "In some cases, a minor-version mismatch will not cause later errors: "
- "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. ")
- return True
-
-
-def check_cuda_availability(cuda_dir):
- if not torch.cuda.is_available():
- # https://github.com/NVIDIA/apex/issues/486
- # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query
- # torch.cuda.get_device_capability(), which will fail if you are compiling in an environment
- # without visible GPUs (e.g. during an nvidia-docker build command).
- print(
- '\nWarning: Torch did not find available GPUs on this system.\n',
- 'If your intention is to cross-compile, this is not an error.\n'
- 'By default, Colossal-AI will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n'
- 'Volta (compute capability 7.0), Turing (compute capability 7.5),\n'
- 'and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n'
- 'If you wish to cross-compile for a single specific architecture,\n'
- 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n')
- if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
- _, bare_metal_major, _ = get_cuda_bare_metal_version(cuda_dir)
- if int(bare_metal_major) == 11:
- os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0"
- else:
- os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
- return False
-
- if cuda_dir is None:
- print("nvcc was not found. CUDA extension will not be installed. If you're installing within a container from "
- "https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
- return False
- return True
-
-
-def append_nvcc_threads(nvcc_extra_args):
- _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
- if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
- return nvcc_extra_args + ["--threads", "4"]
- return nvcc_extra_args
-
-
-def fetch_requirements(path):
+ raise RuntimeError(
+ "[extension] CUDA_HOME is not found while CUDA_EXT=1. You need to export CUDA_HOME environment vairable or install CUDA Toolkit first in order to build CUDA extensions"
+ )
+
+ check_system_pytorch_cuda_match(CUDA_HOME)
+ check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)
+ check_cuda_availability()
+
+
+def fetch_requirements(path) -> List[str]:
+ """
+ This function reads the requirements file.
+
+ Args:
+ path (str): the path to the requirements file.
+
+ Returns:
+ The lines in the requirements file.
+ """
with open(path, 'r') as fd:
return [r.strip() for r in fd.readlines()]
-def fetch_readme():
+def fetch_readme() -> str:
+ """
+ This function reads the README.md file in the current directory.
+
+ Returns:
+ The lines in the README file.
+ """
with open('README.md', encoding='utf-8') as f:
return f.read()
-def get_version():
+def get_version() -> str:
+ """
+ This function reads the version.txt and generates the colossalai/version.py file.
+
+ Returns:
+ The library version stored in version.txt.
+ """
+
setup_file_path = os.path.abspath(__file__)
project_path = os.path.dirname(setup_file_path)
version_txt_path = os.path.join(project_path, 'version.txt')
@@ -112,36 +94,65 @@ def get_version():
with open(version_txt_path) as f:
version = f.read().strip()
- if build_cuda_ext:
- torch_version = '.'.join(torch.__version__.split('.')[:2])
- cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME)[1:])
- version += f'+torch{torch_version}cu{cuda_version}'
# write version into version.py
with open(version_py_path, 'w') as f:
f.write(f"__version__ = '{version}'\n")
- return version
+ # look for pytorch and cuda version
+ if BUILD_CUDA_EXT:
+ torch_major, torch_minor, _ = get_pytorch_version()
+ torch_version = f'{torch_major}.{torch_minor}'
+ cuda_version = '.'.join(get_cuda_bare_metal_version(CUDA_HOME))
+ else:
+ torch_version = None
+ cuda_version = None
+
+ # write the version into the python file
+ if torch_version:
+ f.write(f'torch = "{torch_version}"\n')
+ else:
+ f.write('torch = None\n')
+
+ if cuda_version:
+ f.write(f'cuda = "{cuda_version}"\n')
+ else:
+ f.write('cuda = None\n')
+ return version
-if build_cuda_ext:
- build_cuda_ext = check_cuda_availability(CUDA_HOME) and check_cuda_torch_binary_vs_bare_metal(CUDA_HOME)
-if build_cuda_ext:
- # Set up macros for forward/backward compatibility hack around
- # https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
- # and
- # https://github.com/NVIDIA/apex/issues/456
- # https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac
+if BUILD_CUDA_EXT:
+ environment_check_for_cuda_extension_build()
+ set_cuda_arch_list(CUDA_HOME)
from op_builder import ALL_OPS
+ op_names = []
+
+ # load all builders
for name, builder_cls in ALL_OPS.items():
- print(f'===== Building Extension {name} =====')
+ op_names.append(name)
ext_modules.append(builder_cls().builder())
-setup(name='colossalai',
- version=get_version(),
+ # show log
+ op_name_list = ', '.join(op_names)
+ print(f"[extension] loaded builders for {op_name_list}")
+
+# always put not nightly branch as the if branch
+# otherwise github will treat colossalai-nightly as the project name
+# and it will mess up with the dependency graph insights
+if not IS_NIGHTLY:
+ version = get_version()
+ package_name = 'colossalai'
+else:
+ # use date as the nightly version
+ version = datetime.today().strftime('%Y.%m.%d')
+ package_name = 'colossalai-nightly'
+
+setup(name=package_name,
+ version=version,
packages=find_packages(exclude=(
+ 'op_builder',
'benchmark',
'docker',
'tests',
@@ -179,4 +190,9 @@ def get_version():
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: System :: Distributed Computing',
],
- package_data={'colossalai': ['_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*', 'kernel/cuda_native/csrc/kernels/include/*']})
+ package_data={
+ 'colossalai': [
+ '_C/*.pyi', 'kernel/cuda_native/csrc/*', 'kernel/cuda_native/csrc/kernel/*',
+ 'kernel/cuda_native/csrc/kernels/include/*'
+ ]
+ })
diff --git a/tests/kit/__init__.py b/tests/kit/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/kit/model_zoo/__init__.py b/tests/kit/model_zoo/__init__.py
new file mode 100644
index 000000000000..466a2a558829
--- /dev/null
+++ b/tests/kit/model_zoo/__init__.py
@@ -0,0 +1,4 @@
+from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
+from .registry import model_zoo
+
+__all__ = ['model_zoo']
diff --git a/tests/kit/model_zoo/diffusers/__init__.py b/tests/kit/model_zoo/diffusers/__init__.py
new file mode 100644
index 000000000000..288f626a4539
--- /dev/null
+++ b/tests/kit/model_zoo/diffusers/__init__.py
@@ -0,0 +1 @@
+from .diffusers import *
diff --git a/tests/kit/model_zoo/diffusers/diffusers.py b/tests/kit/model_zoo/diffusers/diffusers.py
new file mode 100644
index 000000000000..8aa3f4c6741f
--- /dev/null
+++ b/tests/kit/model_zoo/diffusers/diffusers.py
@@ -0,0 +1,73 @@
+from functools import partial
+
+import diffusers
+import torch
+import transformers
+
+from ..registry import ModelAttribute, model_zoo
+
+BATCH_SIZE = 2
+SEQ_LENGTH = 5
+HEIGHT = 224
+WIDTH = 224
+IN_CHANNELS = 3
+LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
+TIME_STEP = 3
+
+data_vae_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32))
+data_unet_fn = lambda: dict(sample=torch.randn(2, 3, 32, 32), timestep=3)
+
+identity_output = lambda x: x
+
+
+def data_clip_model():
+ input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
+ return dict(input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids)
+
+
+def data_clip_text():
+ input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+def data_clip_vision():
+ pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
+ return dict(pixel_values=pixel_values)
+
+
+model_zoo.register(name='diffusers_auto_encoder_kl',
+ model_fn=diffusers.AutoencoderKL,
+ data_gen_fn=data_vae_fn,
+ output_transform_fn=identity_output)
+
+model_zoo.register(name='diffusers_vq_model',
+ model_fn=diffusers.VQModel,
+ data_gen_fn=data_vae_fn,
+ output_transform_fn=identity_output)
+
+model_zoo.register(name='diffusers_clip_model',
+ model_fn=partial(transformers.CLIPModel, config=transformers.CLIPConfig()),
+ data_gen_fn=data_clip_model,
+ output_transform_fn=identity_output)
+
+model_zoo.register(name='diffusers_clip_text_model',
+ model_fn=partial(transformers.CLIPTextModel, config=transformers.CLIPTextConfig()),
+ data_gen_fn=data_clip_text,
+ output_transform_fn=identity_output)
+
+model_zoo.register(name='diffusers_clip_vision_model',
+ model_fn=partial(transformers.CLIPVisionModel, config=transformers.CLIPVisionConfig()),
+ data_gen_fn=data_clip_vision,
+ output_transform_fn=identity_output)
+
+model_zoo.register(name='diffusers_unet2d_model',
+ model_fn=diffusers.UNet2DModel,
+ data_gen_fn=data_unet_fn,
+ output_transform_fn=identity_output)
diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py
new file mode 100644
index 000000000000..7470327a65b6
--- /dev/null
+++ b/tests/kit/model_zoo/registry.py
@@ -0,0 +1,68 @@
+#!/usr/bin/env python
+from dataclasses import dataclass
+from typing import Callable
+
+__all__ = ['ModelZooRegistry', 'ModelAttributem', 'model_zoo']
+
+
+@dataclass
+class ModelAttribute:
+ """
+ Attributes of a model.
+
+ Args:
+ has_control_flow (bool): Whether the model contains branching in its forward method.
+ has_stochastic_depth_prob (bool): Whether the model contains stochastic depth probability. Often seen in the torchvision models.
+ """
+ has_control_flow: bool = False
+ has_stochastic_depth_prob: bool = False
+
+
+class ModelZooRegistry(dict):
+ """
+ A registry to map model names to model and data generation functions.
+ """
+
+ def register(self,
+ name: str,
+ model_fn: Callable,
+ data_gen_fn: Callable,
+ output_transform_fn: Callable,
+ model_attribute: ModelAttribute = None):
+ """
+ Register a model and data generation function.
+
+ Examples:
+ >>> # Register
+ >>> model_zoo = ModelZooRegistry()
+ >>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
+ >>> # Run the model
+ >>> data = resnresnet18_data_gen() # do not input any argument
+ >>> model = resnet18() # do not input any argument
+ >>> out = model(**data)
+
+ Args:
+ name (str): Name of the model.
+ model_fn (callable): A function that returns a model. **It must not contain any arguments.**
+ output_transform_fn (callable): A function that transforms the output of the model into Dict.
+ data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
+ model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
+ """
+ self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
+
+ def get_sub_registry(self, keyword: str):
+ """
+ Get a sub registry with models that contain the keyword.
+
+ Args:
+ keyword (str): Keyword to filter models.
+ """
+ new_dict = dict()
+
+ for k, v in self.items():
+ if keyword in k:
+ new_dict[k] = v
+ return new_dict
+
+
+model_zoo = ModelZooRegistry()
diff --git a/tests/kit/model_zoo/timm/__init__.py b/tests/kit/model_zoo/timm/__init__.py
new file mode 100644
index 000000000000..c9c85319448d
--- /dev/null
+++ b/tests/kit/model_zoo/timm/__init__.py
@@ -0,0 +1 @@
+from .timm import *
diff --git a/tests/kit/model_zoo/timm/timm.py b/tests/kit/model_zoo/timm/timm.py
new file mode 100644
index 000000000000..b29ac12a6b53
--- /dev/null
+++ b/tests/kit/model_zoo/timm/timm.py
@@ -0,0 +1,159 @@
+import timm.models as tm
+import torch
+
+from ..registry import ModelAttribute, model_zoo
+
+## ==============
+# Register models without control flow
+## ==============
+data_gen_fn = lambda: dict(x=torch.rand(2, 3, 224, 224))
+output_transform_fn = lambda x: dict(output=x)
+
+model_zoo.register(name='timm_resnet',
+ model_fn=tm.resnest.resnest50d,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_beit',
+ model_fn=tm.beit.beit_base_patch16_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_cait',
+ model_fn=tm.cait.cait_s24_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_convmixer',
+ model_fn=tm.convmixer.convmixer_768_32,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_efficientnetv2',
+ model_fn=tm.efficientnet.efficientnetv2_m,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_resmlp',
+ model_fn=tm.resmlp_12_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_vision_transformer',
+ model_fn=tm.vision_transformer.vit_base_patch16_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_deit',
+ model_fn=tm.deit_base_distilled_patch16_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_beitv2',
+ model_fn=tm.beitv2_base_patch16_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_coat',
+ model_fn=tm.coat.coat_lite_mini,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='timm_deit3',
+ model_fn=tm.deit3_base_patch16_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='timm_eca_nfnet',
+ model_fn=tm.eca_nfnet_l0,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_efficientformer',
+ model_fn=tm.efficientformer_l1,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_ese_vovnet19b_dw',
+ model_fn=tm.ese_vovnet19b_dw,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_gmixer_12_224',
+ model_fn=tm.gmixer_12_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_gmlp_b16_224',
+ model_fn=tm.gmlp_b16_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_hardcorenas_a',
+ model_fn=tm.hardcorenas_a,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_hrnet_w18_small',
+ model_fn=tm.hrnet_w18_small,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_inception_v3',
+ model_fn=tm.inception_v3,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_mixer_b16_224',
+ model_fn=tm.mixer_b16_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_nf_ecaresnet101',
+ model_fn=tm.nf_ecaresnet101,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_nf_regnet_b0',
+ model_fn=tm.nf_regnet_b0,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_regnetv_040',
+ model_fn=tm.regnetv_040,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_skresnet18',
+ model_fn=tm.skresnet18,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_tnt_b_patch16_224',
+ model_fn=tm.tnt_b_patch16_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_wide_resnet50_2',
+ model_fn=tm.wide_resnet50_2,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_convit',
+ model_fn=tm.convit_base,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='timm_dm_nfnet',
+ model_fn=tm.dm_nfnet_f0,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+# ==============
+# Register models with control flow
+# ==============
+model_zoo.register(name='timm_convnext',
+ model_fn=tm.convnext.convnext_base,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='timm_vgg',
+ model_fn=tm.vgg.vgg11,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='timm_dpn',
+ model_fn=tm.dpn.dpn68,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='timm_densenet',
+ model_fn=tm.densenet.densenet121,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='timm_rexnet',
+ model_fn=tm.rexnet.rexnet_100,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='timm_swin_transformer',
+ model_fn=tm.swin_transformer.swin_base_patch4_window7_224,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
diff --git a/tests/kit/model_zoo/torchaudio/__init__.py b/tests/kit/model_zoo/torchaudio/__init__.py
new file mode 100644
index 000000000000..082eb9ebb89c
--- /dev/null
+++ b/tests/kit/model_zoo/torchaudio/__init__.py
@@ -0,0 +1 @@
+from .torchaudio import *
diff --git a/tests/kit/model_zoo/torchaudio/torchaudio.py b/tests/kit/model_zoo/torchaudio/torchaudio.py
new file mode 100644
index 000000000000..74611720292f
--- /dev/null
+++ b/tests/kit/model_zoo/torchaudio/torchaudio.py
@@ -0,0 +1,130 @@
+import torch
+import torchaudio.models as tm
+
+from ..registry import ModelAttribute, model_zoo
+
+INPUT_DIM = 80
+IN_FEATURES = 16
+N_TIME = 20
+KERNEL_SIZE = 5
+HOP_LENGTH = 20
+N_CLASSES = 10
+N_FREQ = 16
+N_MELS = 80
+
+
+def conformer_data_gen_fn():
+ lengths = torch.randint(1, 400, (4,))
+ input = torch.rand(4, int(lengths.max()), INPUT_DIM)
+ return dict(input=input, lengths=lengths)
+
+
+transformer_output_transform_fn = lambda outputs: dict(frames=outputs[0], lengths=outputs[1])
+
+model_zoo.register(name='torchaudio_conformer',
+ model_fn=lambda: tm.Conformer(
+ input_dim=INPUT_DIM, num_heads=4, ffn_dim=128, num_layers=4, depthwise_conv_kernel_size=31),
+ data_gen_fn=conformer_data_gen_fn,
+ output_transform_fn=transformer_output_transform_fn)
+
+single_output_transform_fn = lambda output: dict(output=output)
+
+model_zoo.register(name='torchaudio_convtasnet',
+ model_fn=tm.ConvTasNet,
+ data_gen_fn=lambda: dict(input=torch.rand(4, 1, 8)),
+ output_transform_fn=single_output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+
+model_zoo.register(name='torchaudio_deepspeech',
+ model_fn=lambda: tm.DeepSpeech(IN_FEATURES, n_hidden=128, n_class=4),
+ data_gen_fn=lambda: dict(x=torch.rand(4, 1, 10, IN_FEATURES)),
+ output_transform_fn=single_output_transform_fn)
+
+
+def emformer_data_gen_fn():
+ input = torch.rand(4, 400, IN_FEATURES)
+ lengths = torch.randint(1, 200, (4,))
+ return dict(input=input, lengths=lengths)
+
+
+model_zoo.register(
+ name='torchaudio_emformer',
+ model_fn=lambda: tm.Emformer(input_dim=IN_FEATURES, num_heads=4, ffn_dim=128, num_layers=4, segment_length=4),
+ data_gen_fn=emformer_data_gen_fn,
+ output_transform_fn=transformer_output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+
+model_zoo.register(name='torchaudio_wav2letter_waveform',
+ model_fn=lambda: tm.Wav2Letter(input_type='waveform', num_features=40),
+ data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
+ output_transform_fn=single_output_transform_fn)
+
+model_zoo.register(name='torchaudio_wav2letter_mfcc',
+ model_fn=lambda: tm.Wav2Letter(input_type='mfcc', num_features=40),
+ data_gen_fn=lambda: dict(x=torch.rand(4, 40, 400)),
+ output_transform_fn=single_output_transform_fn)
+
+
+def wavernn_data_gen_fn():
+ waveform = torch.rand(4, 1, (N_TIME - KERNEL_SIZE + 1) * HOP_LENGTH)
+ specgram = torch.rand(4, 1, N_FREQ, N_TIME)
+ return dict(waveform=waveform, specgram=specgram)
+
+
+model_zoo.register(name='torchaudio_wavernn',
+ model_fn=lambda: tm.WaveRNN(upsample_scales=[2, 2, 5],
+ n_classes=N_CLASSES,
+ hop_length=HOP_LENGTH,
+ kernel_size=KERNEL_SIZE,
+ n_freq=N_FREQ,
+ n_res_block=2,
+ n_rnn=64,
+ n_fc=64,
+ n_hidden=16,
+ n_output=16),
+ data_gen_fn=wavernn_data_gen_fn,
+ output_transform_fn=single_output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+
+
+def tacotron_data_gen_fn():
+ n_batch = 4
+ max_text_length = 100
+ max_mel_specgram_length = 300
+ tokens = torch.randint(0, 148, (n_batch, max_text_length))
+ token_lengths = max_text_length * torch.ones((n_batch,))
+ mel_specgram = torch.rand(n_batch, N_MELS, max_mel_specgram_length)
+ mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
+ return dict(tokens=tokens,
+ token_lengths=token_lengths,
+ mel_specgram=mel_specgram,
+ mel_specgram_lengths=mel_specgram_lengths)
+
+
+model_zoo.register(
+ name='torchaudio_tacotron',
+ model_fn=lambda: tm.Tacotron2(n_mels=N_MELS),
+ data_gen_fn=tacotron_data_gen_fn,
+ output_transform_fn=lambda outputs: dict(
+ spectrogram_before=outputs[0], spectrogram_after=outputs[1], stop_tokens=outputs[2], attn_weights=outputs[3]),
+ model_attribute=ModelAttribute(has_control_flow=True))
+
+
+def wav2vec_data_gen_fn():
+ batch_size, num_frames = 4, 400
+ waveforms = torch.randn(batch_size, num_frames)
+ lengths = torch.randint(0, num_frames, (batch_size,))
+ return dict(waveforms=waveforms, lengths=lengths)
+
+
+model_zoo.register(name='torchaudio_wav2vec2_base',
+ model_fn=tm.wav2vec2_base,
+ data_gen_fn=wav2vec_data_gen_fn,
+ output_transform_fn=transformer_output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+
+model_zoo.register(name='torchaudio_hubert_base',
+ model_fn=tm.hubert_base,
+ data_gen_fn=wav2vec_data_gen_fn,
+ output_transform_fn=transformer_output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
diff --git a/tests/kit/model_zoo/torchrec/__init__.py b/tests/kit/model_zoo/torchrec/__init__.py
new file mode 100644
index 000000000000..43952e6998cf
--- /dev/null
+++ b/tests/kit/model_zoo/torchrec/__init__.py
@@ -0,0 +1 @@
+from .torchrec import *
diff --git a/tests/kit/model_zoo/torchrec/torchrec.py b/tests/kit/model_zoo/torchrec/torchrec.py
new file mode 100644
index 000000000000..dda563155fca
--- /dev/null
+++ b/tests/kit/model_zoo/torchrec/torchrec.py
@@ -0,0 +1,142 @@
+from collections import namedtuple
+from functools import partial
+
+import torch
+from torchrec.models import deepfm, dlrm
+from torchrec.modules.embedding_configs import EmbeddingBagConfig
+from torchrec.modules.embedding_modules import EmbeddingBagCollection
+from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
+
+from ..registry import ModelAttribute, model_zoo
+
+BATCH = 2
+SHAPE = 10
+
+
+def gen_kt():
+ KT = KeyedTensor(keys=["f1", "f2"], length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
+ return KT
+
+
+# KeyedJaggedTensor
+def gen_kjt():
+ KJT = KeyedJaggedTensor.from_offsets_sync(keys=["f1", "f2"],
+ values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
+ offsets=torch.tensor([0, 2, 4, 6, 8]))
+ return KJT
+
+
+data_gen_fn = lambda: dict(features=torch.rand((BATCH, SHAPE)))
+
+
+def interaction_arch_data_gen_fn():
+ KT = gen_kt()
+ return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KT)
+
+
+def simple_dfm_data_gen_fn():
+ KJT = gen_kjt()
+ return dict(dense_features=torch.rand((BATCH, SHAPE)), sparse_features=KJT)
+
+
+def sparse_arch_data_gen_fn():
+ KJT = gen_kjt()
+ return dict(features=KJT)
+
+
+def output_transform_fn(x):
+ if isinstance(x, KeyedTensor):
+ output = dict()
+ for key in x.keys():
+ output[key] = x[key]
+ return output
+ else:
+ return dict(output=x)
+
+
+def output_transform_fn(x):
+ if isinstance(x, KeyedTensor):
+ output = dict()
+ for key in x.keys():
+ output[key] = x[key]
+ return output
+ else:
+ return dict(output=x)
+
+
+def get_ebc():
+ # EmbeddingBagCollection
+ eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
+ eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
+ return EmbeddingBagCollection(tables=[eb1_config, eb2_config], device=torch.device('cpu'))
+
+
+def sparse_arch_model_fn():
+ ebc = get_ebc()
+ return deepfm.SparseArch(ebc)
+
+
+def simple_deep_fmnn_model_fn():
+ ebc = get_ebc()
+ return deepfm.SimpleDeepFMNN(SHAPE, ebc, SHAPE, SHAPE)
+
+
+def dlrm_model_fn():
+ ebc = get_ebc()
+ return dlrm.DLRM(ebc, SHAPE, [SHAPE, SHAPE], [5, 1])
+
+
+def dlrm_sparsearch_model_fn():
+ ebc = get_ebc()
+ return dlrm.SparseArch(ebc)
+
+
+model_zoo.register(name='deepfm_densearch',
+ model_fn=partial(deepfm.DenseArch, SHAPE, SHAPE, SHAPE),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='deepfm_interactionarch',
+ model_fn=partial(deepfm.FMInteractionArch, SHAPE * 3, ["f1", "f2"], SHAPE),
+ data_gen_fn=interaction_arch_data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='deepfm_overarch',
+ model_fn=partial(deepfm.OverArch, SHAPE),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='deepfm_simpledeepfmnn',
+ model_fn=simple_deep_fmnn_model_fn,
+ data_gen_fn=simple_dfm_data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='deepfm_sparsearch',
+ model_fn=sparse_arch_model_fn,
+ data_gen_fn=sparse_arch_data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='dlrm',
+ model_fn=dlrm_model_fn,
+ data_gen_fn=simple_dfm_data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='dlrm_densearch',
+ model_fn=partial(dlrm.DenseArch, SHAPE, [SHAPE, SHAPE]),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='dlrm_interactionarch',
+ model_fn=partial(dlrm.InteractionArch, 2),
+ data_gen_fn=interaction_arch_data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='dlrm_overarch',
+ model_fn=partial(dlrm.OverArch, SHAPE, [5, 1]),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='dlrm_sparsearch',
+ model_fn=dlrm_sparsearch_model_fn,
+ data_gen_fn=sparse_arch_data_gen_fn,
+ output_transform_fn=output_transform_fn)
diff --git a/tests/kit/model_zoo/torchvision/__init__.py b/tests/kit/model_zoo/torchvision/__init__.py
new file mode 100644
index 000000000000..55d58f97b5d4
--- /dev/null
+++ b/tests/kit/model_zoo/torchvision/__init__.py
@@ -0,0 +1 @@
+from .torchvision import *
diff --git a/tests/kit/model_zoo/torchvision/torchvision.py b/tests/kit/model_zoo/torchvision/torchvision.py
new file mode 100644
index 000000000000..62bda93d5a75
--- /dev/null
+++ b/tests/kit/model_zoo/torchvision/torchvision.py
@@ -0,0 +1,131 @@
+from collections import namedtuple
+
+import torch
+import torchvision
+import torchvision.models as tm
+from packaging import version
+
+from ..registry import ModelAttribute, model_zoo
+
+data_gen_fn = lambda: dict(x=torch.rand(4, 3, 224, 224))
+output_transform_fn = lambda x: dict(output=x)
+
+# special data gen fn
+inception_v3_data_gen_fn = lambda: dict(x=torch.rand(4, 3, 299, 299))
+
+
+# special model fn
+def swin_s():
+ from torchvision.models.swin_transformer import Swin_T_Weights, _swin_transformer
+
+ # adapted from torchvision.models.swin_transformer.swin_small
+ weights = None
+ weights = Swin_T_Weights.verify(weights)
+ progress = True
+
+ return _swin_transformer(
+ patch_size=[4, 4],
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=[7, 7],
+ stochastic_depth_prob=0, # it is originally 0.2, but we set it to 0 to make it deterministic
+ weights=weights,
+ progress=progress,
+ )
+
+
+# special output transform fn
+google_net_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.GoogLeNetOutputs
+ ) else dict(output=x)
+swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
+ for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
+inception_v3_output_transform_fn = lambda x: dict(output=x.logits) if isinstance(x, torchvision.models.InceptionOutputs
+ ) else dict(output=x)
+
+model_zoo.register(name='torchvision_alexnet',
+ model_fn=tm.alexnet,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_densenet121',
+ model_fn=tm.densenet121,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_efficientnet_b0',
+ model_fn=tm.efficientnet_b0,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
+model_zoo.register(name='torchvision_googlenet',
+ model_fn=tm.googlenet,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=google_net_output_transform_fn)
+model_zoo.register(name='torchvision_inception_v3',
+ model_fn=tm.inception_v3,
+ data_gen_fn=inception_v3_data_gen_fn,
+ output_transform_fn=inception_v3_output_transform_fn)
+model_zoo.register(name='torchvision_mobilenet_v2',
+ model_fn=tm.mobilenet_v2,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_mobilenet_v3_small',
+ model_fn=tm.mobilenet_v3_small,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_mnasnet0_5',
+ model_fn=tm.mnasnet0_5,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_resnet18',
+ model_fn=tm.resnet18,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_regnet_x_16gf',
+ model_fn=tm.regnet_x_16gf,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_resnext50_32x4d',
+ model_fn=tm.resnext50_32x4d,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_shufflenet_v2_x0_5',
+ model_fn=tm.shufflenet_v2_x0_5,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_squeezenet1_0',
+ model_fn=tm.squeezenet1_0,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+model_zoo.register(name='torchvision_vgg11',
+ model_fn=tm.vgg11,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+model_zoo.register(name='torchvision_wide_resnet50_2',
+ model_fn=tm.wide_resnet50_2,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+
+if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
+ model_zoo.register(name='torchvision_vit_b_16',
+ model_fn=tm.vit_b_16,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn)
+ model_zoo.register(name='torchvision_convnext_base',
+ model_fn=tm.convnext_base,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
+
+if version.parse(torchvision.__version__) >= version.parse('0.13.0'):
+ model_zoo.register(
+ name='torchvision_swin_s',
+ model_fn=swin_s,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=swin_s_output_output_transform_fn,
+ )
+ model_zoo.register(name='torchvision_efficientnet_v2_s',
+ model_fn=tm.efficientnet_v2_s,
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py
new file mode 100644
index 000000000000..f56ff7ad84eb
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/__init__.py
@@ -0,0 +1,5 @@
+from .albert import *
+from .bert import *
+from .gpt import *
+from .opt import *
+from .t5 import *
diff --git a/tests/kit/model_zoo/transformers/albert.py b/tests/kit/model_zoo/transformers/albert.py
new file mode 100644
index 000000000000..e85f564e376a
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/albert.py
@@ -0,0 +1,85 @@
+import torch
+import transformers
+
+from ..registry import ModelAttribute, model_zoo
+
+# ===============================
+# Register single-sentence ALBERT
+# ===============================
+BATCH_SIZE = 2
+SEQ_LENGTH = 16
+
+
+def data_gen_fn():
+ input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+
+
+output_transform_fn = lambda x: x
+
+config = transformers.AlbertConfig(embedding_size=128,
+ hidden_size=128,
+ num_hidden_layers=2,
+ num_attention_heads=4,
+ intermediate_size=256)
+
+model_zoo.register(name='transformers_albert',
+ model_fn=lambda: transformers.AlbertModel(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_albert_for_pretraining',
+ model_fn=lambda: transformers.AlbertForPreTraining(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_albert_for_masked_lm',
+ model_fn=lambda: transformers.AlbertForMaskedLM(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_albert_for_sequence_classification',
+ model_fn=lambda: transformers.AlbertForSequenceClassification(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_albert_for_token_classification',
+ model_fn=lambda: transformers.AlbertForTokenClassification(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+
+# ===============================
+# Register multi-sentence ALBERT
+# ===============================
+
+
+def data_gen_for_qa():
+ question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
+ tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
+ inputs = tokenizer(question, text, return_tensors="pt")
+ return inputs
+
+
+def data_gen_for_mcq():
+ prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ choice0 = "It is eaten with a fork and a knife."
+ choice1 = "It is eaten while held in the hand."
+ tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
+ encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
+ encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
+ return encoding
+
+
+model_zoo.register(name='transformers_albert_for_question_answering',
+ model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
+ data_gen_fn=data_gen_for_qa,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_albert_for_multiple_choice',
+ model_fn=lambda: transformers.AlbertForMultipleChoice(config),
+ data_gen_fn=data_gen_for_mcq,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
diff --git a/tests/kit/model_zoo/transformers/bert.py b/tests/kit/model_zoo/transformers/bert.py
new file mode 100644
index 000000000000..99135704da70
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/bert.py
@@ -0,0 +1,88 @@
+import torch
+import transformers
+
+from ..registry import ModelAttribute, model_zoo
+
+# ===============================
+# Register single-sentence BERT
+# ===============================
+BATCH_SIZE = 2
+SEQ_LENGTH = 16
+
+
+def data_gen_fn():
+ input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+
+
+output_transform_fn = lambda x: x
+
+config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
+
+# register the BERT variants
+model_zoo.register(name='transformers_bert',
+ model_fn=lambda: transformers.BertModel(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_bert_for_pretraining',
+ model_fn=lambda: transformers.BertForPreTraining(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_bert_lm_head_model',
+ model_fn=lambda: transformers.BertLMHeadModel(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_bert_for_masked_lm',
+ model_fn=lambda: transformers.BertForMaskedLM(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_bert_for_sequence_classification',
+ model_fn=lambda: transformers.BertForSequenceClassification(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_bert_for_token_classification',
+ model_fn=lambda: transformers.BertForTokenClassification(config),
+ data_gen_fn=data_gen_fn,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+
+
+# ===============================
+# Register multi-sentence BERT
+# ===============================
+def data_gen_for_next_sentence():
+ tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
+ prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ next_sentence = "The sky is blue due to the shorter wavelength of blue light."
+ encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
+ return encoding
+
+
+def data_gen_for_mcq():
+ tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
+ prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
+ choice0 = "It is eaten with a fork and a knife."
+ choice1 = "It is eaten while held in the hand."
+ encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
+ encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
+ return encoding
+
+
+# register the following models
+model_zoo.register(name='transformers_bert_for_next_sentence',
+ model_fn=lambda: transformers.BertForNextSentencePrediction(config),
+ data_gen_fn=data_gen_for_next_sentence,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_bert_for_mcq',
+ model_fn=lambda: transformers.BertForMultipleChoice(config),
+ data_gen_fn=data_gen_for_mcq,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py
new file mode 100644
index 000000000000..5ed4fbe70dc9
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/gpt.py
@@ -0,0 +1,57 @@
+import torch
+import transformers
+
+from ..registry import ModelAttribute, model_zoo
+
+# ===============================
+# Register single-sentence GPT
+# ===============================
+BATCH_SIZE = 1 # it can only be 1 as GPT cannot handle batch sizes > 1 if no padding token is defined.
+SEQ_LENGTH = 16
+
+
+def data_gen():
+ input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+
+
+def seq_classification_data_gen():
+ # batch sizes should be 1 if no padding token is defined.
+ input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
+ token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
+ attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
+ return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+
+
+output_transform_fn = lambda x: x
+
+config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
+
+# register the following models
+model_zoo.register(name='transformers_gpt',
+ model_fn=lambda: transformers.GPT2Model(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_gpt_lm',
+ model_fn=lambda: transformers.GPT2LMHeadModel(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_gpt_double_heads',
+ model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_gpt_for_token_classification',
+ model_fn=lambda: transformers.GPT2ForTokenClassification(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_gpt_for_sequence_classification',
+ model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
+ data_gen_fn=seq_classification_data_gen,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
diff --git a/tests/kit/model_zoo/transformers/opt.py b/tests/kit/model_zoo/transformers/opt.py
new file mode 100644
index 000000000000..d9c4a0b3c23c
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/opt.py
@@ -0,0 +1,35 @@
+import torch
+import transformers
+
+from ..registry import ModelAttribute, model_zoo
+
+# ===============================
+# Register single-sentence OPT
+# ===============================
+BATCH_SIZE = 2
+SEQ_LENGTH = 16
+
+
+def data_gen():
+ input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+
+output_transform_fn = lambda x: x
+
+config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
+
+# register the following models
+# transformers.OPTModel,
+# transformers.OPTForCausalLM,
+model_zoo.register(name='transformers_opt',
+ model_fn=lambda: transformers.OPTModel(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_opt_for_causal_lm',
+ model_fn=lambda: transformers.OPTForCausalLM(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py
new file mode 100644
index 000000000000..b81bcad90db8
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/t5.py
@@ -0,0 +1,46 @@
+import torch
+import transformers
+
+from ..registry import ModelAttribute, model_zoo
+
+# ===============================
+# Register single-sentence T5
+# ===============================
+BATCH_SIZE = 2
+SEQ_LENGTH = 16
+
+
+def data_gen():
+ input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
+
+
+def data_gen_for_encoder_only():
+ input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+ return dict(input_ids=input_ids)
+
+
+output_transform_fn = lambda x: x
+
+config = transformers.T5Config(d_model=128, num_layers=2)
+
+# register the following models
+# transformers.T5Model,
+# transformers.T5ForConditionalGeneration,
+# transformers.T5EncoderModel,
+model_zoo.register(name='transformers_t5',
+ model_fn=lambda: transformers.T5Model(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_t5_for_conditional_generation',
+ model_fn=lambda: transformers.T5ForConditionalGeneration(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
+model_zoo.register(name='transformers_t5_encoder_model',
+ model_fn=lambda: transformers.T5EncoderModel(config),
+ data_gen_fn=data_gen_for_encoder_only,
+ output_transform_fn=output_transform_fn,
+ model_attribute=ModelAttribute(has_control_flow=True))
diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py
index 7f6f0c86ad8e..c01de469b8f1 100644
--- a/tests/test_amp/test_naive_fp16.py
+++ b/tests/test_amp/test_naive_fp16.py
@@ -24,7 +24,6 @@ def run_naive_amp():
In this test, we compare the naive fp16 optimizer implemented in colossalai
and fp32 torch optimizer
"""
-
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
diff --git a/tests/test_analyzer/test_fx/__init__.py b/tests/test_analyzer/test_fx/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/test_analyzer/test_fx/test_bias_addition.py b/tests/test_analyzer/test_fx/test_bias_addition.py
new file mode 100644
index 000000000000..61951e9a5da9
--- /dev/null
+++ b/tests/test_analyzer/test_fx/test_bias_addition.py
@@ -0,0 +1,121 @@
+import pytest
+import torch
+from packaging import version
+from torch.utils.checkpoint import checkpoint
+
+from colossalai.testing.utils import parameterize
+
+try:
+ from colossalai._analyzer.fx import symbolic_trace
+except:
+ pass
+
+
+class LinearModel(torch.nn.Module):
+
+ def __init__(self, in_features, out_features, bias):
+ super().__init__()
+ self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
+
+ def forward(self, x):
+ x = self.linear(x)
+ return x
+
+
+class ConvModel(torch.nn.Module):
+
+ def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channel,
+ out_channels,
+ kernel_size,
+ bias=bias,
+ padding=1,
+ stride=2,
+ dilation=2,
+ groups=3)
+ self.conv_transpose = torch.nn.ConvTranspose2d(in_channel,
+ out_channels,
+ kernel_size,
+ bias=bias,
+ padding=1,
+ stride=2,
+ dilation=2,
+ groups=3)
+
+ def forward(self, x, select=0):
+ if select == 0:
+ x = self.conv(x)
+ else:
+ x = self.conv_transpose(x)
+ return x
+
+
+class SiuModel(torch.nn.Module):
+
+ def __init__(self, bias) -> None:
+ super().__init__()
+ self.linear = LinearModel(3, 3, bias)
+ self.conv = ConvModel(3, 6, 3, bias)
+
+ def forward(self, x, select=torch.Tensor([0])):
+ x = self.linear(x)
+ if select:
+ x = checkpoint(self.conv, x, 0)
+ else:
+ x = checkpoint(self.conv, x, 1)
+
+ return x
+
+
+class AddmmModel(torch.nn.Module):
+
+ def __init__(self, alpha, beta) -> None:
+ super().__init__()
+ self.alpha = alpha
+ self.beta = beta
+
+ def forward(self, x):
+ x = torch.addmm(x, x, x, alpha=self.alpha, beta=self.beta)
+ return x
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+@parameterize("bias", [True, False])
+@parameterize("bias_addition_split", [True, False])
+@parameterize("shape", [(3, 3, 3), (3, 3, 3, 3)])
+@parameterize("select", [torch.Tensor([0]), torch.Tensor([1])])
+def test_siu_model(bias, bias_addition_split, shape, select):
+ model = SiuModel(bias=bias)
+ x = torch.rand(shape)
+ gm = symbolic_trace(model,
+ meta_args={'x': x},
+ concrete_args={'select': select},
+ trace_act_ckpt=True,
+ bias_addition_split=bias_addition_split)
+ assert torch.allclose(model(x, select), gm(x)), 'original model and traced model should be the same!'
+ if bias and bias_addition_split:
+ assert '+' in gm.code, 'bias addition should be split!'
+ else:
+ assert '+' not in gm.code, 'bias addition should not be split!'
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+@parameterize("alpha", [1, 2])
+@parameterize("beta", [1, 2])
+@parameterize("bias_addition_split", [True, False])
+@parameterize("shape", [(3, 3), (5, 5)])
+def test_addmm_model(alpha, beta, bias_addition_split, shape):
+ model = AddmmModel(alpha=alpha, beta=beta)
+ x = torch.rand(shape)
+ gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True, bias_addition_split=bias_addition_split)
+ assert torch.allclose(model(x), gm(x)), 'original model and traced model should be the same!'
+ if (alpha == 1 and beta == 1) or not bias_addition_split:
+ assert '*' not in gm.code, 'bias addition should not be split!'
+ elif bias_addition_split:
+ assert '+' in gm.code, 'bias addition should be split!'
+
+
+if __name__ == '__main__':
+ test_siu_model()
+ test_addmm_model()
diff --git a/tests/test_analyzer/test_fx/test_mod_dir.py b/tests/test_analyzer/test_fx/test_mod_dir.py
new file mode 100644
index 000000000000..15e0c2ec21c7
--- /dev/null
+++ b/tests/test_analyzer/test_fx/test_mod_dir.py
@@ -0,0 +1,78 @@
+import pytest
+import torch
+
+try:
+ from colossalai._analyzer.fx import symbolic_trace
+except:
+ pass
+
+
+class LinearModel(torch.nn.Module):
+
+ def __init__(self, in_features, out_features, bias):
+ super().__init__()
+ self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
+
+ def forward(self, x):
+ x = self.linear(x)
+ return x
+
+
+class ConvModel(torch.nn.Module):
+
+ def __init__(self, in_channel, out_channels, kernel_size, bias) -> None:
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channel,
+ out_channels,
+ kernel_size,
+ bias=bias,
+ padding=1,
+ stride=2,
+ dilation=2,
+ groups=3)
+ self.conv_transpose = torch.nn.ConvTranspose2d(out_channels,
+ out_channels,
+ kernel_size,
+ bias=bias,
+ padding=1,
+ stride=2,
+ dilation=2,
+ groups=3)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.conv_transpose(x)
+ return x
+
+
+class AModel(torch.nn.Module):
+
+ def __init__(self, bias) -> None:
+ super().__init__()
+ self.linear_1 = LinearModel(3, 3, bias)
+ self.linear_2 = LinearModel(3, 3, bias)
+ self.conv = ConvModel(3, 6, 3, bias)
+
+ def forward(self, x):
+ for i in range(x.shape[0]):
+ x = self.linear_1(x)
+ x = self.linear_2(x)
+ x = self.conv(x)
+ return x
+
+
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
+@pytest.mark.parametrize("bias", [True, False])
+@pytest.mark.parametrize("bias_addition_split", [True, False])
+@pytest.mark.parametrize("shape", [(3, 3, 3), (3, 3, 3, 3)])
+def test_mod_dir(bias, bias_addition_split, shape):
+ model = AModel(bias=bias)
+ x = torch.rand(shape)
+ gm = symbolic_trace(model, meta_args={'x': x}, bias_addition_split=bias_addition_split)
+ for node in gm.graph.nodes:
+ assert len(node.meta['info'].mod_dir), f"{node} should have non-trivial ``mod_dir``."
+ print(node, node.meta['info'].mod_dir)
+
+
+if __name__ == '__main__':
+ test_mod_dir(True, True, (3, 3, 3))
diff --git a/tests/test_analyzer/test_fx/test_nested_ckpt.py b/tests/test_analyzer/test_fx/test_nested_ckpt.py
new file mode 100644
index 000000000000..c31aab6752f8
--- /dev/null
+++ b/tests/test_analyzer/test_fx/test_nested_ckpt.py
@@ -0,0 +1,55 @@
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+import pytest
+
+try:
+ from colossalai._analyzer.fx import symbolic_trace
+except:
+ pass
+
+
+class MyModule(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.a = nn.Linear(10, 10)
+ self.b = nn.Linear(10, 10)
+ self.c = nn.Linear(10, 10)
+ self.d = nn.Linear(10, 10)
+ self.e = nn.Linear(10, 10)
+
+ def checkpoint_0(self, x):
+ return checkpoint(self.checkpoint_0_0, x) + checkpoint(self.checkpoint_0_1, x) + self.e(x)
+
+ def checkpoint_0_0(self, x):
+ return checkpoint(self.checkpoint_0_0_0, x) + checkpoint(self.checkpoint_0_0_1, x)
+
+ def checkpoint_0_0_0(self, x):
+ return self.a(x) + checkpoint(self.checkpoint_0_0_0_0, x, use_reentrant=False)
+
+ def checkpoint_0_0_0_0(self, x):
+ return self.b(x)
+
+ def checkpoint_0_0_1(self, x):
+ return self.b(x) + self.c(x)
+
+ def checkpoint_0_1(self, x):
+ return self.d(x)
+
+ def forward(self, x):
+ return checkpoint(self.checkpoint_0, x)
+
+
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
+def test_nested_ckpt():
+ model = MyModule()
+ x = torch.rand(10, 10)
+ gm = symbolic_trace(model, meta_args={'x': x}, trace_act_ckpt=True)
+ assert torch.allclose(gm(x), model(x)), "The traced model should generate the same output as the original model."
+ for ckpt_def in filter(lambda s: s.startswith('checkpoint'), dir(model)):
+ assert ckpt_def in gm.code, f"Checkpoint {ckpt_def} should be in the traced code.\n Traced code = {gm.code}"
+
+
+if __name__ == "__main__":
+ test_nested_ckpt()
diff --git a/tests/test_analyzer/test_fx/test_shape_prop.py b/tests/test_analyzer/test_fx/test_shape_prop.py
new file mode 100644
index 000000000000..08f4ff2cbd1f
--- /dev/null
+++ b/tests/test_analyzer/test_fx/test_shape_prop.py
@@ -0,0 +1,65 @@
+import pytest
+import torch
+import torchvision.models as tm
+from packaging import version
+
+from colossalai.testing.utils import parameterize
+from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
+
+try:
+ from colossalai._analyzer._subclasses import MetaTensorMode
+ from colossalai._analyzer.fx import symbolic_trace
+ from colossalai._analyzer.fx.passes.shape_prop import shape_prop_pass
+ from colossalai._analyzer.fx.symbolic_profile import register_shape_impl
+
+ @register_shape_impl(torch.nn.functional.linear)
+ def linear_impl(*args, **kwargs):
+ assert True
+ return torch.nn.functional.linear(*args, **kwargs)
+except:
+ pass
+
+
+def _check_gm_validity(gm: torch.fx.GraphModule):
+ for node in gm.graph.nodes:
+ assert node.meta['info'].outputs, f'In {gm.__class__.__name__}, {node} has no output shape.'
+ if node.op in [
+ 'call_module', # can apply to params
+ 'call_function', # can apply to params
+ 'call_method', # can apply to params
+ ]:
+ assert hasattr(node.meta['info'], 'inputs'), f'In {gm.__class__.__name__}, {node} has no input shape.'
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+@parameterize('m', tm_models)
+def test_torchvision_shape_prop(m):
+ with MetaTensorMode():
+ model = m()
+ data = torch.rand(100, 3, 224, 224)
+ meta_args = {
+ "x": data,
+ }
+ gm = symbolic_trace(model, meta_args=meta_args)
+ shape_prop_pass(gm, data)
+ _check_gm_validity(gm)
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+@parameterize('m', tmm_models)
+def test_timm_shape_prop(m):
+ with MetaTensorMode():
+ model = m()
+ data = torch.rand(100, 3, 224, 224)
+ meta_args = {
+ "x": data,
+ }
+
+ gm = symbolic_trace(model, meta_args=meta_args)
+ shape_prop_pass(gm, data)
+ _check_gm_validity(gm)
+
+
+if __name__ == "__main__":
+ test_torchvision_shape_prop()
+ test_timm_shape_prop()
diff --git a/tests/test_analyzer/test_fx/test_symbolic_profile.py b/tests/test_analyzer/test_fx/test_symbolic_profile.py
new file mode 100644
index 000000000000..be781599f14b
--- /dev/null
+++ b/tests/test_analyzer/test_fx/test_symbolic_profile.py
@@ -0,0 +1,51 @@
+import pytest
+import torch
+import torchvision.models as tm
+from packaging import version
+
+from colossalai.testing.utils import parameterize
+from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
+
+try:
+ from colossalai._analyzer._subclasses import MetaTensorMode
+ from colossalai._analyzer.fx import symbolic_profile, symbolic_trace
+except:
+ pass
+
+
+def _check_gm_validity(gm: torch.fx.GraphModule):
+ for node in gm.graph.nodes:
+ assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.'
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+@parameterize('m', tm_models)
+def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
+ with MetaTensorMode():
+ model = m()
+ data = torch.rand(8, 3, 224, 224)
+ meta_args = {
+ "x": data,
+ }
+ gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split)
+ symbolic_profile(gm, data, verbose=verbose)
+ _check_gm_validity(gm)
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+@parameterize('m', tmm_models)
+def test_timm_profile(m, verbose=False, bias_addition_split=False):
+ with MetaTensorMode():
+ model = m()
+ data = torch.rand(8, 3, 224, 224)
+ meta_args = {
+ "x": data,
+ }
+ gm = symbolic_trace(model, meta_args=meta_args, bias_addition_split=bias_addition_split)
+ symbolic_profile(gm, data, verbose=verbose)
+ _check_gm_validity(gm)
+
+
+if __name__ == "__main__":
+ test_torchvision_profile()
+ test_timm_profile()
diff --git a/tests/test_analyzer/test_fx/zoo.py b/tests/test_analyzer/test_fx/zoo.py
new file mode 100644
index 000000000000..a96aa3949134
--- /dev/null
+++ b/tests/test_analyzer/test_fx/zoo.py
@@ -0,0 +1,53 @@
+import timm.models as tmm
+import torchvision.models as tm
+
+# input shape: (batch_size, 3, 224, 224)
+tm_models = [
+ tm.alexnet,
+ tm.convnext_base,
+ tm.densenet121,
+ # tm.efficientnet_v2_s,
+ # tm.googlenet, # output bad case
+ # tm.inception_v3, # bad case
+ tm.mobilenet_v2,
+ tm.mobilenet_v3_small,
+ tm.mnasnet0_5,
+ tm.resnet18,
+ tm.regnet_x_16gf,
+ tm.resnext50_32x4d,
+ tm.shufflenet_v2_x0_5,
+ tm.squeezenet1_0,
+ # tm.swin_s, # fx bad case
+ tm.vgg11,
+ tm.vit_b_16,
+ tm.wide_resnet50_2,
+]
+
+tmm_models = [
+ tmm.beit_base_patch16_224,
+ tmm.beitv2_base_patch16_224,
+ tmm.cait_s24_224,
+ tmm.coat_lite_mini,
+ tmm.convit_base,
+ tmm.deit3_base_patch16_224,
+ tmm.dm_nfnet_f0,
+ tmm.eca_nfnet_l0,
+ tmm.efficientformer_l1,
+ # tmm.ese_vovnet19b_dw,
+ tmm.gmixer_12_224,
+ tmm.gmlp_b16_224,
+ # tmm.hardcorenas_a,
+ tmm.hrnet_w18_small,
+ tmm.inception_v3,
+ tmm.mixer_b16_224,
+ tmm.nf_ecaresnet101,
+ tmm.nf_regnet_b0,
+ # tmm.pit_b_224, # pretrained only
+ # tmm.regnetv_040,
+ # tmm.skresnet18,
+ # tmm.swin_base_patch4_window7_224, # fx bad case
+ # tmm.tnt_b_patch16_224, # bad case
+ tmm.vgg11,
+ tmm.vit_base_patch16_18x2_224,
+ tmm.wide_resnet50_2,
+]
diff --git a/tests/test_analyzer/test_subclasses/__init__.py b/tests/test_analyzer/test_subclasses/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/test_analyzer/test_subclasses/test_aten.py b/tests/test_analyzer/test_subclasses/test_aten.py
new file mode 100644
index 000000000000..591a8d617580
--- /dev/null
+++ b/tests/test_analyzer/test_subclasses/test_aten.py
@@ -0,0 +1,82 @@
+from typing import Any, Callable, Union
+import pytest
+
+import torch
+import torch.nn as nn
+
+try:
+ from colossalai._analyzer._subclasses import MetaTensor
+except:
+ pass
+
+aten = torch.ops.aten
+
+registered_meta = {
+ ('aten.convolution.default', True): [ # (aten ops, requires_backward)
+ (nn.Conv1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
+ (nn.Conv2d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4)),
+ (nn.Conv3d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4, 4, 4)),
+ (nn.ConvTranspose1d(in_channels=3, out_channels=4, kernel_size=2, padding=1, dilation=2), torch.rand(2, 3, 4)),
+ (nn.ConvTranspose2d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
+ dilation=2), torch.rand(2, 3, 4, 4)),
+ (nn.ConvTranspose3d(in_channels=3, out_channels=4, kernel_size=2, padding=1,
+ dilation=2), torch.rand(2, 3, 4, 4, 4)),
+ ],
+ ('aten.native_batch_norm.default', True): [
+ (nn.BatchNorm1d(4), torch.rand(2, 4)),
+ (nn.BatchNorm2d(4), torch.rand(1, 4, 4, 4)),
+ (nn.BatchNorm3d(4), torch.rand(1, 4, 4, 4, 4)),
+ ],
+ ('aten.native_layer_norm.default', True): [(nn.LayerNorm(4), torch.rand(1, 2, 3, 4)),],
+ ('aten.avg_pool1d.default', True): [
+ (nn.MaxPool1d(3, stride=2), torch.rand(4, 5, 5)),
+ (nn.AvgPool1d(3, stride=2), torch.rand(4, 5, 5)),
+ (nn.AdaptiveMaxPool1d(3), torch.rand(4, 5, 5)),
+ (nn.AdaptiveAvgPool1d(3), torch.rand(4, 5, 5)),
+ ],
+ ('aten.avg_pool2d.default', True): [
+ (nn.MaxPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
+ (nn.AvgPool2d((3, 2), stride=(2, 1)), torch.rand(2, 4, 5, 5)),
+ (nn.AdaptiveMaxPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
+ (nn.AdaptiveAvgPool2d((3, 2)), torch.rand(2, 4, 5, 5)),
+ ],
+ ('aten.relu.default', True): [
+ (nn.ReLU(), torch.rand(4, 3, 1, 2)),
+ (nn.LeakyReLU(), torch.rand(4, 3, 1, 2)),
+ (nn.SiLU(), torch.rand(4, 3, 1, 2)),
+ (nn.GELU(), torch.rand(4, 3, 1, 2)),
+ (nn.ELU(), torch.rand(4, 3, 1, 2)),
+ (nn.Sigmoid(), torch.rand(4, 3, 1, 2)),
+ (nn.Tanh(), torch.rand(4, 3, 1, 2)),
+ (nn.Hardswish(), torch.rand(4, 3, 1, 2)),
+ ]
+}
+
+
+def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
+ assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
+ assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
+ assert tensor.stride() == meta_tensor.stride(
+ ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
+
+
+def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_backward=False) -> Any:
+ x.requires_grad = requires_backward
+ meta_x = MetaTensor(x)
+ x_out, meta_out = f(x), f(meta_x)
+ compare_all(x_out, meta_out)
+ if requires_backward:
+ x_out.sum().backward()
+ meta_out.sum().backward()
+ compare_all(x.grad, meta_x.grad)
+
+
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
+def test_meta_aten():
+ for (aten_op, requires_backward), v in registered_meta.items():
+ for f, x in v:
+ run_and_compare(f, x, requires_backward)
+
+
+if __name__ == '__main__':
+ test_meta_aten()
diff --git a/tests/test_analyzer/test_subclasses/test_flop_tensor.py b/tests/test_analyzer/test_subclasses/test_flop_tensor.py
new file mode 100644
index 000000000000..752836141fe7
--- /dev/null
+++ b/tests/test_analyzer/test_subclasses/test_flop_tensor.py
@@ -0,0 +1,51 @@
+import pytest
+import torch
+import torch.nn.functional as F
+import torchvision.models as tm
+from packaging import version
+
+from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
+
+try:
+ from colossalai._analyzer._subclasses import MetaTensorMode, flop_count
+except:
+ pass
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+@pytest.mark.parametrize('m', tm_models + tmm_models)
+def test_flop_count_module(m):
+ x = torch.rand(2, 3, 224, 224)
+ with MetaTensorMode(): # save time for testing
+ module = m()
+ rs_fwd, rs_bwd = flop_count(module, x, verbose=True)
+ assert rs_fwd > 0, f'fwd flop count of {m.__name__} is {rs_fwd}'
+ assert rs_bwd > 0, f'bwd flop count of {m.__name__} is {rs_bwd}'
+
+
+odd_cases = [
+ (F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
+ 'inplace': True
+ }),
+ (F.max_pool2d, (torch.rand(2, 3, 224, 224, requires_grad=True),), {
+ 'kernel_size': 3,
+ 'stride': 2,
+ 'padding': 1,
+ 'dilation': 2
+ }),
+ (torch.where, (torch.rand(2, 3, 224, 224) > 0.5, torch.rand(2, 3, 224, 224, requires_grad=True),
+ torch.rand(2, 3, 224, 224, requires_grad=True)), {}),
+]
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+@pytest.mark.parametrize('func, args, kwargs', odd_cases)
+def test_flop_count_function(func, args, kwargs):
+ rs_fwd, rs_bwd = flop_count(func, *args, **kwargs, verbose=True)
+ assert rs_fwd > 0, f'fwd flop count of {func.__name__} is {rs_fwd}'
+ assert rs_bwd > 0, f'bwd flop count of {func.__name__} is {rs_bwd}'
+
+
+if __name__ == '__main__':
+ test_flop_count_module(tm.resnet18)
+ test_flop_count_function(F.relu, (torch.rand(2, 3, 224, 224, requires_grad=True),), {'inplace': True})
diff --git a/tests/test_analyzer/test_subclasses/test_meta_mode.py b/tests/test_analyzer/test_subclasses/test_meta_mode.py
new file mode 100644
index 000000000000..160d411f6c39
--- /dev/null
+++ b/tests/test_analyzer/test_subclasses/test_meta_mode.py
@@ -0,0 +1,39 @@
+import pytest
+import torch
+import torchvision.models as tm
+from packaging import version
+
+try:
+ from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
+except:
+ pass
+from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
+
+
+def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
+ assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
+ assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
+ assert tensor.stride() == meta_tensor.stride(
+ ), f'the stride of tensor ({tensor.stride()}) and meta tensor ({meta_tensor.stride()}) does not match.'
+
+
+def run_and_compare(model):
+ x = torch.rand(2, 3, 224, 224, requires_grad=True)
+ x_out = model(x)
+ with MetaTensorMode():
+ meta_x = torch.rand(2, 3, 224, 224, requires_grad=True)
+ meta_out = model(meta_x)
+ compare_all(x_out, meta_out)
+ x_out.sum().backward()
+ meta_out.sum().backward()
+ compare_all(x.grad, meta_x.grad)
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+@pytest.mark.parametrize('m', tm_models + tmm_models)
+def test_meta_mode_shape(m):
+ run_and_compare(m())
+
+
+if __name__ == '__main__':
+ test_meta_mode_shape(tm.resnet18)
diff --git a/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
similarity index 94%
rename from tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py
rename to tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
index 773cf151d2e9..f8dd0b16b7f6 100644
--- a/tests/test_fx/test_ckpt_solvers/test_C_solver_consistency.py
+++ b/tests/test_auto_parallel/test_ckpt_solvers/test_C_solver_consistency.py
@@ -1,16 +1,17 @@
import copy
-import colossalai
import pytest
import torch
import torch.fx
import torch.multiprocessing as mp
import torchvision.models as tm
+
+import colossalai
from colossalai.core import global_context as gpc
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
-from colossalai.fx.passes.algorithms import solver_rotor
-from colossalai.fx.passes.algorithms.operation import Sequence
+# from colossalai.fx.passes.algorithms import solver_rotor
+# from colossalai.fx.passes.algorithms.operation import Sequence
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
@@ -67,6 +68,7 @@ def _run_C_solver_consistency_test(rank=0):
gpc.destroy()
+@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not withcodegen, reason="torch version is less than 1.12.0")
def test_C_solver_consistency():
mp.spawn(_run_C_solver_consistency_test, nprocs=1)
diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
similarity index 97%
rename from tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py
rename to tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
index 9949d49c1e01..89600ea098a9 100644
--- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py
+++ b/tests/test_auto_parallel/test_ckpt_solvers/test_ckpt_torchvision.py
@@ -13,7 +13,7 @@
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
-from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
+# from colossalai.fx.passes.algorithms import chen_greedy, solver_rotor
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.utils import free_port
@@ -28,7 +28,8 @@
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False
-SOLVERS = [chen_greedy, solver_rotor]
+# SOLVERS = [chen_greedy, solver_rotor]
+SOLVERS = []
def _is_activation_checkpoint_available(gm: GraphModule):
diff --git a/tests/test_fx/test_ckpt_solvers/test_linearize.py b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py
similarity index 95%
rename from tests/test_fx/test_ckpt_solvers/test_linearize.py
rename to tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py
index a803f8c07277..0f90ba0b0989 100644
--- a/tests/test_fx/test_ckpt_solvers/test_linearize.py
+++ b/tests/test_auto_parallel/test_ckpt_solvers/test_linearize.py
@@ -1,11 +1,12 @@
import pytest
import torch
import torchvision.models as tm
+
from colossalai.fx import ColoTracer
from colossalai.fx._compatibility import is_compatible_with_meta
from colossalai.fx.graph_module import ColoGraphModule
-from colossalai.fx.passes.algorithms import linearize, solver_rotor
-from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
+# from colossalai.fx.passes.algorithms import linearize, solver_rotor
+# from colossalai.fx.passes.algorithms.operation import (ForwardCheck, ForwardEnable, ForwardNograd, Loss)
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
if is_compatible_with_meta():
@@ -21,6 +22,7 @@
@pytest.mark.skip(reason='TODO: modify the logger')
+@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}
@@ -79,6 +81,7 @@ def test_linearize():
del node_list
+@pytest.mark.skip("TODO(lyl): refactor all tests.")
@pytest.mark.skip(reason="torch11 meta tensor not implemented")
@pytest.mark.skipif(with_codegen, reason="torch version is equal to or higher than 1.12.0")
def test_linearize_torch11():
diff --git a/tests/test_auto_parallel/test_offload/model_utils.py b/tests/test_auto_parallel/test_offload/model_utils.py
new file mode 100644
index 000000000000..c22b17ae42ba
--- /dev/null
+++ b/tests/test_auto_parallel/test_offload/model_utils.py
@@ -0,0 +1,86 @@
+import torch
+import torch.nn as nn
+from transformers import GPT2Config, GPT2LMHeadModel
+from transformers import BertConfig, BertLMHeadModel
+from tests.components_to_test.registry import non_distributed_component_funcs
+
+class GPTLMModel(nn.Module):
+
+ def __init__(self,
+ hidden_size=768,
+ num_layers=12,
+ num_attention_heads=12,
+ max_seq_len=1024,
+ vocab_size=50257):
+ super().__init__()
+ self.model = GPT2LMHeadModel(
+ GPT2Config(n_embd=hidden_size,
+ n_layer=num_layers,
+ n_head=num_attention_heads,
+ n_positions=max_seq_len,
+ n_ctx=max_seq_len,
+ vocab_size=vocab_size))
+
+ def forward(self, input_ids, attention_mask):
+ # Only return lm_logits
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
+
+
+class LMLoss(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.loss_fn = nn.CrossEntropyLoss()
+
+ def forward(self, logits, labels):
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
+
+class BertLMModel(nn.Module):
+ def __init__(self, hidden_size=768, num_layers=12, num_attention_heads=32, vocab_size=30522):
+ super().__init__()
+ self.model = BertLMHeadModel(BertConfig(n_embd=hidden_size, num_hidden_layers=num_layers, hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads, max_position_embeddings=hidden_size,
+ vocab_size=vocab_size))
+
+ def forward(self, input_ids, attention_mask):
+ # Only return lm_logits
+ return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
+
+@non_distributed_component_funcs.register(name='bert_')
+def get_bert_components():
+ vocab_size = 1024
+ seq_len = 64
+ batchSize = 64
+
+ def bert_model_builder():
+ model = BertLMModel(hidden_size=8192, num_layers=4, num_attention_heads=32, vocab_size=vocab_size)
+ return model
+
+ def bert_data_gen(device="meta"):
+ input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
+ attention_mask = torch.ones_like(input_ids, device=device)
+ kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
+ return kwargs
+
+ return bert_model_builder, bert_data_gen
+
+@non_distributed_component_funcs.register(name='gpt2_')
+def get_gpt2_components():
+ vocab_size = 1024
+ seq_len = 8
+ batchSize = 64
+
+ def gpt2_model_builder():
+ model = GPTLMModel(hidden_size=8192, num_layers=2, num_attention_heads=32, vocab_size=vocab_size)
+ return model
+
+ def gpt2_data_gen(device="meta"):
+ input_ids = torch.randint(0, vocab_size, (batchSize, seq_len), device=device)
+ attention_mask = torch.ones_like(input_ids, device=device)
+ kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
+ return kwargs
+
+ return gpt2_model_builder, gpt2_data_gen
\ No newline at end of file
diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py
new file mode 100644
index 000000000000..d569570f4b7d
--- /dev/null
+++ b/tests/test_auto_parallel/test_offload/test_perf.py
@@ -0,0 +1,150 @@
+import time
+import pytest
+from functools import partial
+
+import torch
+from torch.utils._pytree import tree_map
+import torch.multiprocessing as mp
+
+import colossalai
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.fx.profiler import parameter_size
+from colossalai.utils.model.colo_init_context import ColoInitContext
+from colossalai.utils import free_port, get_current_device
+from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
+from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
+from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
+from colossalai.auto_parallel.offload.solver import NOT_NVML
+from colossalai.testing import parameterize
+
+from tests.test_tensor.common_utils import set_seed
+from tests.test_auto_parallel.test_offload.model_utils import *
+
+
+@parameterize('model_name', ['gpt2_'])
+@parameterize('memory_budget', [5000])
+@parameterize('solver_name', ['asyn'])
+def exam_fwd_bwd(
+ model_name: str,
+ memory_budget: float,
+ solver_name: str
+):
+
+ # build model
+ get_components_func = non_distributed_component_funcs.get_callable(model_name)
+ model_builder, data_gen = get_components_func()
+ label = torch.randint(low=0, high=128, size=(64, 8,), device=get_current_device())
+ criterion = LMLoss()
+
+ set_seed(42)
+ start_time = time.time()
+ model = model_builder()
+ model.train()
+ param_size = parameter_size(model) / 1024 ** 2 / 2
+ init_time = time.time() - start_time
+ print(f"init_param_size={param_size:.3f} MB | init_model_time={init_time:.3f} s")
+
+ data_args = data_gen(device="cpu")
+ wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
+ data_args = tree_map(wrap_fn, data_args)
+ start_time = time.time()
+ model = memory_optimize(model, data_args, memory_budget * 1024 * 1024, solver_name)
+ solver_time = time.time() - start_time
+ print(f"solver_time={solver_time:.3f} s")
+
+ hybrid_optimizer = HybridAdam(model.model.parameters(), lr=1e-3)
+ optim = AMPOptimizer(hybrid_optimizer, model)
+
+ with ColoInitContext(device=torch.device('cpu')):
+ gemini_model = model_builder()
+ gemini_model.train()
+
+ hybrid_optimizer = HybridAdam(gemini_model.parameters(), lr=1e-3)
+ gemini_config = dict(strict_ddp_mode=False,
+ device=torch.device('cpu'),
+ placement_policy='cpu',
+ pin_memory=True,
+ hidden_dim=8192,
+ search_range_mb=128)
+ gemini_model = zero_model_wrapper(gemini_model, 3, gemini_config)
+ optim_config = dict(reduce_bucket_size=12 * 1024 * 1024, overlap_communication=True, verbose=True)
+ gemini_optim = zero_optim_wrapper(gemini_model, hybrid_optimizer, optim_config=optim_config)
+
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.cuda.reset_peak_memory_stats()
+
+ # test gemini
+ time_list = []
+ set_seed(42)
+ data_args = data_gen(device="cuda")
+ for step in range(10):
+ gemini_optim.zero_grad()
+ torch.cuda.synchronize()
+ start_time = time.time()
+ gemini_out = gemini_model(**data_args)
+ gemini_loss = criterion(gemini_out, label)
+ gemini_optim.backward(gemini_loss)
+ torch.cuda.synchronize()
+ time_list.append(time.time() - start_time)
+ gemini_optim.step()
+
+ torch.cuda.synchronize()
+
+ exec_time = sum(sorted(time_list)[:5]) / 5
+ runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
+ runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
+ print(f'gemini | model_name: {model_name}')
+ print(
+ f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
+ f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
+ )
+ print(time_list)
+
+ del data_args
+ del gemini_model
+ del gemini_optim
+ del gemini_out
+ del gemini_loss
+
+ # test asyn offload
+ torch.cuda.empty_cache()
+ torch.cuda.synchronize()
+ torch.cuda.reset_peak_memory_stats()
+
+ time_list = []
+ set_seed(42)
+ data_args = data_gen(device="cuda")
+ data_args = tree_map(wrap_fn, data_args)
+ for step in range(10):
+ optim.zero_grad()
+ torch.cuda.synchronize()
+ start_time = time.time()
+ loss = criterion(model(**data_args), label)
+ optim.backward(loss)
+ torch.cuda.synchronize()
+ time_list.append(time.time() - start_time)
+ optim.step()
+
+ torch.cuda.synchronize()
+
+ exec_time = sum(sorted(time_list)[:5]) / 5
+ runtime_peak_mem_alc = torch.cuda.max_memory_allocated() / 1024 ** 2
+ runtime_peak_mem_res = torch.cuda.max_memory_reserved() / 1024 ** 2
+ print(f'solver_name: {solver_name} | model_name: {model_name}')
+ print(
+ f'| exec_time={exec_time:.3f} s | param_size={param_size:.3f} MB '
+ f'| runtime_peak_mem_alc={runtime_peak_mem_alc:.3f} MB| runtime_peak_mem_res={runtime_peak_mem_res:.3f} MB|'
+ )
+ print(time_list)
+
+@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
+def test_perf(rank, world_size, port):
+ config = {}
+ colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ exam_fwd_bwd()
+
+
+if __name__ == '__main__':
+ run_func = partial(test_perf, world_size=1, port=free_port())
+ mp.spawn(run_func, nprocs=1)
diff --git a/tests/test_auto_parallel/test_offload/test_solver.py b/tests/test_auto_parallel/test_offload/test_solver.py
new file mode 100644
index 000000000000..2efbb750f80d
--- /dev/null
+++ b/tests/test_auto_parallel/test_offload/test_solver.py
@@ -0,0 +1,62 @@
+import pytest
+import torch.fx
+from torch.fx import GraphModule
+from torch.utils._pytree import tree_map
+
+from colossalai.fx import ColoTracer, is_compatible_with_meta
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.auto_parallel.offload.region_manager import RegionManager
+from colossalai.auto_parallel.offload.solver import SolverFactory, NOT_NVML
+from colossalai.testing import parameterize
+from tests.test_auto_parallel.test_offload.model_utils import *
+
+@pytest.mark.skipif(NOT_NVML, reason='pynvml is not installed')
+@parameterize('model_name', ['gpt2_', 'bert_'])
+@parameterize('memory_budget', [4000])
+@parameterize('solver_name', ['syn', 'asyn'])
+def solver_test(model_name: str,
+ memory_budget: float,
+ solver_name: str):
+
+ get_components_func = non_distributed_component_funcs.get_callable(model_name)
+ model_builder, data_gen = get_components_func()
+ data_args = data_gen(device="cpu")
+ wrap_fn = lambda x: x.to(dtype=torch.half) if isinstance(x, torch.Tensor) and torch.is_floating_point(x) else x
+ data_args = tree_map(wrap_fn, data_args)
+ model = model_builder()
+ model.train()
+ model = model.cpu().half()
+
+ tracer = ColoTracer()
+ assert is_compatible_with_meta()
+ wrap_fn = lambda x: x.to("meta") if isinstance(x, torch.Tensor) else x
+ meta_args = tree_map(wrap_fn, data_args)
+ graph = tracer.trace(model, meta_args=meta_args)
+ gm = GraphModule(model, graph, model.__class__.__name__)
+
+ interp = MetaInfoProp(gm)
+ interp.propagate(*meta_args.values())
+
+ region_manager = RegionManager(graph, solver_name=solver_name)
+ region_manager._pre_process()
+ region_list = region_manager.region_list
+
+ solver_cls = SolverFactory.create(solver_name)
+ memory_budget = memory_budget * 1024 * 1024
+ solver = solver_cls(region_list, memory_budget)
+ solver._call_solver()
+
+ assert solver.best_ts.peak_mem < memory_budget
+
+ print("****************** execution plan *******************")
+ for region in region_list:
+ need_offload = region.need_offload
+ to_prefetch = region.fwd_prefetch_region.r_id if region.fwd_prefetch_region is not None else None
+ print(f'| {model_name} forward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
+ for region in region_list.__reversed__():
+ need_offload = region.need_offload
+ to_prefetch = region.bwd_prefetch_region.r_id if region.bwd_prefetch_region is not None else None
+ print(f'| {model_name} backward | region id: {region.r_id} | need_offload: {need_offload} | to_prefetch: {to_prefetch}')
+
+if __name__ == '__main__':
+ solver_test()
\ No newline at end of file
diff --git a/tests/test_auto_parallel/test_pass/__init__.py b/tests/test_auto_parallel/test_pass/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/test_auto_parallel/test_pass/test_node_converting_pass.py b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py
new file mode 100644
index 000000000000..d0d107610f7a
--- /dev/null
+++ b/tests/test_auto_parallel/test_pass/test_node_converting_pass.py
@@ -0,0 +1,54 @@
+import torch
+import torch.nn.functional as F
+
+from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.tracer import ColoTracer
+from colossalai.tensor.sharding_spec import ShardingSpec
+
+
+class TestModule(torch.nn.Module):
+
+ def forward(self, x):
+ x = x.view(4, 4, 2)
+ return x
+
+
+def insert_narrow(gm, x_node):
+ graph = gm.graph
+ with graph.inserting_after(x_node):
+ shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
+ view_node = list(x_node.users.keys())[0]
+ new_args = list(view_node.args)
+ new_args[0] = shard_node
+ view_node.args = tuple(new_args)
+ return gm
+
+
+def test_node_args_converting_pass():
+ model = TestModule()
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
+ meta_args = {'x': torch.rand(4, 8).to('meta')}
+ input = torch.rand(4, 8)
+ tracer = ColoTracer()
+ graph = tracer.trace(root=model, meta_args=meta_args)
+
+ x_node = list(graph.nodes)[0]
+ view_node = list(graph.nodes)[1]
+ sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
+ setattr(x_node, 'sharding_spec', sharding_spec)
+ setattr(view_node, 'sharding_spec', sharding_spec)
+
+ gm = ColoGraphModule(model, graph)
+ gm = node_args_converting_pass(gm, device_mesh)
+ gm = insert_narrow(gm, x_node)
+ gm.recompile()
+ output = gm(input)
+ assert output.shape == torch.Size([2, 4, 2])
+
+
+if __name__ == '__main__':
+ test_node_args_converting_pass()
diff --git a/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py
new file mode 100644
index 000000000000..3494830080ff
--- /dev/null
+++ b/tests/test_auto_parallel/test_pass/test_size_value_converting_pass.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn.functional as F
+
+from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.tracer import ColoTracer
+from colossalai.tensor.sharding_spec import ShardingSpec
+
+
+class TestModule(torch.nn.Module):
+
+ def forward(self, x):
+ size = x.size()
+ return size
+
+
+def insert_narrow(gm, x_node):
+ graph = gm.graph
+ with graph.inserting_after(x_node):
+ shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
+ size_node = list(x_node.users.keys())[0]
+ size_node.args = (shard_node,)
+ return gm
+
+
+def recover_narrow(gm, narrow_node):
+ graph = gm.graph
+ size_node = list(graph.nodes)[2]
+ x_node = narrow_node.args[0]
+ size_node.args = (x_node,)
+ graph.erase_node(narrow_node)
+ return gm
+
+
+def test_size_value_converting_pass():
+ model = TestModule()
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
+ meta_args = {'x': torch.rand(4, 8).to('meta')}
+ input = torch.rand(4, 8)
+ tracer = ColoTracer()
+ graph = tracer.trace(root=model, meta_args=meta_args)
+
+ x_node = list(graph.nodes)[0]
+ x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
+ setattr(x_node, 'sharding_spec', x_sharding_spec)
+ gm = ColoGraphModule(model, graph)
+ gm = insert_narrow(gm, x_node)
+ gm.recompile()
+ size = gm(input)
+ assert size == torch.Size([2, 8])
+
+ narrow_node = list(gm.graph.nodes)[1]
+ gm = recover_narrow(gm, narrow_node)
+ gm = size_value_converting_pass(gm, device_mesh)
+ gm = insert_narrow(gm, x_node)
+ gm.recompile()
+ size = gm(input)
+ assert size == torch.Size([4, 8])
+
+
+if __name__ == '__main__':
+ test_size_value_converting_pass()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
index e666cb1753a7..f43885a6ac44 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_bias_addition_forward.py
@@ -4,21 +4,11 @@
import torch
import torch.multiprocessing as mp
-from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
-from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
-from colossalai.auto_parallel.tensor_shard.solver import (
- CostGraph,
- GraphAnalyser,
- Solver,
- SolverOptions,
- StrategiesConstructor,
-)
+from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
-from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
+from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
@@ -63,42 +53,9 @@ def check_linear_module(rank, world_size, port):
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- tracer = ColoTracer()
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %linear_weight : [#users=1] = get_attr[target=linear.weight]
- # %linear_bias : [#users=1] = get_attr[target=linear.bias]
- # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%x, %linear_weight), kwargs = {})
- # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
- # return mul
- graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 4).to('meta')})
- # def forward(self, x : torch.Tensor):
- # linear_weight = self.linear.weight
- # linear_bias = self.linear.bias
- # linear = torch._C._nn.linear(x, linear_weight); x = linear_weight = None
- # add = linear + linear_bias; linear = linear_bias = None
- # mul = add * 2; add = None
- # return mul
- gm = ColoGraphModule(model, graph)
- gm.recompile()
- node_list = list(graph.nodes)
-
- solver_options = SolverOptions()
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
- linear_node = node_list[3]
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- graph_analyser = GraphAnalyser(gm)
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
- ret = solver.call_solver_serialized_args()
- solution = list(ret[0])
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
-
- gm = runtime_apply_pass(gm)
- gm.recompile()
- output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
+ meta_args = {'x': torch.rand(4, 4).to('meta')}
+ gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh)
+ output = gm(input)
assert_close(output, output_compare)
@@ -113,47 +70,9 @@ def check_conv_module(rank, world_size, port):
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- tracer = ColoTracer()
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %conv_weight : [#users=1] = get_attr[target=conv.weight]
- # %conv_bias : [#users=1] = get_attr[target=conv.bias]
- # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {})
- # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
- # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%add, 2), kwargs = {})
- # return mul
- graph = tracer.trace(root=model, meta_args={'x': torch.rand(4, 3, 64, 64).to('meta')})
- # def forward(self, x : torch.Tensor):
- # conv_weight = self.conv.weight
- # conv_bias = self.conv.bias
- # conv2d = torch.conv2d(x, conv_weight); x = conv_weight = None
- # view = conv_bias.view([1, -1, 1, 1]); conv_bias = None
- # add = conv2d + view; conv2d = view = None
- # mul = add * 2; add = None
- # return mul
- gm = ColoGraphModule(model, graph)
-
- gm.recompile()
-
- node_list = list(graph.nodes)
- conv_node = node_list[3]
- solver_options = SolverOptions()
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
-
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- graph_analyser = GraphAnalyser(gm)
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
- ret = solver.call_solver_serialized_args()
- solution = list(ret[0])
-
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
-
- gm = runtime_apply_pass(gm)
- gm.recompile()
- output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
+ meta_args = {'x': torch.rand(4, 3, 64, 64).to('meta')}
+ gm = initialize_model(model, meta_args=meta_args, device_mesh=device_mesh)
+ output = gm(input)
assert_close(output, output_compare)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
new file mode 100644
index 000000000000..0b42722fec5f
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_checkpoint.py
@@ -0,0 +1,70 @@
+from functools import partial
+from typing import Optional, Tuple, Union
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+from transformers.pytorch_utils import Conv1D
+
+from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.tracer import ColoTracer
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.tensor.shape_consistency import ShapeConsistencyManager
+from colossalai.testing import rerun_if_address_is_in_use
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.utils import free_port
+
+HIDDEN_SIZE = 16
+
+
+class GPT2MLPWithCkpt(nn.Module):
+
+ def __init__(self, intermediate_size, hidden_size):
+ super().__init__()
+ embed_dim = hidden_size
+ self.c_fc = Conv1D(intermediate_size, embed_dim)
+ self.c_proj = Conv1D(embed_dim, intermediate_size)
+ self.act = torch.nn.ReLU()
+
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = checkpoint(self.c_proj, hidden_states)
+ hidden_states = self.act(hidden_states)
+
+ return hidden_states
+
+
+def check_act_ckpt(rank, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ model = GPT2MLPWithCkpt(intermediate_size=4 * HIDDEN_SIZE, hidden_size=HIDDEN_SIZE)
+ input_sample = {
+ 'hidden_states': torch.rand(1, 64, HIDDEN_SIZE).to('meta'),
+ }
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ # [[0, 1]
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ gm = initialize_model(model, input_sample, device_mesh)
+ code = gm.module.graph.python_code('self').src
+ assert "runtime_comm_spec_apply_1 = colossalai_auto_parallel_passes_runtime_apply_pass_runtime_comm_spec_apply(linear_1, comm_actions_dict, 12, 'linear_1')" in code
+ assert "view_3 = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, False, view_1, comm_actions_dict, use_reentrant=True)" in code
+
+
+@run_on_environment_flag(name='AUTO_PARALLEL')
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_mlp_layer():
+ world_size = 4
+ run_func = partial(check_act_ckpt, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_mlp_layer()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
new file mode 100644
index 000000000000..e4982a5d7f5a
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_ddp.py
@@ -0,0 +1,105 @@
+import copy
+from functools import partial
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing import assert_close, rerun_if_address_is_in_use
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.utils import free_port
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, in_features):
+ super().__init__()
+ self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False)
+ self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False)
+
+ def forward(self, x):
+ x = self.linear_1(x)
+ x = self.linear_2(x)
+
+ return x
+
+
+def check_compatibility_with_ddp(rank, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ model = MLP(4).cuda()
+ if rank in [0, 1]:
+ input = torch.arange(0, 16, dtype=torch.float).reshape(4, 4).cuda()
+ elif rank in [2, 3]:
+ input = torch.arange(16, 32, dtype=torch.float).reshape(4, 4).cuda()
+ input_compare = torch.arange(0, 32, dtype=torch.float).reshape(8, 4).cuda()
+ output_compare = model(input_compare)
+ loss_compare = output_compare.sum()
+ loss_compare.backward()
+ grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2)
+
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ # [[0, 1]
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ meta_args = {'x': torch.rand(4, 4).to('meta')}
+ gm, solution = initialize_model(model,
+ meta_args=meta_args,
+ device_mesh=device_mesh,
+ return_solution=True,
+ solver_preference='tp',
+ shard_option='shard_last_axis')
+
+ msg = '| TP strategy combination chosen by auto-parallel solver |'
+ msg_length = len(msg)
+ if rank == 0:
+ print('=' * msg_length)
+ print(msg)
+ print('=' * msg_length)
+ for strategy in solution:
+ print(strategy)
+ print('=' * msg_length)
+
+ dp_process_group = None
+ for (ranks, process_group_handle) in device_mesh.process_groups_dict[0]:
+ if rank in ranks:
+ dp_process_group = process_group_handle
+ assert dp_process_group is not None
+ gm = DDP(gm, process_group=dp_process_group)
+ output = gm(input)
+
+ if rank in (0, 1):
+ assert_close(output, output_compare.narrow(0, 0, 4))
+ else:
+ assert_close(output, output_compare.narrow(0, 4, 4))
+ print(f'output on rank{rank} is correct')
+ loss = output.sum()
+
+ loss.backward()
+
+ if rank in (0, 2):
+ assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 0, 8))
+
+ if rank in (1, 3):
+ assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 8, 8))
+
+ print(f'gradient on rank{rank} is correct')
+
+
+@run_on_environment_flag(name='AUTO_PARALLEL')
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_compatibility_with_ddp():
+ world_size = 4
+ run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_compatibility_with_ddp()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
new file mode 100644
index 000000000000..760401c3f2c2
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py
@@ -0,0 +1,114 @@
+import copy
+from functools import partial
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
+from colossalai.tensor.process_group import ProcessGroup
+from colossalai.testing import assert_close, rerun_if_address_is_in_use
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.utils import free_port, get_current_device
+from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
+
+
+class MLP(torch.nn.Module):
+
+ def __init__(self, in_features):
+ super().__init__()
+ self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False)
+ self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False)
+
+ def forward(self, x):
+ x = self.linear_1(x)
+ x = self.linear_2(x)
+
+ return x
+
+
+def check_auto_parallel_with_gemini(rank, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ model = MLP(4).half().cuda()
+ if rank in [0, 1]:
+ input = torch.arange(0, 16).reshape(4, 4).half().cuda()
+ elif rank in [2, 3]:
+ input = torch.arange(16, 32).reshape(4, 4).half().cuda()
+ input_compare = torch.arange(0, 32).reshape(8, 4).half().cuda()
+ output_compare = model(input_compare)
+ loss_compare = output_compare.sum()
+ loss_compare.backward()
+ grad_compare = copy.deepcopy(model.linear_1.weight.grad / 2)
+
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ # [[0, 1]
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ meta_args = {'x': torch.rand(4, 4).half().to('meta')}
+ gm, solution = initialize_model(model,
+ meta_args=meta_args,
+ device_mesh=device_mesh,
+ return_solution=True,
+ solver_preference='tp',
+ shard_option='shard_last_axis')
+
+ if rank == 0:
+ msg = '| TP strategy combination chosen by auto-parallel solver |'
+ msg_length = len(msg)
+ print('=' * msg_length)
+ print(msg)
+ print('=' * msg_length)
+ for strategy in solution:
+ print(strategy)
+ print('=' * msg_length)
+
+ dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2)
+ gemini_config = dict(strict_ddp_mode=False,
+ device=get_current_device(),
+ placement_policy='cpu',
+ pin_memory=True,
+ search_range_mb=128)
+
+ post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group)
+ gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)
+ optimizer = HybridAdam(gm.parameters(), betas=(0, 0))
+ optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1)
+ output = gm(input)
+ if rank in (0, 1):
+ assert_close(output, output_compare.narrow(0, 0, 4))
+ else:
+ assert_close(output, output_compare.narrow(0, 4, 4))
+ print(f'output on rank{rank} is correct')
+ loss = output.sum()
+ optimizer.zero_grad()
+ optimizer.backward(loss)
+ optimizer.step()
+
+ if rank in (0, 2):
+ assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 0, 8).flatten())
+
+ if rank in (1, 3):
+ assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 8, 8).flatten())
+
+ print(f'gradient on rank{rank} is correct')
+
+
+@run_on_environment_flag(name='AUTO_PARALLEL')
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_auto_parallel_with_gemini():
+ world_size = 4
+ run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_auto_parallel_with_gemini()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py
deleted file mode 100644
index 96d96a4594c3..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_cost_graph.py
+++ /dev/null
@@ -1,96 +0,0 @@
-from copy import deepcopy
-from pickletools import optimize
-
-import pytest
-import torch
-import torch.nn as nn
-from torch.fx import GraphModule
-
-from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
-
-
-class ConvModel(nn.Module):
-
- def __init__(self, c_in, c_out):
- super().__init__()
- self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
- self.relu = nn.ReLU()
-
- def forward(self, x):
- x = x * 2
- x = self.conv1(x)
- x = x / 2
- x = self.relu(x)
- return x
-
-
-def test_cost_graph():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- entire_shape = torch.Size((4, 16, 64, 64))
-
- tracer = ColoTracer()
- model = ConvModel(16, 32)
- input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
-
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
- # %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
- # %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {})
- # %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {})
- # return relu
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
-
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
-
- # (x, mul):{(0, 0): 0}
- # (mul, conv1):{(0, 0): 65547.1, (0, 1): 65547.1, (0, 2): 65547.1, (0, 3): 65547.1, (0, 4): 131105.30000000002, (0, 5): 131105.30000000002, (0, 6): 65547.1, (0, 7): 65547.1, (0, 8): 65547.1, (0, 9): 65547.1, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 131105.30000000002, (0, 14): 131105.30000000002}
- # (conv1, truediv):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): inf, (11, 0): inf, (12, 0): inf, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): 0, (6, 1): inf, (7, 1): inf, (8, 1): inf, (9, 1): inf, (10, 1): inf, (11, 1): inf, (12, 1): inf, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): inf, (11, 2): inf, (12, 2): inf, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): inf, (11, 3): inf, (12, 3): inf, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): inf, (11, 4): inf, (12, 4): inf, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): inf, (11, 5): inf, (12, 5): inf, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): inf, (11, 7): inf, (12, 7): inf, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): inf, (9, 8): inf, (10, 8): inf, (11, 8): inf, (12, 8): inf, (13, 8): inf, (14, 8): 0}
- # (truediv, relu):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): inf, (6, 1): inf, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): inf, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): inf, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): 0}
- # (relu, output):{(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002}
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
-
- # construct all node pairs
- all_node_pairs = []
-
- for node in graph.nodes:
- if node.op == 'output':
- continue
- for child in node.users.keys():
- all_node_pairs.append((node, child))
-
- for node_pair in all_node_pairs:
- assert node_pair in cost_graph.edge_costs
-
- # construct merged node pairs
- merged_node_pairs = []
- node_list = list(graph.nodes)
- # add (conv1_weight, conv2d), (conv1_bias, view), (conv2d, add), (view, add), (add, output), (x, conv2d) into check node pairs
- merged_node_pairs.append((node_list[0], node_list[4]))
- merged_node_pairs.append((node_list[2], node_list[4]))
- merged_node_pairs.append((node_list[3], node_list[5]))
- merged_node_pairs.append((node_list[5], node_list[6]))
- merged_node_pairs.append((node_list[4], node_list[6]))
- merged_node_pairs.append((node_list[6], node_list[-1]))
- cost_graph.simplify_graph()
- for node_pair in all_node_pairs:
- if node_pair in merged_node_pairs:
- assert node_pair in cost_graph.edge_costs
- else:
- assert node_pair not in cost_graph.edge_costs
-
-
-if __name__ == '__main__':
- test_cost_graph()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py
deleted file mode 100644
index 2d3e71551eb2..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_batch_norm_handler.py
+++ /dev/null
@@ -1,118 +0,0 @@
-import torch
-from torch.fx import GraphModule
-import torch.nn as nn
-import pytest
-
-from colossalai.fx.proxy import ColoProxy
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
-from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.batch_norm_handler import BatchNormHandler
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-from colossalai.device.device_mesh import DeviceMesh
-
-
-class BNModel(nn.Module):
-
- def __init__(self, c):
- super().__init__()
- self.bn = nn.BatchNorm2d(c)
-
- def forward(self, x):
- x = x * 2
- x = self.bn(x)
- return x
-
-
-def test_bn_handler():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- entire_shape = torch.Size((4, 16, 64, 64))
-
- tracer = ColoTracer()
- model = BNModel(16)
- input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
- # %bn : [#users=1] = call_module[target=bn](args = (%mul,), kwargs = {})
- # return bn
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- # [x, mul, bn, output]
- nodes = [node for node in gm.graph.nodes]
-
- # find the sharding strategies for the input node of the bn node
- # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
- strategies_vector_for_input = StrategiesVector(nodes[1])
- sharding_option = (None, 0, 1)
- for first_sharding_index in sharding_option:
- for second_sharding_index in sharding_option:
- if first_sharding_index is not None and second_sharding_index == first_sharding_index:
- continue
- if first_sharding_index is None:
- first_dim_spec = _DimSpec([])
- else:
- first_dim_spec = _DimSpec([first_sharding_index])
-
- if second_sharding_index is None:
- second_dim_spec = _DimSpec([])
- else:
- second_dim_spec = _DimSpec([second_sharding_index])
-
- replica_dim_spec = _DimSpec([])
- sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec]
- sharding_spec = ShardingSpec(device_mesh=device_mesh,
- entire_shape=entire_shape,
- sharding_sequence=sharding_sequence)
- strategy_name = str(sharding_spec.sharding_sequence)
- sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
- strategies_vector_for_input.append(sharding_strategy)
- setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
-
- # generate bn strategy
- strategies_vector = StrategiesVector(node=nodes[2])
- bn_handler = BatchNormHandler(
- node=nodes[2],
- device_mesh=device_mesh,
- strategies_vector=strategies_vector,
- )
- bn_handler.register_strategy()
- # ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
- # 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
- strategy_name_list = [strategy.name for strategy in bn_handler.strategies_vector]
-
- # RS = RS x S and strategies based on it, such as
- # SS = RS x S
- assert 'RS0 = RS0 x S0' in strategy_name_list
- assert 'S1S0 = RS0 x S0' in strategy_name_list
- assert 'RS1 = RS1 x S1' in strategy_name_list
- assert 'S0S1 = RS1 x S1' in strategy_name_list
-
- # RR = RR x R and strategies based on it, such as
- # SR = SR x R
- assert 'RR = RR x R' in strategy_name_list
- assert 'S0R = RR x R' in strategy_name_list
- assert 'S1R = RR x R' in strategy_name_list
- assert 'S01R = RR x R' in strategy_name_list
-
- # RS01 = RS01 x S01
- assert 'RS01 = RS01 x S01' in strategy_name_list
-
- # SR = SR x R WITH SYNC_BN
- assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
- assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
-
- # SS = SS x S WITH SYNC_BN
- assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
- assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
-
- # S01R = S01R x R WITH SYNC_BN
- assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
-
-
-if __name__ == '__main__':
- test_bn_handler()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py
deleted file mode 100644
index 7adc211cfc07..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_handler.py
+++ /dev/null
@@ -1,75 +0,0 @@
-from cProfile import run
-
-import pytest
-import torch
-import torch.nn as nn
-from torch.fx import GraphModule
-
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-
-
-class ConvModel(nn.Module):
-
- def __init__(self, c_in, c_out):
- super().__init__()
- self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1)
- self.conv2 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2)
-
- def forward(self, x):
- x1 = self.conv1(x)
- x2 = x1 + 1
- x1 = torch.reshape(x1, [1, -1, 64, 1])
- x3 = self.conv2(x1)
- x3 = torch.reshape(x3, [4, 1, 64, -1])
- x = x1 + x3
-
- return x
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-def test_conv_handler():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
-
- tracer = ColoTracer()
- model = ConvModel(16, 32)
- input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %conv1 : [#users=2] = call_module[target=conv1](args = (%x,), kwargs = {})
- # %add : [#users=0] = call_function[target=operator.add](args = (%conv1, 1), kwargs = {})
- # %reshape : [#users=2] = call_function[target=torch.reshape](args = (%conv1, [1, -1, 64, 1]), kwargs = {})
- # %conv2 : [#users=1] = call_module[target=conv2](args = (%reshape,), kwargs = {})
- # %reshape_1 : [#users=1] = call_function[target=torch.reshape](args = (%conv2, [4, 1, 64, -1]), kwargs = {})
- # %add_1 : [#users=1] = call_function[target=operator.add](args = (%reshape, %reshape_1), kwargs = {})
- # return add_1
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- # [x, conv1, add, reshape, conv2, reshape_1, add_1, output]
- nodes = [node for node in gm.graph.nodes]
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
-
- strategies_constructor.build_strategies_and_cost()
- strategy_map = strategies_constructor.strategy_map
- # check a tensor add with a scalar case
- conv1_strategies = strategy_map[nodes[1]]
- add_strategies = strategy_map[nodes[2]]
- add_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in add_strategies]
- for strategy in conv1_strategies:
- assert strategy.output_sharding_spec.sharding_sequence in add_strategies_cover_list
-
- # check two tensors element-wise add case
- add_1_strategies = strategy_map[nodes[6]]
- assert len(add_1_strategies) == 25
-
-
-if __name__ == '__main__':
- test_conv_handler()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py
deleted file mode 100644
index 426d179f10d5..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_bcast_matmul.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import pytest
-import torch
-import torch.nn as nn
-from torch.fx import GraphModule
-
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-
-
-class MatmulModel(nn.Module):
-
- def __init__(self):
- super().__init__()
-
- def forward(self, x1, x2):
- x = torch.matmul(x1, x2)
-
- return x
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-def test_conv_handler():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
-
- tracer = ColoTracer()
- model = MatmulModel()
- input_sample = {'x1': torch.rand(4, 4, 8).to('meta'), 'x2': torch.rand(4, 1, 8, 4).to('meta')}
- # graph():
- # %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
- # %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
- # %matmul : [#users=1] = call_function[target=torch.matmul](args = (%x1, %x2), kwargs = {})
- # return matmul
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- # [x1, x2, matmul, output]
- nodes = [node for node in gm.graph.nodes]
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
-
- strategies_constructor.build_strategies_and_cost()
- strategy_map = strategies_constructor.strategy_map
- matmul_strategies = strategy_map[nodes[2]]
- assert len(matmul_strategies) == 30
-
-
-if __name__ == '__main__':
- test_conv_handler()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py
deleted file mode 100644
index 9342e06a040a..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_conv_handler.py
+++ /dev/null
@@ -1,90 +0,0 @@
-import pytest
-import torch
-import torch.nn as nn
-from torch.fx import GraphModule
-
-from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.proxy import ColoProxy
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
-
-
-class ConvModel(nn.Module):
-
- def __init__(self, c_in, c_out):
- super().__init__()
- self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
-
- def forward(self, x):
- x = x * 2
- x = self.conv(x)
- return x
-
-
-def test_conv_handler():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- entire_shape = torch.Size((4, 16, 64, 64))
-
- tracer = ColoTracer()
- model = ConvModel(16, 32)
- input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
- # %conv_weight : [#users=1] = get_attr[target=conv.weight]
- # %conv_bias : [#users=1] = get_attr[target=conv.bias]
- # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
- # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
- # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
- # return add
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
-
- strategies_constructor.build_strategies_and_cost()
- conv_node = list(graph.nodes)[4]
- # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
- strategy_name_list = [strategy.name for strategy in conv_node.strategies_vector]
-
- # SS = SR x RS
- assert 'S0S1 = S0R x RS1' in strategy_name_list
- assert 'S1S0 = S1R x RS0' in strategy_name_list
-
- # SR = SS x SR
- assert 'S0R = S0S1 x S1R' in strategy_name_list
- assert 'S1R = S1S0 x S0R' in strategy_name_list
-
- # RS = RS x SS
- assert 'RS0 = RS1 x S1S0' in strategy_name_list
- assert 'RS1 = RS0 x S0S1' in strategy_name_list
-
- # RS = RR x RS
- assert 'RS0 = RR x RS0' in strategy_name_list
- assert 'RS1 = RR x RS1' in strategy_name_list
-
- # RR= RR x RR
- assert 'RR = RR x RR' in strategy_name_list
-
- # SR = SR x RR
- assert 'S0R = S0R x RR' in strategy_name_list
- assert 'S1R = S1R x RR' in strategy_name_list
- assert 'S01R = S01R x RR' in strategy_name_list
-
- # RR = RS x SR
- assert 'RR = RS0 x S0R' in strategy_name_list
- assert 'RR = RS1 x S1R' in strategy_name_list
- assert 'RR = RS01 x S01R' in strategy_name_list
-
-
-if __name__ == '__main__':
- test_conv_handler()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py
deleted file mode 100644
index 0a2dba1611f0..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_dot_handler.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import pytest
-import torch
-import torch.nn as nn
-from torch.fx import GraphModule
-
-from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.proxy import ColoProxy
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
-
-
-class LinearModel(nn.Module):
-
- def __init__(self, in_features, out_features):
- super().__init__()
- self.linear = nn.Linear(in_features, out_features)
-
- def forward(self, x):
- x = x * 2
- x = self.linear(x)
- return x
-
-
-@pytest.mark.skip('F.linear is not supported in deprecated handler')
-def test_dot_handler():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- entire_shape = torch.Size((4, 8))
-
- tracer = ColoTracer()
- model = LinearModel(8, 16)
- input_sample = {'x': torch.rand(4, 8).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
- # %linear_weight : [#users=1] = get_attr[target=linear.weight]
- # %linear_bias : [#users=1] = get_attr[target=linear.bias]
- # %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %linear_weight), kwargs = {})
- # %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
- # return add
- graph = tracer.trace(root=model, meta_args=input_sample)
-
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
-
- strategies_constructor.build_strategies_and_cost()
- linear_node = list(graph.nodes)[4]
-
- # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
- strategy_name_list = [strategy.name for strategy in linear_node.strategies_vector]
-
- # SS = SR x RS
- assert 'S0S1 = S0R x RS1' in strategy_name_list
- assert 'S1S0 = S1R x RS0' in strategy_name_list
-
- # SR = SS x SR
- assert 'S0R = S0S1 x S1R' in strategy_name_list
- assert 'S1R = S1S0 x S0R' in strategy_name_list
-
- # RS = RS x SS
- assert 'RS0 = RS1 x S1S0' in strategy_name_list
- assert 'RS1 = RS0 x S0S1' in strategy_name_list
-
- # RR = RS x SR
- assert 'RR = RS0 x S0R' in strategy_name_list
- assert 'RR = RS1 x S1R' in strategy_name_list
-
- # RS= RR x RS
- assert 'RS0 = RR x RS0' in strategy_name_list
- assert 'RS1 = RR x RS1' in strategy_name_list
-
-
-if __name__ == '__main__':
- test_dot_handler()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py
deleted file mode 100644
index 40e227cb53eb..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_layer_norm_handler.py
+++ /dev/null
@@ -1,70 +0,0 @@
-import torch
-from torch.fx import GraphModule
-import torch.nn as nn
-import pytest
-from colossalai.auto_parallel.tensor_shard.deprecated import sharding_strategy
-
-from colossalai.fx.proxy import ColoProxy
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
-from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.layer_norm_handler import LayerNormHandler
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-from colossalai.device.device_mesh import DeviceMesh
-
-
-class LNModel(nn.Module):
-
- def __init__(self, c):
- super().__init__()
- self.ln = nn.LayerNorm(c)
-
- def forward(self, x):
- x = x * 2
- x = self.ln(x)
- return x
-
-
-def test_bn_handler():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- entire_shape = torch.Size((4, 4, 128))
-
- tracer = ColoTracer()
- model = LNModel(128)
- input_sample = {'x': torch.rand(4, 4, 128).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
- # %ln : [#users=1] = call_module[target=ln](args = (%mul,), kwargs = {})
- # return ln
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- # [x, mul, ln, output]
- nodes = [node for node in gm.graph.nodes]
- sharding_spec_for_input = ShardingSpec(device_mesh, entire_shape, {})
- sharding_strategy_for_input = ShardingStrategy('node_1', sharding_spec_for_input)
- strategies_vector_for_input = StrategiesVector(nodes[1])
- strategies_vector_for_input.append(sharding_strategy_for_input)
- setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
-
- # generate bn strategy
- strategies_vector = StrategiesVector(node=nodes[2])
- ln_handler = LayerNormHandler(
- node=nodes[2],
- device_mesh=device_mesh,
- strategies_vector=strategies_vector,
- )
- ln_handler.register_strategy()
- # ['[S0, R, R] = [S0, R, R] x [R]', '[R, S0, R] = [R, S0, R] x [R]', '[S1, R, R] = [S1, R, R] x [R]', '[R, S1, R] = [R, S1, R] x [R]',
- # '[S0, S1, R] = [S0, S1, R] x [R]', '[S1, S0, R] = [S1, S0, R] x [R]', '[S01, R, R] = [S01, R, R] x [R]', '[R, S01, R] = [R, S01, R] x [R]', 'RR = RR x R']
- strategy_name_list = [strategy.name for strategy in ln_handler.strategies_vector]
-
- assert len(strategy_name_list) == 9
-
-
-if __name__ == '__main__':
- test_bn_handler()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py
deleted file mode 100644
index ac9df4cd825b..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_reshape_handler.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import torch
-import torch.nn as nn
-from torch.fx import GraphModule
-
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
-
-
-class ConvModel(nn.Module):
-
- def __init__(self, c_in, c_out):
- super().__init__()
- self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
-
- def forward(self, x):
- x = self.conv(x)
- x = torch.flatten(x)
- return x
-
-
-def test_conv_handler():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
-
- tracer = ColoTracer()
- model = ConvModel(16, 32)
- input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %conv_weight : [#users=1] = get_attr[target=conv.weight]
- # %conv_bias : [#users=1] = get_attr[target=conv.bias]
- # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
- # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
- # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
- # %flatten : [#users=1] = call_function[target=torch.flatten](args = (%add,), kwargs = {})
- # return flatten
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- # [x, conv, flatten, output]
- nodes = [node for node in gm.graph.nodes]
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
-
- strategies_constructor.build_strategies_and_cost()
- strategy_map = strategies_constructor.strategy_map
- add_strategies = strategy_map[nodes[5]]
- flatten_strategies = strategy_map[nodes[6]]
- flatten_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in flatten_strategies]
- for strategy in add_strategies:
- assert strategy.output_sharding_spec.sharding_sequence in flatten_strategies_cover_list
-
-
-if __name__ == '__main__':
- test_conv_handler()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py
deleted file mode 100644
index 294a59fc8548..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_op_handler/test_deprecated_where_handler.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import torch
-from torch.fx import GraphModule
-import torch.nn as nn
-import pytest
-
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-
-
-class ConvModel(nn.Module):
-
- def __init__(self, dim_in, dim_out):
- super().__init__()
- self.dim_in = dim_in
- self.dim_out = dim_out
-
- def forward(self, condition, x, y):
- output = torch.where(condition, x, y)
-
- return output
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-def test_where_handler():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
-
- tracer = ColoTracer()
- model = ConvModel(16, 32)
- input_sample = {
- 'condition': torch.rand(16, 32).to('meta'),
- 'x': torch.rand(16, 32).to('meta'),
- 'y': torch.rand(16, 32).to('meta')
- }
- # graph():
- # %condition : torch.Tensor [#users=1] = placeholder[target=condition]
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %y : torch.Tensor [#users=1] = placeholder[target=y]
- # %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
- # return where
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
-
- # [condition, x, y, where, output]
- nodes = [node for node in gm.graph.nodes]
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
-
- strategies_constructor.build_strategies_and_cost()
- strategy_map = strategies_constructor.strategy_map
- # check a tensor add with a scalar case
- where_node = strategy_map[nodes[3]]
- # ['[S0, S1] = [S0, S1] x [S0, S1] x [S0, S1]', '[S1, S0] = [S1, S0] x [S1, S0] x [S1, S0]', '[S01, R] = [S01, R] x [S01, R] x [S01, R]',
- # '[R, S01] = [R, S01] x [R, S01] x [R, S01]', '[S0, R] = [S0, R] x [S0, R] x [S0, R]', '[R, S0] = [R, S0] x [R, S0] x [R, S0]',
- # '[S1, R] = [S1, R] x [S1, R] x [S1, R]', '[R, S1] = [R, S1] x [R, S1] x [R, S1]', '[R, R] = [R, R] x [R, R] x [R, R]']
- assert len(where_node) == 9
-
-
-if __name__ == '__main__':
- test_where_handler()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py
deleted file mode 100644
index 3286b325c8ab..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_shape_consistency_pass.py
+++ /dev/null
@@ -1,86 +0,0 @@
-from functools import partial
-import pytest
-import torch
-import torch.multiprocessing as mp
-from torch.fx import GraphModule
-import torch.nn as nn
-import pytest
-from colossalai.initialize import launch
-from colossalai.utils import free_port
-from colossalai.testing import rerun_if_address_is_in_use
-from colossalai.logging import disable_existing_loggers
-from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
-from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
-from colossalai.auto_parallel.tensor_shard.deprecated import Solver
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-
-
-class ConvModel(nn.Module):
-
- def __init__(self, c_in, c_out):
- super().__init__()
- self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False)
-
- def forward(self, x):
- x = self.conv(x)
- return x
-
-
-def check_apply(rank, world_size, port):
- disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- input = torch.rand(4, 4, 4, 4).cuda()
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- entire_shape = torch.Size((4, 4, 8, 8))
-
- tracer = ColoTracer()
- model = ConvModel(4, 4).cuda()
- origin_output = model(input)
- input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
- # return conv
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
-
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- graph_analyser = GraphAnalyser(gm)
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
- ret = solver.call_solver_serialized_args()
- solution = list(ret[0])
- sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
- shape_consistency_pass(gm)
- gm.recompile()
- nodes = [node for node in gm.graph.nodes]
- # TODO: wrap the gm to avoid the influence of the user training code
- output = gm(input, sharding_spec_dict, origin_spec_dict)
- assert output.equal(origin_output)
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_apply():
- world_size = 4
- run_func = partial(check_apply, world_size=world_size, port=free_port())
- mp.spawn(run_func, nprocs=world_size)
-
-
-if __name__ == '__main__':
- test_apply()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py
deleted file mode 100644
index baa70727a2e5..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from copy import deepcopy
-
-import pytest
-import torch
-import torch.nn as nn
-from torch.fx import GraphModule
-
-from colossalai.auto_parallel.tensor_shard.deprecated import Solver
-from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
-from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-
-
-class ConvModel(nn.Module):
-
- def __init__(self, c_in, c_out):
- super().__init__()
- self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
- self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3)
- self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3)
- self.relu = nn.ReLU()
-
- def forward(self, x):
- x = x * 2
- x = self.conv1(x)
- x = self.conv2(x)
- x = x / 2
- x = self.conv3(x)
- x = self.relu(x)
- return x
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-def test_solver():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- shape_consistency_manager = ShapeConsistencyManager()
-
- tracer = ColoTracer()
- model = ConvModel(16, 32)
- input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
-
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
- # %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
- # %conv2 : [#users=1] = call_module[target=conv2](args = (%conv1,), kwargs = {})
- # %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv2, 2), kwargs = {})
- # %conv3 : [#users=1] = call_module[target=conv3](args = (%truediv,), kwargs = {})
- # %relu : [#users=1] = call_module[target=relu](args = (%conv3,), kwargs = {})
- # return relu
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
-
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
-
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- graph_analyser = GraphAnalyser(gm)
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
- ret = solver.call_solver_serialized_args()
-
- # [ 0 0 13 13 13 13 13 0]
- strategies_combination_list = ret[0]
- assert solver.leaf_strategies[2][13].name == 'S01R = S01R x RR'
-
-
-if __name__ == '__main__':
- test_solver()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py
deleted file mode 100644
index e90d6b15308c..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_gpt.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import torch
-from torch.fx import GraphModule
-import torch.nn as nn
-import pytest
-
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
-from copy import deepcopy
-from colossalai.auto_parallel.tensor_shard.deprecated import Solver
-import transformers
-from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
-from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-
-BATCH_SIZE = 8
-SEQ_LENGHT = 8
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-def test_cost_graph():
- physical_mesh_id = torch.arange(0, 8)
- mesh_shape = (2, 4)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- shape_consistency_manager = ShapeConsistencyManager()
-
- tracer = ColoTracer()
- config = transformers.GPT2Config(n_position=1024, n_layer=1, n_head=12)
- model = transformers.GPT2LMHeadModel(config=config)
- input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
- token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
- attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
- kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
- meta_args = {k: v.to('meta') for k, v in kwargs.items()}
-
- graph = tracer.trace(root=model, meta_args=meta_args)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- graph_analyser = GraphAnalyser(gm)
- liveness_list = graph_analyser.liveness_analysis()
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- print(graph)
- strategies_constructor.build_strategies_and_cost()
- for check_node, strategies_vector in strategies_constructor.strategy_map.items():
- print(check_node, len(strategies_vector))
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
-
- ret = solver.call_solver_serialized_args()
- print(ret)
- strategies_list = list(ret[0])
- print(strategies_list)
- computation_cost = 0
- communication_cost = 0
- memory_cost = 0
- nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
- for index, node in enumerate(nodes):
- print(node.name, node.strategies_vector[strategies_list[index]].name)
- computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
- communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
- node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
- if isinstance(node_memory_cost, tuple):
- node_memory_cost = node_memory_cost[0]
- memory_cost += node_memory_cost
-
- print(f'computation cost is {computation_cost}')
- print(f'communication cost is {communication_cost}')
- print(f'memory cost is {memory_cost}')
-
-
-if __name__ == '__main__':
- test_cost_graph()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py
deleted file mode 100644
index 415156ed6545..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_solver_with_mlp.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import torch
-from torch.fx import GraphModule
-import torch.nn as nn
-import pytest
-
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
-from copy import deepcopy
-from colossalai.auto_parallel.tensor_shard.deprecated import Solver
-from torchvision.models import resnet34, resnet50
-from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
-from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-
-
-class MLP(torch.nn.Module):
-
- def __init__(self, dim: int):
- super().__init__()
- self.linear1 = torch.nn.Linear(dim, dim * 4)
- self.linear2 = torch.nn.Linear(dim * 4, dim)
- self.dropout = torch.nn.Dropout(0)
- self.relu = torch.nn.ReLU()
-
- def forward(self, x):
- x = self.linear1(x)
- x = self.dropout(x)
- x = self.relu(x)
- x = self.linear2(x)
- return x
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-def test_cost_graph():
- physical_mesh_id = torch.arange(0, 8)
- mesh_shape = (2, 4)
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- shape_consistency_manager = ShapeConsistencyManager()
-
- tracer = ColoTracer()
- model = MLP(32)
-
- input_sample = {'x': torch.rand(16, 32).to('meta')}
-
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
- # %dropout : [#users=1] = call_module[target=dropout](args = (%linear1,), kwargs = {})
- # %relu : [#users=1] = call_module[target=relu](args = (%dropout,), kwargs = {})
- # %linear2 : [#users=1] = call_module[target=linear2](args = (%relu,), kwargs = {})
- # return linear2
- graph = tracer.trace(root=model, meta_args=input_sample)
-
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- graph_analyser = GraphAnalyser(gm)
- liveness_list = graph_analyser.liveness_analysis()
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
-
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- # # megatron mode if no memory constraints
- # solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
- # all sharding on out feature dim if memory budget is not sufficient for megatron mode
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=5500.0)
-
- ret = solver.call_solver_serialized_args()
- strategies_list = list(ret[0])
- computation_cost = 0
- communication_cost = 0
- memory_cost = 0
- for index, node in enumerate(graph.nodes):
- print(node.name, node.strategies_vector[strategies_list[index]].name)
- computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
- communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
- node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
- if isinstance(node_memory_cost, tuple):
- node_memory_cost = node_memory_cost[0]
- memory_cost += node_memory_cost
-
- print(f'computation cost is {computation_cost}')
- print(f'communication cost is {communication_cost}')
- print(f'memory cost is {memory_cost}')
-
-
-if __name__ == '__main__':
- test_cost_graph()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py b/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py
deleted file mode 100644
index 9be1a5d963a9..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_deprecated/test_deprecated_strategies_constructor.py
+++ /dev/null
@@ -1,103 +0,0 @@
-from copy import deepcopy
-
-import pytest
-import torch
-import torch.nn as nn
-from torch.fx import GraphModule
-
-from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST
-from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
-from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.proxy import ColoProxy
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
-
-
-class ConvModel(nn.Module):
-
- def __init__(self, c_in, c_out):
- super().__init__()
- self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
-
- def forward(self, x):
- x = x * 2
- x = self.conv(x)
- return x
-
-
-def test_strategies_constructor():
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
- entire_shape = torch.Size((4, 16, 64, 64))
-
- tracer = ColoTracer()
- model = ConvModel(16, 32)
- input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
- # %conv_weight : [#users=1] = get_attr[target=conv.weight]
- # %conv_bias : [#users=1] = get_attr[target=conv.bias]
- # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
- # %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
- # %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
- # return add
- graph = tracer.trace(root=model, meta_args=input_sample)
- print(graph)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
-
- solver_options = SolverOptions(fast=True)
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
-
- assert strategies_constructor.leaf_strategies == []
- assert strategies_constructor.strategy_map == {}
- strategies_constructor.build_strategies_and_cost()
-
- # check leaf_strategies
-
- # In fast mode, placeholder node only has replica strategy.
- assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
-
- # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
- assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
-
- # Third node is conv.
- conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
- for strategy in strategies_constructor.leaf_strategies[4]:
- conv_check_list.remove(strategy.name)
- assert len(conv_check_list) == 0
-
- # In fast mode, output node only has replica strategy.
- assert strategies_constructor.leaf_strategies[7][0].name == 'Replica Output'
-
- # check strategy_map
-
- nodes = [node for node in graph.nodes]
- # In fast mode, placeholder node only has replica strategy.
- x = nodes[0]
- assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder'
-
- # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
- mul = nodes[1]
- assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
-
- # fifth node is conv.
- conv = nodes[4]
- conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
- for strategy in strategies_constructor.strategy_map[conv]:
- conv_check_list.remove(strategy.name)
- assert len(conv_check_list) == 0
-
- # In fast mode, output node only has replica strategy.
- output = nodes[-1]
- assert strategies_constructor.strategy_map[output][0].name == 'Replica Output'
-
-
-if __name__ == '__main__':
- test_strategies_constructor()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py
new file mode 100644
index 000000000000..90301521f207
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py
@@ -0,0 +1,110 @@
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch.fx import GraphModule
+from transformers.pytorch_utils import Conv1D
+
+from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks
+from colossalai.fx.tracer.tracer import ColoTracer
+from colossalai.testing import parameterize
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+
+NUM_REPEAT_BLOCKS = 4
+BATCH_SIZE = 1
+SEQ_LENGTH = 32
+HIDDEN_DIM = 384
+
+
+class RepeatBlock(nn.Module):
+
+ def __init__(self, intermediate_size, hidden_size):
+ super().__init__()
+ self.c_fc = Conv1D(intermediate_size, hidden_size)
+ self.c_proj = Conv1D(hidden_size, intermediate_size)
+ self.act = torch.nn.ReLU()
+
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+
+ return hidden_states
+
+
+class RepeatModel(nn.Module):
+
+ def __init__(self, intermediate_size, hidden_size, num_layers):
+ super().__init__()
+ self.blocks = nn.ModuleList([RepeatBlock(intermediate_size, hidden_size) for i in range(num_layers)])
+
+ def forward(self, x):
+
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class NonRepeatBlock(nn.Module):
+
+ def __init__(self, intermediate_size, hidden_size, layer_index):
+ super().__init__()
+ intermediate_size //= (layer_index + 1)
+ self.c_fc = Conv1D(intermediate_size, hidden_size)
+ self.c_proj = Conv1D(hidden_size, intermediate_size)
+ self.act = torch.nn.ReLU()
+
+ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+ hidden_states = self.c_fc(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.c_proj(hidden_states)
+
+ return hidden_states
+
+
+class NonRepeatModel(nn.Module):
+
+ def __init__(self, intermediate_size, hidden_size, num_layers):
+ super().__init__()
+ self.blocks = nn.ModuleList([NonRepeatBlock(intermediate_size, hidden_size, i) for i in range(num_layers)])
+
+ def forward(self, x):
+
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+@run_on_environment_flag(name='AUTO_PARALLEL')
+@parameterize('model_cls', [RepeatModel, NonRepeatModel])
+def test_repeat_blocks(model_cls):
+
+ model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS)
+
+ tracer = ColoTracer()
+ input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')}
+ graph = tracer.trace(root=model, meta_args=input_sample)
+
+ gm = GraphModule(model, graph, model.__class__.__name__)
+ gm.recompile()
+
+ node_list = list(graph.nodes)
+ root_module = graph.owning_module
+ common_blocks = find_repeat_blocks(node_list, root_module, common_length_threshold=10)
+
+ total_num_nodes = len(list(graph.nodes))
+ # remove the input placeholder node and the output node
+ num_repeat_nodes_per_block = (total_num_nodes - 2) // NUM_REPEAT_BLOCKS
+ for common_block in common_blocks:
+ print(common_block)
+ if model_cls == RepeatModel:
+ assert len(common_blocks) == NUM_REPEAT_BLOCKS
+ assert len(common_blocks[0]) == num_repeat_nodes_per_block
+ elif model_cls == NonRepeatModel:
+ assert len(common_blocks) == 0
+
+
+if __name__ == '__main__':
+ test_repeat_blocks()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py
deleted file mode 100644
index 0979d8353ee7..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_gpt2_performance.py
+++ /dev/null
@@ -1,131 +0,0 @@
-import copy
-import random
-from functools import partial
-from time import time
-from typing import Dict, Optional, Tuple, Union
-
-import numpy as np
-import psutil
-import pytest
-import torch
-import torch.multiprocessing as mp
-import torch.nn as nn
-import transformers
-from torch.fx import GraphModule
-from torch.profiler import ProfilerActivity, profile, record_function, schedule, tensorboard_trace_handler
-
-from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
-from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
-from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
-from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
-from colossalai.auto_parallel.tensor_shard.solver import (
- CostGraph,
- GraphAnalyser,
- Solver,
- SolverOptions,
- StrategiesConstructor,
-)
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.initialize import launch, launch_from_torch
-from colossalai.logging import disable_existing_loggers, get_dist_logger
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
-from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-from colossalai.utils import free_port
-from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2LMHeadModel, GPTLMLoss
-
-BATCH_SIZE = 32
-SEQ_LENGTH = 256
-HIDDEN_DIM = 16384
-NUM_HEADS = 128
-NUM_LAYERS = 4
-VOCAB_SIZE = 50257
-NUM_STEPS = 10
-FP16 = True
-
-
-def get_cpu_mem():
- return psutil.Process().memory_info().rss / 1024**2
-
-
-def get_gpu_mem():
- return torch.cuda.memory_allocated() / 1024**2
-
-
-def get_mem_info(prefix=''):
- return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
-
-
-def get_tflops(model_numel, batch_size, seq_len, step_time):
- # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
- return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
-
-
-# Randomly Generated Data
-def get_data(batch_size, seq_len, vocab_size):
- input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
- attention_mask = torch.ones_like(input_ids)
- return input_ids, attention_mask
-
-
-def main():
- disable_existing_loggers()
- launch_from_torch(config={})
- logger = get_dist_logger()
- config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
- if FP16:
- model = GPT2LMHeadModel(config=config).half().to('cuda')
- else:
- model = GPT2LMHeadModel(config=config).to('cuda')
- global_numel = sum([p.numel() for p in model.parameters()])
-
- meta_input_sample = {
- 'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
- 'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
- }
-
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
- gm = initialize_model(model, meta_input_sample, device_mesh)
-
- # build criterion
- criterion = GPTLMLoss()
-
- optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
- logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
- get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH)
- torch.cuda.synchronize()
- model.train()
- # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
- # schedule=schedule(wait=1, warmup=2, active=2),
- # on_trace_ready=tensorboard_trace_handler(f'log/dummy_data/bs128_seq128_new'),
- # record_shapes=True,
- # profile_memory=True) as prof:
- # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as prof:
- for n in range(10):
- # we just use randomly generated data here
- input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE)
- optimizer.zero_grad()
- start = time()
- outputs = gm(input_ids, attn_mask)
- loss = criterion(outputs, input_ids)
- loss.backward()
- optimizer.step()
- # prof.step()
- torch.cuda.synchronize()
- step_time = time() - start
- logger.info(
- f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
- ranks=[0])
- # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=10))
- torch.cuda.synchronize()
-
-
-if __name__ == '__main__':
- main()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
index c7f9988f1824..ebeef9870fe9 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
@@ -1,32 +1,27 @@
import copy
import random
from functools import partial
-from typing import Dict, Optional, Tuple, Union
+from typing import Dict
import numpy as np
import pytest
import torch
import torch.multiprocessing as mp
-import torch.nn as nn
import transformers
from torch.fx import GraphModule
-from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
-from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
-from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
-from colossalai.auto_parallel.tensor_shard.solver import (
- CostGraph,
- GraphAnalyser,
- Solver,
- SolverOptions,
- StrategiesConstructor,
+from colossalai.auto_parallel.tensor_shard.initialize import (
+ ModuleWrapper,
+ build_strategy_constructor,
+ solve_solution,
+ transform_to_sharded_model,
)
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_global
+from colossalai.tensor.shape_consistency import to_global
from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
@@ -49,6 +44,7 @@ def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, tor
best_sharding_spec_dict: Dict[str, ShardingSpec]):
for name, param in module.named_parameters():
param_grad = param.grad
+ name = name.replace('module.', '')
origin_param_grad = origin_param_dict[name].grad
atoms = name.split('.')
new_name = '_'.join(atoms)
@@ -115,30 +111,17 @@ def check_attention_layer(rank, model_cls, world_size, port):
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- shape_consistency_manager = ShapeConsistencyManager()
-
tracer = ColoTracer()
graph = tracer.trace(root=model, meta_args=meta_input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
- graph_analyser = GraphAnalyser(gm)
- liveness_list = graph_analyser.liveness_analysis()
- solver_options = SolverOptions()
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
-
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
- ret = solver.call_solver_serialized_args()
-
- solution = list(ret[0])
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
- gm, solution, device_mesh, strategies_constructor)
- gm = runtime_apply_pass(gm)
- gm.recompile()
+ strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
+ solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
+ gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
+ gm = ModuleWrapper(gm, *sharding_spec_dicts)
+
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
best_sharding_spec_dict = {}
for index, node in enumerate(nodes):
@@ -149,7 +132,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
origin_output = test_model(*test_input_sample)
torch.cuda.set_rng_state(cuda_rng_state)
torch.set_rng_state(cpu_rng_state)
- output = gm(*input_sample, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
+ output = gm(*input_sample)
assert_close(output, origin_output, rtol=1e-03, atol=1e-03)
#*******************backward starting*******************
@@ -174,16 +157,15 @@ def check_attention_layer(rank, model_cls, world_size, port):
#*******************strategy selected*******************
if rank == 0:
print("*******************strategy selected*******************")
- strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(nodes):
- print(node.name, node.strategies_vector[strategies_list[index]].name)
- computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
- communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
- node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
+ print(node.name, node.strategies_vector[solution[index]].name)
+ computation_cost += node.strategies_vector[solution[index]].compute_cost.total
+ communication_cost += node.strategies_vector[solution[index]].communication_cost.total
+ node_memory_cost = node.strategies_vector[solution[index]].memory_cost.total
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost.activation + node_memory_cost.parameter
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
index 26ad0d3a08a7..4adb4fbaf047 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
@@ -4,13 +4,8 @@
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
-from colossalai.auto_parallel.tensor_shard.solver import (
- CostGraph,
- GraphAnalyser,
- Solver,
- SolverOptions,
- StrategiesConstructor,
-)
+from colossalai.auto_parallel.tensor_shard.options import SolverOptions
+from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
@@ -20,13 +15,13 @@
BATCH_SIZE = 1
SEQ_LENGTH = 32
-HIDDEN_DIM = 768
+HIDDEN_DIM = 384
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
def test_self_attention_block(model_cls):
- config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
+ config = transformers.GPT2Config(n_position=64, n_layer=12, n_head=16, n_embd=HIDDEN_DIM)
if model_cls == GPT2MLP:
model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
else:
@@ -59,15 +54,13 @@ def test_self_attention_block(model_cls):
gm = GraphModule(model, graph, model.__class__.__name__)
print(gm.graph)
gm.recompile()
- graph_analyser = GraphAnalyser(gm)
- liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
+ solver = Solver(gm.graph, strategies_constructor, cost_graph, memory_budget=-1)
ret = solver.call_solver_serialized_args()
strategies_list = solver.last_s_val
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py
index f468b1ab2113..e41ac4fa690b 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_activation_metainfo.py
@@ -5,6 +5,8 @@
import torch.multiprocessing as mp
import torch.nn as nn
+from colossalai.auto_parallel.meta_profiler import meta_register
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
@@ -12,50 +14,58 @@
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
-from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
-
-
-def _ReLU_module_mem_test(rank, world_size, port):
- """This function is for ReLU memory test
- Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
-
- Args:
- Args:
- rank: device rank
- bias: indicate whether conv module need bias
- world_size: number of devices
- port: port for initializing process group
- """
- disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- model = nn.Sequential(nn.ReLU()).cuda()
- input = torch.rand(4, 128, 64, 64).cuda()
- input.requires_grad = True
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
- # index of target node in computation graph
- node_index = 1
- # total number of target node strategies
- strategy_number = 1
- mem_test_for_node_strategy(rank=rank,
- model=model,
- device_mesh=device_mesh,
- node_index=node_index,
- strategy_number=strategy_number,
- input_args=[input],
- meta_arg_names=['input'])
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_ReLU_meta_concrete_info_match():
- world_size = 4
- run_func_module = partial(_ReLU_module_mem_test, world_size=world_size, port=free_port())
- mp.spawn(run_func_module, nprocs=world_size)
+from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
+
+
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
+@parameterize('func', [
+ torch.nn.functional.softmax,
+ torch.nn.functional.relu,
+ torch.tanh,
+ torch.nn.functional.dropout,
+])
+def test_activation_meta_info(func):
+ meta_func = meta_register.get(func)
+ # construct meta tensors
+ input_tensor = torch.rand(256, 1024, device="meta")
+ output_tensor = torch.rand(256, 1024, device="meta")
+ softmax_dim = 0
+
+ # construct operation data
+ input_data = OperationData(name='input', type=OperationDataType.ARG, data=input_tensor)
+ output_data = OperationData(name='output', type=OperationDataType.OUTPUT, data=output_tensor)
+ softmax_dim_data = OperationData(name='softmax_dim', type=OperationDataType.ARG, data=softmax_dim)
+
+ # construct args and kwargs
+ args = [input_data, softmax_dim_data, output_data]
+ kwargs = {'inplace': False}
+
+ # estimated results
+ compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
+
+ # actual results
+ input_real_tensor = torch.rand(256, 1024, device="cuda")
+
+ input_real_tensor.requires_grad = True
+
+ # fwd
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ output_real_tensor = func(input_real_tensor)
+ fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ # bwd
+ upstream_grad = torch.rand_like(output_real_tensor)
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ torch.autograd.backward(output_real_tensor, upstream_grad)
+ bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak,
+ bwd_allocated, bwd_peak)
if __name__ == '__main__':
- test_ReLU_meta_concrete_info_match()
+ test_activation_meta_info()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py
deleted file mode 100644
index 826c746668da..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_batchnorm_metainfo.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from functools import partial
-
-import pytest
-import torch
-import torch.multiprocessing as mp
-import torch.nn as nn
-
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx import ColoGraphModule, ColoTracer
-from colossalai.initialize import launch
-from colossalai.logging import disable_existing_loggers
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
-from colossalai.utils import free_port
-from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy
-
-
-def _batchnorm_module_mem_test(rank, world_size, port):
- """This function is for batchnorm memory test
- Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
-
- Args:
- rank: device rank
- bias: indicate whether conv module need bias
- world_size: number of devices
- port: port for initializing process group
- """
- disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- model = nn.Sequential(nn.BatchNorm2d(128)).cuda()
- input = torch.rand(4, 128, 64, 64).cuda()
- input.requires_grad = True
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
- # index of target node in computation graph
- node_index = 1
- # total number of target node strategies
- strategy_number = 9
- mem_test_for_node_strategy(rank=rank,
- model=model,
- device_mesh=device_mesh,
- node_index=node_index,
- strategy_number=strategy_number,
- input_args=[input],
- meta_arg_names=['input'])
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_batchnorm_meta_concrete_info_match():
- world_size = 4
- run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port())
- mp.spawn(run_func_module, nprocs=world_size)
-
-
-if __name__ == '__main__':
- test_batchnorm_meta_concrete_info_match()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py
new file mode 100644
index 000000000000..2fb1306546ca
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_embedding_metainfo.py
@@ -0,0 +1,77 @@
+from functools import partial
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
+ MemoryCost,
+ OperationData,
+ OperationDataType,
+ ShardingStrategy,
+ StrategiesVector,
+ TrainCycleItem,
+)
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx import ColoGraphModule, ColoTracer
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port
+from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
+
+if torch.__version__ >= '1.12.0':
+ from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
+
+
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
+def test_embedding_meta_info():
+ meta_func = meta_register.get(torch.nn.Embedding)
+
+ # construct meta tensors
+ input_tensor = torch.randint(0, 50256, (8, 1024), device="meta")
+ weight_tensor = torch.rand(50257, 1024, device="meta")
+ output_tensor = torch.rand(8, 1024, 1024, device="meta")
+
+ # construct operation data
+ input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor)
+
+ weight_data = OperationData(name="weight", type=OperationDataType.PARAM, data=weight_tensor)
+
+ output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor)
+
+ # construct args and kwargs
+ args = [input_data, weight_data, output_data]
+ kwargs = {'inplace': False}
+
+ # estimated results
+ compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
+
+ # actual results
+ input_real_tensor = torch.randint(0, 50256, (8, 1024), device="cuda")
+ embedding_module = torch.nn.Embedding(50257, 1024).cuda()
+
+ # fwd
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ output_real_tensor = embedding_module(input_real_tensor)
+ fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ # bwd
+ upstream_grad = torch.rand_like(output_real_tensor)
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ torch.autograd.backward(output_real_tensor, upstream_grad)
+ bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak,
+ bwd_allocated, bwd_peak)
+
+
+if __name__ == '__main__':
+ test_embedding_meta_info()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py
new file mode 100644
index 000000000000..fd29c63fb522
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_matmul_metainfo.py
@@ -0,0 +1,110 @@
+from functools import partial
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
+ MemoryCost,
+ OperationData,
+ OperationDataType,
+ ShardingStrategy,
+ StrategiesVector,
+ TrainCycleItem,
+)
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx import ColoGraphModule, ColoTracer
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port
+from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
+
+if torch.__version__ >= '1.12.0':
+ from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
+
+
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
+@parameterize(
+ 'tensor_shapes',
+ [
+ [[128], [128]], # dot product
+ [[64, 128], [128]], # mat-vec
+ [[128], [128, 64]], # vec-mat
+ [[64, 64, 128], [128]], # batched mat-vec
+ [[128], [64, 128, 64]], # vec-batched mat
+ [[64, 128], [128, 192]], # mat-mat
+ [[64, 64, 128], [128, 192]], # batched mat-mat
+ [[64, 128], [64, 128, 192]], # mat-batched mat
+ [[64, 64, 128], [64, 128, 192]], # batched mat-batched mat (matched batch dims)
+ [[64, 1, 64, 128], [64, 128, 192]], # batched mat-batched mat (unmatched batch dims)
+ ])
+def test_matmul_function_meta_info(tensor_shapes):
+ meta_func = meta_register.get(torch.matmul)
+
+ # construct meta tensors
+ input_tensor = torch.rand(*tensor_shapes[0], device="meta")
+ other_tensor = torch.rand(*tensor_shapes[1], device="meta")
+ output_tensor = torch.matmul(input_tensor, other_tensor)
+
+ # construct operation data
+ input_data = OperationData(
+ name="input",
+ data=input_tensor,
+ type=OperationDataType.ARG,
+ logical_shape=input_tensor.shape,
+ )
+ other_data = OperationData(
+ name="other",
+ data=other_tensor,
+ type=OperationDataType.ARG,
+ logical_shape=other_tensor.shape,
+ )
+ output_data = OperationData(
+ name="output",
+ data=output_tensor,
+ type=OperationDataType.OUTPUT,
+ logical_shape=output_tensor.shape,
+ )
+
+ # construct args and kwargs
+ args = [input_data, other_data, output_data]
+ kwargs = {'inplace': False}
+
+ # estimated results
+ compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
+
+ # actual results
+ input_real_tensor = torch.rand(*tensor_shapes[0], device="cuda:0")
+ other_real_tensor = torch.rand(*tensor_shapes[1], device="cuda:0")
+
+ input_real_tensor.requires_grad = True
+ other_real_tensor.requires_grad = True
+
+ # fwd
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ output_real_tensor = torch.matmul(input_real_tensor, other_real_tensor)
+ fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ # bwd
+ upstream_grad = torch.rand_like(output_real_tensor)
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ torch.autograd.backward(output_real_tensor, upstream_grad)
+ bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ compute_cost: TrainCycleItem
+ memory_cost: TrainCycleItem
+
+ print_results([input_real_tensor, other_real_tensor], [output_real_tensor], compute_cost, memory_cost,
+ fwd_allocated, fwd_peak, bwd_allocated, bwd_peak)
+
+
+if __name__ == '__main__':
+ test_matmul_function_meta_info()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py
new file mode 100644
index 000000000000..9d3ab9c82670
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_norm_metainfo.py
@@ -0,0 +1,131 @@
+from functools import partial
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
+ MemoryCost,
+ OperationData,
+ OperationDataType,
+ ShardingStrategy,
+ StrategiesVector,
+ TrainCycleItem,
+)
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx import ColoGraphModule, ColoTracer
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port
+from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import mem_test_for_node_strategy, print_results
+
+if torch.__version__ >= '1.12.0':
+ from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
+
+
+def _batchnorm_module_mem_test(rank, world_size, port):
+ """This function is for batchnorm memory test
+ Test and print real memory cost and estimated, this test will not be executed except with the tag AUTO_PARALLEL
+
+ Args:
+ rank: device rank
+ bias: indicate whether conv module need bias
+ world_size: number of devices
+ port: port for initializing process group
+ """
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ model = nn.Sequential(nn.BatchNorm2d(128)).cuda()
+ input = torch.rand(4, 128, 64, 64).cuda()
+ input.requires_grad = True
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+
+ # index of target node in computation graph
+ node_index = 1
+ # total number of target node strategies
+ strategy_number = 9
+ mem_test_for_node_strategy(rank=rank,
+ model=model,
+ device_mesh=device_mesh,
+ node_index=node_index,
+ strategy_number=strategy_number,
+ input_args=[input],
+ meta_arg_names=['input'])
+
+
+@run_on_environment_flag(name='AUTO_PARALLEL')
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_batchnorm_meta_concrete_info_match():
+ world_size = 4
+ run_func_module = partial(_batchnorm_module_mem_test, world_size=world_size, port=free_port())
+ mp.spawn(run_func_module, nprocs=world_size)
+
+
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='need pytorch 1.12.0 or higher for aten level operations')
+@parameterize('tensor_shape', [
+ [256, 1024],
+ [1024, 256],
+])
+def test_layernorm_meta_info(tensor_shape):
+ meta_func = meta_register.get(torch.nn.LayerNorm)
+
+ # construct input
+ input_tensor = torch.rand(*tensor_shape, device="meta")
+ output_tensor = torch.rand(*tensor_shape, device="meta")
+ weight_tensor = torch.rand(tensor_shape[1], device="meta")
+ bias_tensor = torch.rand(tensor_shape[1], device="meta")
+
+ # construct operation data
+ input_data = OperationData(name="input", type=OperationDataType.ARG, data=input_tensor)
+
+ output_data = OperationData(name="output", type=OperationDataType.OUTPUT, data=output_tensor)
+
+ weight_data = OperationData(name="weight", type=OperationDataType.PARAM, data=weight_tensor)
+
+ bias_data = OperationData(name="bias", type=OperationDataType.PARAM, data=bias_tensor)
+
+ # construct args and kwargs
+ args = [input_data, output_data, weight_data, bias_data]
+ kwargs = {'inplace': False}
+
+ # estimated results
+ compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
+
+ # actual results
+ input_real_tensor = torch.rand(*tensor_shape, device="cuda:0")
+
+ input_real_tensor.requires_grad = True
+
+ ln_module = torch.nn.LayerNorm(tensor_shape[1]).cuda()
+
+ # fwd
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ output_real_tensor = ln_module(input_real_tensor)
+ fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ # bwd
+ upstream_grad = torch.rand_like(output_real_tensor)
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ torch.autograd.backward(output_real_tensor, upstream_grad)
+ bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ compute_cost: TrainCycleItem
+ memory_cost: TrainCycleItem
+
+ print_results([input_real_tensor], [output_real_tensor], compute_cost, memory_cost, fwd_allocated, fwd_peak,
+ bwd_allocated, bwd_peak)
+
+
+if __name__ == '__main__':
+ test_batchnorm_meta_concrete_info_match()
+ test_layernorm_meta_info()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py
new file mode 100644
index 000000000000..a0ab66fdc060
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_tensor_metainfo.py
@@ -0,0 +1,103 @@
+from functools import partial
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
+ MemoryCost,
+ OperationData,
+ OperationDataType,
+ ShardingStrategy,
+ StrategiesVector,
+ TrainCycleItem,
+)
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx import ColoGraphModule, ColoTracer
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port
+from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
+
+if torch.__version__ >= '1.12.0':
+ from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
+
+
+class SplitModule(nn.Module):
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(self, x):
+ return x.split(512, dim=0)
+
+
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
+def test_tensor_meta_info():
+ """test tensor related meta information
+ We will just use torch.Tensor.split for the test
+ """
+ meta_func = meta_register.get(torch.Tensor.split)
+
+ # construct meta tensors
+ input_tensor = torch.rand(1024, 1024, device="meta")
+ output_tensor = input_tensor.split(512, dim=0)
+
+ # construct operation data
+ input_data = OperationData(
+ name="input",
+ data=input_tensor,
+ type=OperationDataType.ARG,
+ logical_shape=input_tensor.shape,
+ )
+ output_data = OperationData(
+ name="output",
+ data=output_tensor,
+ type=OperationDataType.OUTPUT,
+ logical_shape=input_tensor.shape,
+ )
+ split_info_data = OperationData(
+ name='split_info',
+ type=OperationDataType.ARG,
+ data=0,
+ logical_shape=None,
+ )
+
+ # construct args
+ args = [input_data, output_data, split_info_data]
+ kwargs = {'inplace': False}
+
+ # estimated results
+ compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
+
+ # actual results
+ model = SplitModule()
+ input_real_tensor = torch.rand(1024, 1024).cuda()
+
+ input_real_tensor.requires_grad = True
+
+ # fwd
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ output_real_tensor = model(input_real_tensor)
+ fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ # bwd
+ upstream_grad = [torch.rand_like(tensor) for tensor in output_real_tensor]
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ torch.autograd.backward(output_real_tensor, upstream_grad)
+ bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ print_results([input_real_tensor], output_real_tensor, compute_cost, memory_cost, fwd_allocated, fwd_peak,
+ bwd_allocated, bwd_peak)
+
+
+if __name__ == "__main__":
+ test_tensor_meta_info()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py
new file mode 100644
index 000000000000..20156f9ab4d5
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/test_where_metainfo.py
@@ -0,0 +1,104 @@
+import pytest
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+from colossalai.auto_parallel.tensor_shard.node_handler import LinearModuleHandler
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
+ MemoryCost,
+ OperationData,
+ OperationDataType,
+ ShardingStrategy,
+ StrategiesVector,
+ TrainCycleItem,
+)
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx import ColoGraphModule, ColoTracer
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.testing.utils import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port
+from tests.test_auto_parallel.test_tensor_shard.test_metainfo.utils import print_results
+
+if torch.__version__ >= '1.12.0':
+ from colossalai.auto_parallel.meta_profiler import MetaInfo, meta_register
+
+
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
+def test_where_meta_info():
+ meta_func = meta_register.get(torch.where)
+
+ # construct meta tensors
+ condition_tensor = torch.rand(1, 1, 1024, 1024) > 0.5
+ condition_tensor = condition_tensor.to(device="meta")
+ x_tensor = torch.rand(8, 16, 1024, 1024, device="meta")
+ y_tensor = torch.tensor(0, device="meta")
+ output_tensor = torch.rand(8, 16, 1024, 1024)
+
+ # construct operation data
+ condition_data = OperationData(
+ name="condition",
+ data=condition_tensor,
+ type=OperationDataType.ARG,
+ logical_shape=condition_tensor.shape,
+ )
+ x_data = OperationData(
+ name="x",
+ data=x_tensor,
+ type=OperationDataType.ARG,
+ logical_shape=x_tensor.shape,
+ )
+ y_data = OperationData(
+ name="y",
+ data=y_tensor,
+ type=OperationDataType.ARG,
+ logical_shape=y_tensor.shape,
+ )
+ output_data = OperationData(
+ name="output",
+ data=output_tensor,
+ type=OperationDataType.OUTPUT,
+ logical_shape=output_tensor.shape,
+ )
+
+ # construct args and kwargs
+ args = [condition_data, x_data, y_data, output_data]
+ kwargs = {'inplace': False}
+
+ # estimated results
+ compute_cost, memory_cost, fwd_in, fwd_buffer, fwd_out = meta_func(*args, **kwargs)
+
+ # actual results
+ condition_real_tensor = torch.rand(1, 1, 1024, 1024) > 0.5
+ condition_real_tensor = condition_real_tensor.to(device="cuda")
+ x_real_tensor = torch.rand(8, 16, 1024, 1024, device="cuda")
+ y_real_tensor = torch.tensor(0.0, device="cuda")
+
+ x_real_tensor.requires_grad = True
+ y_real_tensor.requires_grad = True
+
+ # fwd
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ output_real_tensor = torch.where(condition_real_tensor, x_real_tensor, y_real_tensor)
+ fwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ fwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ # bwd
+ upstream_grad = torch.rand_like(output_real_tensor)
+ torch.cuda.reset_peak_memory_stats()
+ mem_stamp0 = torch.cuda.memory_allocated()
+ torch.autograd.backward(output_real_tensor, upstream_grad)
+ bwd_allocated = torch.cuda.memory_allocated() - mem_stamp0
+ bwd_peak = torch.cuda.max_memory_allocated() - mem_stamp0
+
+ compute_cost: TrainCycleItem
+ memory_cost: TrainCycleItem
+
+ print_results([condition_real_tensor, x_real_tensor, y_real_tensor], [output_real_tensor], compute_cost,
+ memory_cost, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak)
+
+
+if __name__ == '__main__':
+ test_where_meta_info()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py
index 7c06f2ee9e20..60ecd1dd9801 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_metainfo/utils.py
@@ -7,8 +7,9 @@
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
-from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
-from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
+from colossalai.auto_parallel.tensor_shard.options import SolverOptions
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType, TrainCycleItem
+from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
@@ -57,7 +58,7 @@ def mem_test_for_node_strategy(rank: int,
output_key]
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
- gm, solution, device_mesh)
+ gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()
gm: GraphModule
@@ -126,3 +127,56 @@ def mem_test_for_node_strategy(rank: int,
f"backward temp: {metainfo.memory_cost.bwd.temp / 1024} kb, backward buffer: {metainfo.memory_cost.bwd.buffer / 1024} kb"
)
print("=======================")
+
+
+def print_results(input: List[torch.Tensor], output: List[torch.Tensor], compute_cost: TrainCycleItem,
+ memory_cost: TrainCycleItem, fwd_allocated, fwd_peak, bwd_allocated, bwd_peak):
+ """Print the results of the meta information test.
+
+ Args:
+ input (List[torch.Tensor]): input tensors
+ output (List[torch.Tensor]): output tensors
+ compute_cost (TrainCycleItem): compute cost estimated by meta_func
+ memory_cost (TrainCycleItem): memory cost estimated by meta_func
+ fwd_allocated: real forward memory allocated
+ fwd_peak: real forward peak memory stats
+ bwd_allocated: real backward memory allocated
+ bwd_peak: real backward peak memory stats
+ """
+ print("=====================")
+ print(f"input shapes: {[tensor.shape for tensor in input]}")
+ print(f"output shapes: {[tensor.shape for tensor in output]}")
+
+ # estimated results
+ print("Estimated Results")
+
+ # compute cost
+ print("compute_cost:")
+ print(f" fwd: {compute_cost.fwd}")
+ print(f" bwd: {compute_cost.bwd}")
+
+ # memory cost
+ print("memory_cost:")
+ # fwd
+ print(f" fwd activation: {memory_cost.fwd.activation / 1024} KB")
+ print(f" fwd buffer: {memory_cost.fwd.buffer / 1024} KB")
+ print(f" fwd temp: {memory_cost.fwd.temp / 1024} KB")
+ print(f" fwd parameter: {memory_cost.fwd.parameter / 1024} KB")
+
+ # bwd
+ print(f" bwd activation: {memory_cost.bwd.activation / 1024} KB")
+ print(f" bwd buffer: {memory_cost.bwd.buffer / 1024} KB")
+ print(f" bwd temp: {memory_cost.bwd.temp / 1024} KB")
+ print(f" bwd parameter: {memory_cost.bwd.parameter / 1024} KB")
+
+ # actual results
+ print("Actual Results")
+
+ print("memory_cost:")
+ # fwd
+ print(f" fwd allocated: {fwd_allocated / 1024} KB")
+ print(f" fwd peak: {fwd_peak / 1024} KB")
+
+ # bwd
+ print(f" bwd allocated: {bwd_allocated / 1024} KB")
+ print(f" bwd peak: {bwd_peak / 1024} KB")
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
index 42430d5a24cb..50385c0450a8 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_binary_elementwise_handler.py
@@ -122,25 +122,41 @@ def forward(self, x1, x2):
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
-def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port):
- disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+class BEOpModelWithNodeConst(nn.Module):
- class BinaryElementwiseOpModel(nn.Module):
+ def __init__(self, op):
+ super().__init__()
+ self.op = op
- def __init__(self, op, const):
- super().__init__()
- self.op = op
- self.const = const
+ def forward(self, x1):
+ const = x1.dim()
+ out = self.op(x1, const)
+ return out
- def forward(self, x1):
- out = self.op(x1, self.const)
- return out
+
+class BEOpModelWithIntConst(nn.Module):
+
+ def __init__(self, op, const):
+ super().__init__()
+ self.op = op
+ self.const = const
+
+ def forward(self, x1):
+ out = self.op(x1, self.const)
+ return out
+
+
+def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
- model = BinaryElementwiseOpModel(op, other_dim).cuda()
+ if model_cls == BEOpModelWithNodeConst:
+ model = model_cls(op).cuda()
+ else:
+ model = model_cls(op, other_dim).cuda()
x1 = torch.rand(4, 4).cuda()
# the index of binary-elementwise node in computation graph
node_index = 1
@@ -159,9 +175,14 @@ def forward(self, x1):
tracer = ColoTracer()
meta_args = {'x1': torch.rand(4, 4).to('meta')}
graph = tracer.trace(model, meta_args=meta_args)
+ print(graph)
+ # assert False
gm = ColoGraphModule(model, graph)
- op_node = list(graph.nodes)[1]
+ if model_cls == BEOpModelWithNodeConst:
+ op_node = list(graph.nodes)[2]
+ else:
+ op_node = list(graph.nodes)[1]
strategies_vector = StrategiesVector(op_node)
# build handler
@@ -212,7 +233,7 @@ def forward(self, x1):
@parameterize('other_dim', [1, 2])
@pytest.mark.dist
@rerun_if_address_is_in_use()
-def test_binary_elementwise_handler(op, other_dim):
+def test_binary_elementwise_handler_with_tensor(op, other_dim):
world_size = 4
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
op=op,
@@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim):
world_size=world_size,
port=free_port())
mp.spawn(run_func_tensor, nprocs=world_size)
+
+
+@run_on_environment_flag(name='AUTO_PARALLEL')
+@parameterize('op', [torch.add])
+@parameterize('other_dim', [1, 2])
+@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst])
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
+ world_size = 4
run_func_int = partial(check_binary_elementwise_handler_with_int,
op=op,
+ model_cls=model_cls,
other_dim=other_dim,
world_size=world_size,
port=free_port())
@@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim):
if __name__ == '__main__':
- test_binary_elementwise_handler()
+ test_binary_elementwise_handler_with_tensor()
+ test_binary_elementwise_handler_with_int()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py
similarity index 91%
rename from tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py
rename to tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py
index de277002b75d..ea7c2b729635 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_reshape_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_default_reshape_handler.py
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
+from colossalai.auto_parallel.tensor_shard.node_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
-from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@@ -51,9 +51,9 @@ def test_reshape_handler():
strategies_vector=conv_strategies_vector)
conv_handler.register_strategy(compute_resharding_cost=False)
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
- reshape_handler = ReshapeHandler(node=reshape_node,
- device_mesh=device_mesh,
- strategies_vector=reshape_strategies_vector)
+ reshape_handler = DefaultReshapeHandler(node=reshape_node,
+ device_mesh=device_mesh,
+ strategies_vector=reshape_strategies_vector)
reshape_handler.register_strategy(compute_resharding_cost=False)
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
index 3c35da61b1c3..c72d2a6a80e8 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
@@ -5,10 +5,10 @@
import torch.multiprocessing as mp
import torch.nn as nn
+from colossalai.auto_parallel.tensor_shard.node_handler.default_reshape_handler import DefaultReshapeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler import GetItemHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler import PlaceholderHandler
-from colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler import ReshapeHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
@@ -153,7 +153,9 @@ def test_getitem_from_tuple_handler():
)
input_handler.register_strategy(compute_resharding_cost=False)
setattr(input_node, 'strategies_vector', input_strategies_vector)
- split_handler = ReshapeHandler(node=split_node, device_mesh=device_mesh, strategies_vector=split_strategies_vector)
+ split_handler = DefaultReshapeHandler(node=split_node,
+ device_mesh=device_mesh,
+ strategies_vector=split_strategies_vector)
split_handler.register_strategy(compute_resharding_cost=False)
setattr(split_node, 'strategies_vector', split_strategies_vector)
getitem_handler = GetItemHandler(node=getitem_node,
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
index 3d268ea43fc3..18afacf56b8e 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
@@ -1,12 +1,9 @@
-from faulthandler import disable
from functools import partial
-from xml.dom import WrongDocumentErr
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
-from typing_extensions import Self
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py
index 306c45f56dbf..91b3ae27d599 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_matmul_handler.py
@@ -1,3 +1,4 @@
+import pytest
import torch
import torch.nn as nn
@@ -24,6 +25,7 @@ def forward(self, x1, x2):
return torch.matmul(x1, x2)
+@pytest.mark.skipif(torch.__version__ < '1.12.0', reason="need pytorch 1.12.0 or higher for aten level operations")
@parameterize(
'tensor_shapes',
[
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
index c695b8843a3c..af03481d830e 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py
@@ -5,8 +5,8 @@
import torch.multiprocessing as mp
import torch.nn as nn
+from colossalai.auto_parallel.tensor_shard.node_handler import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
-from colossalai.auto_parallel.tensor_shard.node_handler.experimental import PermuteHandler, TransposeHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
@@ -243,79 +243,79 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
if model_cls.__name__ == 'LinearReshapeModel':
if reshape_dims == ((0, 2, 1, 3), (1, 2)):
- assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list
- assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list
- assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list
- assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list
- assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list
- assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list
- assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
- assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
- assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list
- assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
- assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list
- assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
- assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
- assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list
- assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
- assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
+ assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list
+ assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list
+ assert '[R, R, S0, S1] -> [R, S0, R, S1]_13' in strategy_name_list
+ assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list
+ assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list
+ assert '[R, R, S1, S0] -> [R, S1, R, S0]_16' in strategy_name_list
+ assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
+ assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list
+ assert '[R, R, S0, R] -> [R, S0, R, R]_19' in strategy_name_list
+ assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
+ assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list
+ assert '[R, R, S1, R] -> [R, S1, R, R]_22' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
+ assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
+ assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list
+ assert '[R, R, S01, R] -> [R, S01, R, R]_2' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
+ assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
if reshape_dims == (2, 0, 1, 3):
- assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list
- assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list
- assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list
- assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list
- assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list
- assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list
- assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list
- assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list
- assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list
- assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list
- assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list
- assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
- assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list
- assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list
- assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
- assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
+ assert '[S0, R, R, S1] -> [R, S0, R, S1]_11' in strategy_name_list
+ assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list
+ assert '[R, R, S0, S1] -> [S0, R, R, S1]_13' in strategy_name_list
+ assert '[S1, R, R, S0] -> [R, S1, R, S0]_14' in strategy_name_list
+ assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list
+ assert '[R, R, S1, S0] -> [S1, R, R, S0]_16' in strategy_name_list
+ assert '[S0, R, R, R] -> [R, S0, R, R]_17' in strategy_name_list
+ assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list
+ assert '[R, R, S0, R] -> [S0, R, R, R]_19' in strategy_name_list
+ assert '[S1, R, R, R] -> [R, S1, R, R]_20' in strategy_name_list
+ assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list
+ assert '[R, R, S1, R] -> [S1, R, R, R]_22' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
+ assert '[S01, R, R, R] -> [R, S01, R, R]_0' in strategy_name_list
+ assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list
+ assert '[R, R, S01, R] -> [S01, R, R, R]_2' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
+ assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
if reshape_dims == (1, 3):
- assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list
- assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list
- assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list
- assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list
- assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list
- assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list
- assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
- assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list
- assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
- assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
- assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list
- assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
- assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list
- assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
- assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list
- assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list
- assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
- assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list
- assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
- assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list
+ assert '[S0, R, R, S1] -> [S0, S1, R, R]_11' in strategy_name_list
+ assert '[R, S0, R, S1] -> [R, S1, R, S0]_12' in strategy_name_list
+ assert '[R, R, S0, S1] -> [R, S1, S0, R]_13' in strategy_name_list
+ assert '[S1, R, R, S0] -> [S1, S0, R, R]_14' in strategy_name_list
+ assert '[R, S1, R, S0] -> [R, S0, R, S1]_15' in strategy_name_list
+ assert '[R, R, S1, S0] -> [R, S0, S1, R]_16' in strategy_name_list
+ assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
+ assert '[R, S0, R, R] -> [R, R, R, S0]_18' in strategy_name_list
+ assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
+ assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
+ assert '[R, S1, R, R] -> [R, R, R, S1]_21' in strategy_name_list
+ assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, S1, R, R]_10' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, S0, R, R]_9' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, S0, R, R]_6' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, S1, R, R]_5' in strategy_name_list
+ assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
+ assert '[R, S01, R, R] -> [R, R, R, S01]_1' in strategy_name_list
+ assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
+ assert '[R, R, R, S01] -> [R, S01, R, R]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
new file mode 100644
index 000000000000..f6895d92ab03
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
@@ -0,0 +1,122 @@
+from functools import partial
+
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler
+from colossalai.auto_parallel.tensor_shard.options import ShardOption
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx import ColoGraphModule, ColoTracer
+from colossalai.testing import parameterize
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.testing.utils import parameterize
+
+
+class LinearModel(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input, others, bias=None):
+ x = nn.functional.linear(input, others, bias=bias)
+ return x
+
+
+def check_shard_option(shard_option):
+ model = LinearModel().cuda()
+ physical_mesh_id = torch.arange(0, 4)
+ mesh_shape = (2, 2)
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
+
+ tracer = ColoTracer()
+ graph = tracer.trace(model,
+ meta_args={
+ "input": torch.rand(4, 4, 4, 16).to('meta'),
+ 'others': torch.rand(32, 16).to('meta')
+ })
+ gm = ColoGraphModule(model, graph)
+ linear_func_node = list(graph.nodes)[2]
+ strategies_vector = StrategiesVector(linear_func_node)
+
+ # build handler
+ handler = LinearFunctionHandler(node=linear_func_node,
+ device_mesh=device_mesh,
+ strategies_vector=strategies_vector,
+ shard_option=shard_option)
+
+ strategies_vector = handler.register_strategy(compute_resharding_cost=False)
+ strategy_name_list = [val.name for val in strategies_vector]
+
+ if shard_option == ShardOption.SHARD_LAST_AXIS:
+ # RR = RS x SR
+ assert 'RR = RS1 x S1R' in strategy_name_list
+
+ # RS= RR x RS
+ assert 'RS1 = RR x RS1' in strategy_name_list
+
+ return
+
+ # SS = SR x RS
+ assert 'S1S0 = S1R x RS0_0' in strategy_name_list
+ assert 'S0S1 = S0R x RS1_1' in strategy_name_list
+ assert 'S0S1 = S0R x RS1_2' in strategy_name_list
+ assert 'S0S1 = S0R x RS1_0' in strategy_name_list
+ assert 'S1S0 = S1R x RS0_1' in strategy_name_list
+ assert 'S1S0 = S1R x RS0_2' in strategy_name_list
+
+ # SR = SS x SR
+ assert 'S0R = S0S1 x S1R_1' in strategy_name_list
+ assert 'S0R = S0S1 x S1R_2' in strategy_name_list
+ assert 'S1R = S1S0 x S0R_0' in strategy_name_list
+ assert 'S0R = S0S1 x S1R_0' in strategy_name_list
+ assert 'S1R = S1S0 x S0R_1' in strategy_name_list
+ assert 'S1R = S1S0 x S0R_2' in strategy_name_list
+
+ # RS = RS x SS
+ assert 'RS0 = RS1 x S1S0' in strategy_name_list
+ assert 'RS1 = RS0 x S0S1' in strategy_name_list
+
+ # S01R = S01R x RR
+ assert 'S01R = S01R x RR_0' in strategy_name_list
+ assert 'S01R = S01R x RR_1' in strategy_name_list
+ assert 'S01R = S01R x RR_2' in strategy_name_list
+
+ # RR = RS01 x S01R
+ assert 'RR = RS01 x S01R' in strategy_name_list
+
+ # RS01 = RR x RS01
+ assert 'RS01 = RR x RS01' in strategy_name_list
+
+ if shard_option == ShardOption.SHARD:
+ # RR = RS x SR
+ assert 'RR = RS0 x S0R' in strategy_name_list
+ assert 'RR = RS1 x S1R' in strategy_name_list
+
+ # RS= RR x RS
+ assert 'RS0 = RR x RS0' in strategy_name_list
+ assert 'RS1 = RR x RS1' in strategy_name_list
+
+ if shard_option == ShardOption.STANDARD:
+ # RR = RS x SR
+ assert 'RR = RS0 x S0R' in strategy_name_list
+ assert 'RR = RS1 x S1R' in strategy_name_list
+
+ # RS= RR x RS
+ assert 'RS0 = RR x RS0' in strategy_name_list
+ assert 'RS1 = RR x RS1' in strategy_name_list
+
+ # RR = RR x RR
+ assert 'RR = RR x RR' in strategy_name_list
+
+
+@run_on_environment_flag(name='AUTO_PARALLEL')
+def test_shard_option():
+ # for shard_option in [ShardOption.STANDARD, ShardOption.SHARD, ShardOption.FULL_SHARD, ShardOption.SHARD_LAST_AXIS]:
+ for shard_option in [ShardOption.SHARD_LAST_AXIS]:
+ check_shard_option(shard_option)
+
+
+if __name__ == '__main__':
+ test_shard_option()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
index b5e8e32778be..c43ee292bedf 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py
@@ -117,54 +117,54 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
strategy_name_list = [strategy.name for strategy in split_strategies_vector]
if softmax_dim == 0:
- assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list
- assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list
- assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list
- assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list
- assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list
- assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list
- assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list
- assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list
- assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
- assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list
- assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
- assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list
+ assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list
+ assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list
+ assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list
+ assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list
+ assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list
+ assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list
+ assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list
+ assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list
+ assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list
+ assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
+ assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
if softmax_dim == 1:
- assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list
- assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list
- assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list
- assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list
- assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
- assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
- assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list
- assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
+ assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list
- assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list
- assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list
+ assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list
+ assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list
+ assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list
+ assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
+ assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
+ assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
- assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list
+ assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
+ assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list
+ assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
+ assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
index 9e8e905c54a2..044aef19d38d 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py
@@ -5,8 +5,8 @@
import torch.multiprocessing as mp
import torch.nn as nn
+from colossalai.auto_parallel.tensor_shard.node_handler import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
-from colossalai.auto_parallel.tensor_shard.node_handler.experimental import SplitHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
@@ -156,8 +156,7 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert len(split_strategies_vector) == len(previous_strategies_vector)
strategy_name_list = [strategy.name for strategy in split_strategies_vector]
- for name in strategy_name_list:
- print(name)
+
if model_cls.__name__ == 'ConvSplitModel':
if split_dim == 0:
@@ -199,54 +198,54 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
if model_cls.__name__ == 'LinearSplitModel':
if split_dim == 0:
- assert '[R, R, R, S1]_0' in strategy_name_list
- assert '[R, S0, R, S1]_1' in strategy_name_list
- assert '[R, R, S0, S1]_2' in strategy_name_list
- assert '[R, R, R, S0]_3' in strategy_name_list
- assert '[R, S1, R, S0]_4' in strategy_name_list
- assert '[R, R, S1, S0]_5' in strategy_name_list
- assert '[R, R, R, R]_6' in strategy_name_list
- assert '[R, S0, R, R]_7' in strategy_name_list
- assert '[R, R, S0, R]_8' in strategy_name_list
- assert '[R, R, R, R]_9' in strategy_name_list
- assert '[R, S1, R, R]_10' in strategy_name_list
- assert '[R, R, S1, R]_11' in strategy_name_list
- assert '[R, R, R, S1]_12' in strategy_name_list
- assert '[R, R, R, S0]_13' in strategy_name_list
- assert '[R, R, R, R]_14' in strategy_name_list
- assert '[R, R, R, R]_15' in strategy_name_list
- assert '[R, R, R, S0]_16' in strategy_name_list
- assert '[R, R, R, S1]_17' in strategy_name_list
- assert '[R, R, R, R]_18' in strategy_name_list
- assert '[R, S01, R, R]_19' in strategy_name_list
- assert '[R, R, S01, R]_20' in strategy_name_list
- assert '[R, R, R, R]_21' in strategy_name_list
- assert '[R, R, R, S01]_22' in strategy_name_list
+ assert '[R, R, R, S1]_11' in strategy_name_list
+ assert '[R, S0, R, S1]_12' in strategy_name_list
+ assert '[R, R, S0, S1]_13' in strategy_name_list
+ assert '[R, R, R, S0]_14' in strategy_name_list
+ assert '[R, S1, R, S0]_15' in strategy_name_list
+ assert '[R, R, S1, S0]_16' in strategy_name_list
+ assert '[R, R, R, R]_17' in strategy_name_list
+ assert '[R, S0, R, R]_18' in strategy_name_list
+ assert '[R, R, S0, R]_19' in strategy_name_list
+ assert '[R, R, R, R]_20' in strategy_name_list
+ assert '[R, S1, R, R]_21' in strategy_name_list
+ assert '[R, R, S1, R]_22' in strategy_name_list
+ assert '[R, R, R, S1]_10' in strategy_name_list
+ assert '[R, R, R, S0]_9' in strategy_name_list
+ assert '[R, R, R, R]_8' in strategy_name_list
+ assert '[R, R, R, R]_7' in strategy_name_list
+ assert '[R, R, R, S0]_6' in strategy_name_list
+ assert '[R, R, R, S1]_5' in strategy_name_list
+ assert '[R, R, R, R]_0' in strategy_name_list
+ assert '[R, S01, R, R]_1' in strategy_name_list
+ assert '[R, R, S01, R]_2' in strategy_name_list
+ assert '[R, R, R, R]_3' in strategy_name_list
+ assert '[R, R, R, S01]_4' in strategy_name_list
if split_dim == 1:
- assert '[S0, R, R, S1]_0' in strategy_name_list
- assert '[R, R, R, S1]_1' in strategy_name_list
- assert '[R, R, S0, S1]_2' in strategy_name_list
- assert '[S1, R, R, S0]_3' in strategy_name_list
- assert '[R, R, R, S0]_4' in strategy_name_list
- assert '[R, R, S1, S0]_5' in strategy_name_list
- assert '[S0, R, R, R]_6' in strategy_name_list
- assert '[R, R, R, R]_7' in strategy_name_list
- assert '[R, R, S0, R]_8' in strategy_name_list
- assert '[S1, R, R, R]_9' in strategy_name_list
- assert '[R, R, R, R]_10' in strategy_name_list
- assert '[R, R, S1, R]_11' in strategy_name_list
+ assert '[S0, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1]_12' in strategy_name_list
- assert '[R, R, R, S0]_13' in strategy_name_list
- assert '[R, R, R, R]_14' in strategy_name_list
- assert '[R, R, R, R]_15' in strategy_name_list
- assert '[R, R, R, S0]_16' in strategy_name_list
- assert '[R, R, R, S1]_17' in strategy_name_list
- assert '[S01, R, R, R]_18' in strategy_name_list
- assert '[R, R, R, R]_19' in strategy_name_list
- assert '[R, R, S01, R]_20' in strategy_name_list
+ assert '[R, R, S0, S1]_13' in strategy_name_list
+ assert '[S1, R, R, S0]_14' in strategy_name_list
+ assert '[R, R, R, S0]_15' in strategy_name_list
+ assert '[R, R, S1, S0]_16' in strategy_name_list
+ assert '[S0, R, R, R]_17' in strategy_name_list
+ assert '[R, R, R, R]_18' in strategy_name_list
+ assert '[R, R, S0, R]_19' in strategy_name_list
+ assert '[S1, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R]_21' in strategy_name_list
- assert '[R, R, R, S01]_22' in strategy_name_list
+ assert '[R, R, S1, R]_22' in strategy_name_list
+ assert '[R, R, R, S1]_10' in strategy_name_list
+ assert '[R, R, R, S0]_9' in strategy_name_list
+ assert '[R, R, R, R]_8' in strategy_name_list
+ assert '[R, R, R, R]_7' in strategy_name_list
+ assert '[R, R, R, S0]_6' in strategy_name_list
+ assert '[R, R, R, S1]_5' in strategy_name_list
+ assert '[S01, R, R, R]_0' in strategy_name_list
+ assert '[R, R, R, R]_1' in strategy_name_list
+ assert '[R, R, S01, R]_2' in strategy_name_list
+ assert '[R, R, R, R]_3' in strategy_name_list
+ assert '[R, R, R, S01]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
index 08a702789f9f..8a96ac0d66f0 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
@@ -5,8 +5,8 @@
import torch.multiprocessing as mp
import torch.nn as nn
+from colossalai.auto_parallel.tensor_shard.node_handler import ViewHandler
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
-from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler
from colossalai.auto_parallel.tensor_shard.node_handler.linear_handler import LinearFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
@@ -196,54 +196,57 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
if model_cls.__name__ == 'LinearViewModel':
if tgt_shape == (32, 4, 64, 16, 4):
- assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_0' in strategy_name_list
- assert '[R, S0, R, S1] -> FULLY REPLICATED_1' in strategy_name_list
- assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_2' in strategy_name_list
- assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_3' in strategy_name_list
- assert '[R, S1, R, S0] -> FULLY REPLICATED_4' in strategy_name_list
- assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_5' in strategy_name_list
- assert '[S0, R, R, R] -> [S0, R, R, R, R]_6' in strategy_name_list
- assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list
- assert '[R, R, S0, R] -> [R, R, S0, R, R]_8' in strategy_name_list
- assert '[S1, R, R, R] -> [S1, R, R, R, R]_9' in strategy_name_list
- assert '[R, S1, R, R] -> FULLY REPLICATED_10' in strategy_name_list
- assert '[R, R, S1, R] -> [R, R, S1, R, R]_11' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1, R]_12' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0, R]_13' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R, R]_15' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, S0, R]_16' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, S1, R]_17' in strategy_name_list
- assert '[S01, R, R, R] -> [S01, R, R, R, R]_18' in strategy_name_list
- assert '[R, S01, R, R] -> FULLY REPLICATED_19' in strategy_name_list
- assert '[R, R, S01, R] -> [R, R, S01, R, R]_20' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R, R]_21' in strategy_name_list
- assert '[R, R, R, S01] -> [R, R, R, S01, R]_22' in strategy_name_list
+ for strategy in strategy_name_list:
+ print(strategy)
+ # print(strategy_name_list)
+ assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_11' in strategy_name_list
+ assert '[R, S0, R, S1] -> FULLY REPLICATED_12' in strategy_name_list
+ assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_13' in strategy_name_list
+ assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_14' in strategy_name_list
+ assert '[R, S1, R, S0] -> FULLY REPLICATED_15' in strategy_name_list
+ assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_16' in strategy_name_list
+ assert '[S0, R, R, R] -> [S0, R, R, R, R]_17' in strategy_name_list
+ assert '[R, S0, R, R] -> FULLY REPLICATED_18' in strategy_name_list
+ assert '[R, R, S0, R] -> [R, R, S0, R, R]_19' in strategy_name_list
+ assert '[S1, R, R, R] -> [S1, R, R, R, R]_20' in strategy_name_list
+ assert '[R, S1, R, R] -> FULLY REPLICATED_21' in strategy_name_list
+ assert '[R, R, S1, R] -> [R, R, S1, R, R]_22' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1, R]_10' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0, R]_9' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R, R]_7' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, S0, R]_6' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, S1, R]_5' in strategy_name_list
+ assert '[S01, R, R, R] -> [S01, R, R, R, R]_0' in strategy_name_list
+ assert '[R, S01, R, R] -> FULLY REPLICATED_1' in strategy_name_list
+ assert '[R, R, S01, R] -> [R, R, S01, R, R]_2' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R, R]_3' in strategy_name_list
+ assert '[R, R, R, S01] -> [R, R, R, S01, R]_4' in strategy_name_list
if tgt_shape == (8, 4, 4, 64, 16, 4):
- assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_0' in strategy_name_list
- assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1' in strategy_name_list
- assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_2' in strategy_name_list
- assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_3' in strategy_name_list
- assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_4' in strategy_name_list
- assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_5' in strategy_name_list
- assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_6' in strategy_name_list
- assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list
- assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_8' in strategy_name_list
- assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_9' in strategy_name_list
- assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_10' in strategy_name_list
- assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_11' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_12' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_13' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R, R, R]_15' in strategy_name_list
- assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_16' in strategy_name_list
- assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_17' in strategy_name_list
- assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_18' in strategy_name_list
- assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_19' in strategy_name_list
- assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_20' in strategy_name_list
- assert '[R, R, R, R] -> [R, R, R, R, R, R]_21' in strategy_name_list
- assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_22' in strategy_name_list
+ assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11' in strategy_name_list
+ assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12' in strategy_name_list
+ assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13' in strategy_name_list
+ assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14' in strategy_name_list
+ assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15' in strategy_name_list
+ assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16' in strategy_name_list
+ assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_17' in strategy_name_list
+ assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_18' in strategy_name_list
+ assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_19' in strategy_name_list
+ assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_20' in strategy_name_list
+ assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_21' in strategy_name_list
+ assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_22' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_10' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_9' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R, R, R]_7' in strategy_name_list
+ assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_6' in strategy_name_list
+ assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_5' in strategy_name_list
+ assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_0' in strategy_name_list
+ assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_1' in strategy_name_list
+ assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_2' in strategy_name_list
+ assert '[R, R, R, R] -> [R, R, R, R, R, R]_3' in strategy_name_list
+ assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL')
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
index d02e1e31eb40..0cdfdbc9d0cd 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
@@ -6,9 +6,9 @@
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
-from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
+from colossalai.auto_parallel.tensor_shard.options import SolverOptions
+from colossalai.auto_parallel.tensor_shard.solver import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
-from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
@@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
- target_node = list(graph.nodes)[node_index]
+ target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies
+ ][node_index]
if node_type == 'normal':
solution_len = len(strategies_constructor.leaf_strategies)
solution = [0] * solution_len
@@ -107,12 +108,11 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
# solution construction
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
- graph_analyser = GraphAnalyser(gm)
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
+ solver = Solver(gm.graph, strategies_constructor, cost_graph, verbose=False)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
- gm, solution, device_mesh)
+ gm, solution, device_mesh, strategies_constructor)
gm = runtime_apply_pass(gm)
gm.recompile()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
index b504d59c971f..92f011ba30d2 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_param_resharding_cost.py
@@ -1,13 +1,8 @@
import torch
+from colossalai.auto_parallel.tensor_shard.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationDataType
-from colossalai.auto_parallel.tensor_shard.solver import (
- CostGraph,
- GraphAnalyser,
- Solver,
- SolverOptions,
- StrategiesConstructor,
-)
+from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py b/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py
deleted file mode 100644
index 814edd27948c..000000000000
--- a/tests/test_auto_parallel/test_tensor_shard/test_resnet_block_runtime.py
+++ /dev/null
@@ -1,270 +0,0 @@
-import copy
-from copy import deepcopy
-from functools import partial
-
-import pytest
-import torch
-import torch.multiprocessing as mp
-import torch.nn as nn
-from torch.fx import GraphModule
-from torchvision.models import resnet34, resnet50
-
-from colossalai import device
-from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
-from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
-from colossalai.auto_parallel.tensor_shard.constants import *
-from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
-from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
-from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
-from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
-from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.initialize import launch
-from colossalai.logging import disable_existing_loggers
-from colossalai.testing import assert_close, assert_close_loose, rerun_if_address_is_in_use
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
-from colossalai.utils import free_port
-
-seed = 128
-cudnn_benchmark = False
-cudnn_deterministic = True
-
-
-def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
- """3x3 convolution with padding"""
- return nn.Conv2d(
- in_planes,
- out_planes,
- kernel_size=3,
- stride=stride,
- padding=dilation,
- groups=groups,
- bias=False,
- dilation=dilation,
- )
-
-
-def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
- """1x1 convolution"""
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-
-class Bottleneck(nn.Module):
- # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
- # while original implementation places the stride at the first 1x1 convolution(self.conv1)
- # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
- # This variant is also known as ResNet V1.5 and improves accuracy according to
- # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
-
- expansion: int = 4
-
- def __init__(
- self,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample=None,
- groups: int = 1,
- base_width: int = 64,
- dilation: int = 1,
- norm_layer=None,
- ) -> None:
- super().__init__()
- if norm_layer is None:
- norm_layer = nn.BatchNorm2d
- width = int(planes * (base_width / 64.0)) * groups
- # Both self.conv2 and self.downsample layers downsample the input when stride != 1
- self.conv1 = conv1x1(inplanes, width)
- self.bn1 = norm_layer(width)
- self.conv2 = conv3x3(width, width, stride, groups, dilation)
- self.bn2 = norm_layer(width)
- self.conv3 = conv1x1(width, planes * self.expansion)
- self.bn3 = norm_layer(planes * self.expansion)
- self.relu = nn.ReLU(inplace=True)
- self.downsample = downsample
- self.stride = stride
-
- def forward(self, x):
- identity = x
-
- out = self.conv1(x)
- out = self.bn1(out)
- out = self.relu(out)
-
- out = self.conv2(out)
- out = self.bn2(out)
- out = self.relu(out)
-
- out = self.conv3(out)
- out = self.bn3(out)
-
- if self.downsample is not None:
- identity = self.downsample(x)
-
- out = self.relu(out)
-
- return out
-
-
-def check_apply_bottleneck(rank, world_size, port):
- disable_existing_loggers()
- launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
- input = torch.rand(4, 4, 4, 4).cuda()
- physical_mesh_id = torch.arange(0, 4)
- mesh_shape = (2, 2)
- # [[0, 1]
- # [2, 3]]
- device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
-
- tracer = ColoTracer()
- model = Bottleneck(4, 4, 1, norm_layer=torch.nn.modules.batchnorm.BatchNorm2d).cuda()
- test_model = copy.deepcopy(model)
- test_input = copy.deepcopy(input)
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %conv1 : [#users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
- # %bn1 : [#users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
- # %relu : [#users=1] = call_module[target=relu](args = (%bn1,), kwargs = {})
- # %conv2 : [#users=1] = call_module[target=conv2](args = (%relu,), kwargs = {})
- # %bn2 : [#users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
- # %relu_1 : [#users=1] = call_module[target=relu](args = (%bn2,), kwargs = {})
- # %conv3 : [#users=1] = call_module[target=conv3](args = (%relu_1,), kwargs = {})
- # %bn3 : [#users=1] = call_module[target=bn3](args = (%conv3,), kwargs = {})
- # %relu_2 : [#users=1] = call_module[target=relu](args = (%bn3,), kwargs = {})
- # return relu_2
- input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
-
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- solver_options = SolverOptions()
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
-
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- graph_analyser = GraphAnalyser(gm)
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
- ret = solver.call_solver_serialized_args()
- solution = list(ret[0])
- print(solution)
- for index, node in enumerate(graph.nodes):
- print(node.name, node.strategies_vector[solution[index]].name)
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
- gm = runtime_apply_pass(gm)
- gm.recompile()
- nodes = [node for node in gm.graph.nodes]
- # TODO: wrap the gm to avoid the influence of the user training code
- cuda_rng_state = torch.cuda.get_rng_state()
- origin_output = test_model(test_input)
- torch.cuda.set_rng_state(cuda_rng_state)
- output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
-
- assert output.shape == origin_output.shape
- assert_close(output, origin_output, rtol=1e-03, atol=1e-05)
- print("*******************backward starting*******************")
- cuda_rng_state = torch.cuda.get_rng_state()
- output.sum().backward()
- torch.cuda.set_rng_state(cuda_rng_state)
- origin_output.sum().backward()
- if rank == 0:
- print(
- f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 0, 4)).abs().sum()}"
- )
- print(
- f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}"
- )
- print(
- f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
- )
- print(
- f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
- )
- print(
- f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 0, 1)).abs().sum()}"
- )
- print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
-
- assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
- assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum())
- assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
-
- if rank == 1:
- print(
- f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 4, 4)).abs().sum()}"
- )
- print(
- f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 0, 8)).abs().sum()}"
- )
- print(
- f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
- )
- print(
- f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
- )
- print(
- f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 1, 1)).abs().sum()}"
- )
- print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
-
- assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 0, 8).sum())
- assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum())
- assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
-
- if rank == 2:
- print(
- f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 8, 4)).abs().sum()}"
- )
- print(
- f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}"
- )
- print(
- f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
- )
- print(
- f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 0, 2)).abs().sum()}"
- )
- print(
- f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 2, 1)).abs().sum()}"
- )
- print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
-
- assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
- assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 0, 2).sum())
- assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
-
- if rank == 3:
- print(
- f"bn3 diff sum in rank {rank}: {(gm.bn3.weight.grad - test_model.bn3.weight.grad.narrow(0, 12, 4)).abs().sum()}"
- )
- print(
- f"conv3 diff sum in rank {rank}: {(gm.conv3.weight.grad - test_model.conv3.weight.grad.narrow(0, 8, 8)).abs().sum()}"
- )
- print(
- f"bn2 diff sum in rank {rank}: {(gm.bn2.weight.grad - test_model.bn2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
- )
- print(
- f"conv2 diff sum in rank {rank}: {(gm.conv2.weight.grad - test_model.conv2.weight.grad.narrow(0, 2, 2)).abs().sum()}"
- )
- print(
- f"bn1 diff sum in rank {rank}: {(gm.bn1.weight.grad - test_model.bn1.weight.grad.narrow(0, 3, 1)).abs().sum()}"
- )
- print(f"conv1 diff sum in rank {rank}: {(gm.conv1.weight.grad - test_model.conv1.weight.grad).sum()}")
-
- assert_close_loose(gm.conv3.weight.grad.sum(), test_model.conv3.weight.grad.narrow(0, 8, 8).sum())
- assert_close_loose(gm.conv2.weight.grad.sum(), test_model.conv2.weight.grad.narrow(0, 2, 2).sum())
- assert_close_loose(gm.conv1.weight.grad.sum(), test_model.conv1.weight.grad.sum())
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-@pytest.mark.dist
-@rerun_if_address_is_in_use()
-def test_apply():
- world_size = 4
- run_func = partial(check_apply_bottleneck, world_size=world_size, port=free_port())
- mp.spawn(run_func, nprocs=world_size)
-
-
-if __name__ == '__main__':
- test_apply()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py
index 66cd3f3f7707..24a3ae5b42c3 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py
@@ -5,19 +5,9 @@
import torch
import torch.multiprocessing as mp
import torch.nn as nn
-from torch.fx import GraphModule
-
-from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
-from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
-from colossalai.auto_parallel.tensor_shard.solver import (
- CostGraph,
- GraphAnalyser,
- Solver,
- SolverOptions,
- StrategiesConstructor,
-)
+
+from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, rerun_if_address_is_in_use
@@ -41,41 +31,22 @@ def check_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
input = torch.rand(4, 4, 4, 4).cuda()
+ test_input = copy.deepcopy(input)
+ # graph():
+ # %x : torch.Tensor [#users=1] = placeholder[target=x]
+ # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
+ # return conv
+ model = ConvModel(4, 4).cuda()
+ test_model = copy.deepcopy(model)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ meta_args = {'x': torch.rand(4, 4, 4, 4).to('meta')}
+ gm = initialize_model(model, meta_args, device_mesh)
- tracer = ColoTracer()
- model = ConvModel(4, 4).cuda()
- test_model = copy.deepcopy(model)
- test_input = copy.deepcopy(input)
-
- input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
- # graph():
- # %x : torch.Tensor [#users=1] = placeholder[target=x]
- # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
- # return conv
- graph = tracer.trace(root=model, meta_args=input_sample)
- gm = GraphModule(model, graph, model.__class__.__name__)
- gm.recompile()
- solver_options = SolverOptions()
- strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
- strategies_constructor.build_strategies_and_cost()
-
- cost_graph = CostGraph(strategies_constructor.leaf_strategies)
- cost_graph.simplify_graph()
- graph_analyser = GraphAnalyser(gm)
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
- ret = solver.call_solver_serialized_args()
- solution = list(ret[0])
- gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh)
- gm = runtime_apply_pass(gm)
- gm.recompile()
- nodes = [node for node in gm.graph.nodes]
- # TODO: wrap the gm to avoid the influence of the user training code
- output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
+ output = gm(input)
origin_output = test_model(test_input)
assert output.equal(origin_output)
origin_loss = origin_output.sum()
@@ -84,13 +55,21 @@ def check_apply(rank, world_size, port):
origin_loss.backward()
loss.backward()
- grad_0 = test_model.conv.weight.grad.narrow(0, 0, 2)
- grad_1 = test_model.conv.weight.grad.narrow(0, 2, 2)
-
- if rank in (0, 1):
- assert_close(gm.conv.weight.grad.data, grad_0.data)
- elif rank in (2, 3):
- assert_close(gm.conv.weight.grad.data, grad_1.data)
+ grad_0 = test_model.conv.weight.grad.narrow(0, 0, 1)
+ grad_1 = test_model.conv.weight.grad.narrow(0, 1, 1)
+ grad_2 = test_model.conv.weight.grad.narrow(0, 2, 1)
+ grad_3 = test_model.conv.weight.grad.narrow(0, 3, 1)
+
+ if rank == 0:
+ assert_close(gm.module.conv.weight.grad.data, grad_0.data)
+ elif rank == 1:
+ assert_close(gm.module.conv.weight.grad.data, grad_1.data)
+ elif rank == 2:
+ assert_close(gm.module.conv.weight.grad.data, grad_2.data)
+ elif rank == 3:
+ assert_close(gm.module.conv.weight.grad.data, grad_3.data)
+ else:
+ raise ValueError(f'rank {rank} does not exist.')
# skip this test due to pulp not installed in CI environment
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
index f4a5ae7ac1c0..bbfc3e1fcc14 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
@@ -3,13 +3,8 @@
from torchvision.models import resnet50
from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
-from colossalai.auto_parallel.tensor_shard.solver import (
- CostGraph,
- GraphAnalyser,
- Solver,
- SolverOptions,
- StrategiesConstructor,
-)
+from colossalai.auto_parallel.tensor_shard.options import SolverOptions
+from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
@@ -56,15 +51,14 @@ def test_cost_graph():
# return fc
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
- graph_analyser = GraphAnalyser(gm)
- liveness_list = graph_analyser.liveness_analysis()
+
solver_options = SolverOptions()
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
- solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
+ solver = Solver(gm.graph, strategies_constructor, cost_graph)
ret = solver.call_solver_serialized_args()
print(ret[0])
diff --git a/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py
new file mode 100644
index 000000000000..9a2240d62de4
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_alphafold/benchmark_autochunk_alphafold.py
@@ -0,0 +1,140 @@
+import time
+from typing import Any, Dict, List
+
+import torch
+import torch.fx
+
+import colossalai
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.utils import free_port
+
+if AUTOCHUNK_AVAILABLE:
+ from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
+ from colossalai.fx.profiler import MetaTensor
+ from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
+
+
+def _benchmark_evoformer_stack_gm(
+ data_args: tuple,
+ max_memory: int,
+ get_model: Any,
+ get_data: Any,
+) -> None:
+ # build model and input
+ model = get_model().cpu().eval()
+ meta_args, concrete_args = get_data(*data_args)
+ if concrete_args is None:
+ concrete_args = []
+
+ # trace the meta graph and setup codegen
+ meta_graph = symbolic_trace(
+ model,
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
+ concrete_args={k: v for k, v in concrete_args},
+ )
+ interp = MetaInfoProp(meta_graph)
+ meta_tensors = [MetaTensor(i[1], fake_device="cpu") for i in meta_args] + [i[1] for i in concrete_args]
+ interp.propagate(*meta_tensors)
+ codegen = AutoChunkCodeGen(
+ meta_graph,
+ max_memory=max_memory,
+ )
+
+ # trace and recompile
+ # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
+ graph = ColoTracer().trace(
+ model,
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
+ concrete_args={k: v for k, v in concrete_args},
+ )
+ graph.set_codegen(codegen)
+ gm = ColoGraphModule(model, graph, ckpt_codegen=False)
+ gm.recompile()
+
+ # init inputs
+ inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
+ inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
+ model.cuda()
+
+ # bench
+ mem = _benchmark_memory(gm, inputs)
+ speed = _benchmark_speed(gm, inputs)
+ print("evoformer stack gm, mem: %.2fMB, time: %.4fs" % (mem, speed))
+
+
+def _benchmark_evoformer_stack_origin(
+ data_args: tuple,
+ get_model: Any,
+ get_data: Any,
+) -> None:
+ # build model and input
+ model = get_model()
+ meta_args, concrete_args = get_data(*data_args)
+ if concrete_args is None:
+ concrete_args = []
+
+ # init inputs
+ inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
+ inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
+ model.cuda()
+
+ # bench
+ mem = _benchmark_memory(model, inputs)
+ speed = _benchmark_speed(model, inputs)
+ print("evoformer stack origin, mem: %.2fMB, time: %.4fs" % (mem, speed))
+ return mem
+
+
+def _benchmark_memory(model, inputs):
+ with torch.no_grad():
+ torch.cuda.reset_peak_memory_stats()
+ now_mem = torch.cuda.memory_allocated() / 1024**2
+ model(*inputs)
+ new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
+ return new_max_mem - now_mem
+
+
+def _benchmark_speed(model, inputs, loop=5):
+ with torch.no_grad():
+ for _ in range(loop // 2 + 1):
+ model(*inputs)
+ torch.cuda.synchronize()
+ time1 = time.time()
+ for _ in range(loop):
+ model(*inputs)
+ torch.cuda.synchronize()
+ time2 = time.time()
+ return (time2 - time1) / loop
+
+
+def benchmark_evoformer_stack(data_args):
+ from test_autochunk_evoformer_stack import get_data, get_model
+ print("\nmsa len: %d, pair len: %d" % (data_args[0], data_args[1]))
+ max_mem = _benchmark_evoformer_stack_origin(data_args, get_model, get_data)
+ for ratio in [0.5, 0.4, 0.3, 0.2, 0.1]:
+ try:
+ _benchmark_evoformer_stack_gm(data_args, max_mem * ratio, get_model, get_data)
+ except RuntimeError as e:
+ if e.args[0] == 'Search failed. Try a larger memory threshold.':
+ break
+ except Exception as e:
+ raise e
+ _benchmark_evoformer_stack_gm(data_args, None, get_model, get_data)
+
+
+if __name__ == "__main__":
+ # launch colossalai
+ colossalai.launch(
+ config={},
+ rank=0,
+ world_size=1,
+ host="localhost",
+ port=free_port(),
+ backend="nccl",
+ )
+ benchmark_evoformer_stack((256, 256))
+ benchmark_evoformer_stack((256, 512))
+ benchmark_evoformer_stack((256, 1024))
+ benchmark_evoformer_stack((256, 1280))
diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py
new file mode 100644
index 000000000000..cb250d6402e2
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_alphafold_utils.py
@@ -0,0 +1,132 @@
+from typing import Any, Dict, List
+
+import torch
+import torch.fx
+
+import colossalai
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+from colossalai.autochunk.utils import flat_list
+from colossalai.core import global_context as gpc
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.utils import free_port
+
+if AUTOCHUNK_AVAILABLE:
+ from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
+ from colossalai.fx.profiler import MetaTensor
+ from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
+
+
+def assert_codegen_run(
+ model: Any,
+ meta_args: List,
+ concrete_args: List = None,
+ max_memory: int = None,
+ print_mem: bool = False,
+ print_est_mem: bool = False,
+ print_progress: bool = False,
+ print_code: bool = False,
+) -> List[Dict]:
+ if concrete_args is None:
+ concrete_args = []
+
+ # trace the meta graph and setup codegen
+ meta_graph = symbolic_trace(
+ model,
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
+ concrete_args={k: v for k, v in concrete_args},
+ )
+ interp = MetaInfoProp(meta_graph)
+ meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
+ interp.propagate(*meta_tensors)
+ codegen = AutoChunkCodeGen(
+ meta_graph,
+ max_memory=max_memory,
+ print_mem=print_est_mem,
+ print_progress=print_progress,
+ )
+ chunks = codegen.chunk_infos
+
+ # trace and recompile
+ # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
+ graph = ColoTracer().trace(
+ model,
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
+ concrete_args={k: v for k, v in concrete_args},
+ )
+ graph.set_codegen(codegen)
+ gm = ColoGraphModule(model, graph, ckpt_codegen=False)
+ gm.recompile()
+
+ # assert chunk in code
+ code = graph.python_code("self").src
+ if print_code:
+ print(code)
+ assert "chunk_size = None; " in code
+
+ # assert result
+ inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
+ inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
+ model.cuda()
+ with torch.no_grad():
+ if print_mem:
+ torch.cuda.reset_peak_memory_stats()
+ now_mem = torch.cuda.memory_allocated() / 1024**2
+ out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
+ if print_mem:
+ new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
+ print("mem: %.2fMB" % (new_max_mem - now_mem))
+ out_model = model(*inputs)
+ out_gm = flat_list(out_gm)
+ out_model = flat_list(out_model)
+ for out_gm_i, out_model_i in zip(out_gm, out_model):
+ assert torch.allclose(out_gm_i, out_model_i,
+ atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
+ torch.abs(out_gm_i - out_model_i))
+
+ return chunks
+
+
+def run_test(
+ rank: int,
+ data_args: tuple,
+ max_memory: int,
+ get_model: Any,
+ get_data: Any,
+ print_code: bool = False,
+ print_mem: bool = False,
+ print_est_mem: bool = False,
+ print_progress: bool = False,
+ get_chunk_target: Any = None,
+) -> None:
+ # launch colossalai
+ colossalai.launch(
+ config={},
+ rank=rank,
+ world_size=1,
+ host="localhost",
+ port=free_port(),
+ backend="nccl",
+ )
+
+ # build model and input
+ model = get_model()
+ meta_args, concrete_args = get_data(*data_args)
+ chunks = assert_codegen_run(
+ model,
+ meta_args=meta_args,
+ concrete_args=concrete_args,
+ max_memory=max_memory,
+ print_code=print_code,
+ print_mem=print_mem,
+ print_est_mem=print_est_mem,
+ print_progress=print_progress,
+ )
+
+ if get_chunk_target is not None:
+ chunk_found = [i["region"] for i in chunks]
+ chunk_target = get_chunk_target()[max_memory]
+ assert chunk_found == chunk_target, "found regions %s doesn't equal target regions %s" % (
+ str(chunk_found),
+ str(chunk_target),
+ )
diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py
new file mode 100644
index 000000000000..17a5abf4cab8
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_block.py
@@ -0,0 +1,95 @@
+from functools import partial
+from typing import Dict, List, Tuple
+
+import pytest
+import torch
+import torch.fx
+import torch.multiprocessing as mp
+
+try:
+ from fastfold.model.nn.evoformer import EvoformerBlock
+ HAS_REPO = True
+except:
+ HAS_REPO = False
+
+from test_autochunk_alphafold_utils import run_test
+
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+
+
+def get_model():
+ model = EvoformerBlock(
+ c_m=256,
+ c_z=128,
+ c_hidden_msa_att=32,
+ c_hidden_opm=32,
+ c_hidden_mul=128,
+ c_hidden_pair_att=32,
+ no_heads_msa=8,
+ no_heads_pair=4,
+ transition_n=4,
+ msa_dropout=0.15,
+ pair_dropout=0.15,
+ inf=1e4,
+ eps=1e-4,
+ is_multimer=False,
+ ).eval().cuda()
+ return model
+
+
+def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
+ node = torch.randn(1, msa_len, pair_len, 256).cuda()
+ node_mask = torch.randn(1, msa_len, pair_len).cuda()
+ pair = torch.randn(1, pair_len, pair_len, 128).cuda()
+ pair_mask = torch.randn(1, pair_len, pair_len).cuda()
+
+ meta_args = [
+ ("m", node),
+ ("z", pair),
+ ("msa_mask", node_mask),
+ ("pair_mask", pair_mask),
+ ]
+ concrete_args = [("chunk_size", None), ("_mask_trans", True)]
+ return meta_args, concrete_args
+
+
+def get_chunk_target() -> Dict:
+ return {
+ None: [(120, 126), (225, 244), (270, 289), (306, 311), (70, 106), (23, 46), (146, 152), (187, 193), (181, 184),
+ (140, 145), (162, 163), (203, 204)],
+ 20: [(120, 123), (232, 237), (277, 282), (305, 306)],
+ 24: [(122, 123)],
+ }
+
+
+@pytest.mark.skipif(
+ not (AUTOCHUNK_AVAILABLE and HAS_REPO),
+ reason="torch version is lower than 1.12.0",
+)
+@pytest.mark.parametrize("max_memory", [None, 20, 24])
+@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
+def test_evoformer_block(data_args, max_memory):
+ run_func = partial(
+ run_test,
+ data_args=data_args,
+ max_memory=max_memory,
+ get_model=get_model,
+ get_data=get_data,
+ get_chunk_target=get_chunk_target,
+ )
+ mp.spawn(run_func, nprocs=1)
+
+
+if __name__ == "__main__":
+ run_test(
+ rank=0,
+ data_args=(32, 64),
+ max_memory=24,
+ get_model=get_model,
+ get_data=get_data,
+ get_chunk_target=get_chunk_target,
+ print_code=False,
+ print_mem=False,
+ print_est_mem=False,
+ print_progress=False,
+ )
diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py
new file mode 100644
index 000000000000..5210c1c8d48e
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_evoformer_stack.py
@@ -0,0 +1,87 @@
+from functools import partial
+from typing import List, Tuple
+
+import pytest
+import torch
+import torch.fx
+import torch.multiprocessing as mp
+
+try:
+ from fastfold.model.nn.evoformer import EvoformerStack
+ HAS_REPO = True
+except:
+ HAS_REPO = False
+
+from test_autochunk_alphafold_utils import run_test
+
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+
+
+def get_model():
+ model = EvoformerStack(
+ c_m=256,
+ c_z=128,
+ c_hidden_msa_att=32,
+ c_hidden_opm=32,
+ c_hidden_mul=128,
+ c_hidden_pair_att=32,
+ c_s=384,
+ no_heads_msa=8,
+ no_heads_pair=4,
+ no_blocks=2, # 48
+ transition_n=4,
+ msa_dropout=0.15,
+ pair_dropout=0.25,
+ blocks_per_ckpt=None,
+ inf=1000000000.0,
+ eps=1e-08,
+ clear_cache_between_blocks=False,
+ is_multimer=False,
+ ).eval().cuda()
+ return model
+
+
+def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
+ node = torch.randn(1, msa_len, pair_len, 256).cuda()
+ node_mask = torch.randn(1, msa_len, pair_len).cuda()
+ pair = torch.randn(1, pair_len, pair_len, 128).cuda()
+ pair_mask = torch.randn(1, pair_len, pair_len).cuda()
+
+ meta_args = [
+ ("m", node),
+ ("z", pair),
+ ("msa_mask", node_mask),
+ ("pair_mask", pair_mask),
+ ]
+ concrete_args = [("chunk_size", None), ("_mask_trans", True)]
+ return meta_args, concrete_args
+
+
+@pytest.mark.skipif(
+ not (AUTOCHUNK_AVAILABLE and HAS_REPO),
+ reason="torch version is lower than 1.12.0",
+)
+@pytest.mark.parametrize("max_memory", [None, 20, 24])
+@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
+def test_evoformer_stack(data_args, max_memory):
+ run_func = partial(
+ run_test,
+ data_args=data_args,
+ max_memory=max_memory,
+ get_model=get_model,
+ get_data=get_data,
+ )
+ mp.spawn(run_func, nprocs=1)
+
+
+if __name__ == "__main__":
+ run_test(
+ rank=0,
+ data_args=(32, 64),
+ max_memory=None,
+ get_model=get_model,
+ get_data=get_data,
+ print_code=False,
+ print_mem=False,
+ print_progress=False,
+ )
diff --git a/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py
new file mode 100644
index 000000000000..ad955479e617
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_alphafold/test_autochunk_extramsa_block.py
@@ -0,0 +1,83 @@
+from functools import partial
+from typing import Dict, List, Tuple
+
+import pytest
+import torch
+import torch.fx
+import torch.multiprocessing as mp
+
+try:
+ from fastfold.model.nn.evoformer import ExtraMSABlock
+ HAS_REPO = True
+except:
+ HAS_REPO = False
+from test_autochunk_alphafold_utils import run_test
+
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+
+
+def get_model():
+ model = ExtraMSABlock(
+ c_m=256,
+ c_z=128,
+ c_hidden_msa_att=32,
+ c_hidden_opm=32,
+ c_hidden_mul=128,
+ c_hidden_pair_att=32,
+ no_heads_msa=8,
+ no_heads_pair=4,
+ transition_n=4,
+ msa_dropout=0.15,
+ pair_dropout=0.15,
+ inf=1e4,
+ eps=1e-4,
+ ckpt=False,
+ is_multimer=False,
+ ).eval().cuda()
+ return model
+
+
+def get_data(msa_len: int, pair_len: int) -> Tuple[List, List]:
+ node = torch.randn(1, msa_len, pair_len, 256).cuda()
+ node_mask = torch.randn(1, msa_len, pair_len).cuda()
+ pair = torch.randn(1, pair_len, pair_len, 128).cuda()
+ pair_mask = torch.randn(1, pair_len, pair_len).cuda()
+
+ meta_args = [
+ ("m", node),
+ ("z", pair),
+ ("msa_mask", node_mask),
+ ("pair_mask", pair_mask),
+ ]
+ concrete_args = [("chunk_size", None), ("_chunk_logits", 1024)]
+ return meta_args, concrete_args
+
+
+@pytest.mark.skipif(
+ not (AUTOCHUNK_AVAILABLE and HAS_REPO),
+ reason="torch version is lower than 1.12.0",
+)
+@pytest.mark.parametrize("max_memory", [None, 20, 24])
+@pytest.mark.parametrize("data_args", [(32, 64)]) # (msa_len, pair_len)
+def test_extramsa_block(data_args, max_memory):
+ run_func = partial(
+ run_test,
+ data_args=data_args,
+ max_memory=max_memory,
+ get_model=get_model,
+ get_data=get_data,
+ )
+ mp.spawn(run_func, nprocs=1)
+
+
+if __name__ == "__main__":
+ run_test(
+ rank=0,
+ data_args=(32, 64),
+ max_memory=None,
+ get_model=get_model,
+ get_data=get_data,
+ print_code=False,
+ print_mem=False,
+ print_progress=False,
+ )
diff --git a/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py
new file mode 100644
index 000000000000..6fb7efa7a8fc
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_diffuser/benchmark_autochunk_diffuser.py
@@ -0,0 +1,147 @@
+import time
+from typing import Any, Dict, List
+
+import torch
+import torch.fx
+
+import colossalai
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.fx.profiler import parameter_size
+from colossalai.utils import free_port
+
+if AUTOCHUNK_AVAILABLE:
+ from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
+ from colossalai.fx.profiler import MetaTensor
+ from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
+
+
+def _benchmark_autochunk_unet_gm(
+ model: Any,
+ data: tuple,
+ max_memory: int = None,
+) -> None:
+ model = model.cuda().eval()
+
+ # build model and input
+ meta_args, concrete_args = data
+ if concrete_args is None:
+ concrete_args = {}
+
+ # trace the meta graph and setup codegen
+ meta_graph = symbolic_trace(
+ model,
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
+ concrete_args={k: v for k, v in concrete_args},
+ )
+ interp = MetaInfoProp(meta_graph)
+ meta_tensors = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
+ meta_tensors = [MetaTensor(i, fake_device="cpu") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
+ interp.propagate(*meta_tensors)
+ codegen = AutoChunkCodeGen(
+ meta_graph,
+ max_memory=max_memory,
+ )
+
+ # trace and recompile
+ # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
+ graph = ColoTracer().trace(
+ model.cuda().eval(),
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
+ concrete_args={k: v for k, v in concrete_args},
+ )
+ graph.set_codegen(codegen)
+ gm = ColoGraphModule(model, graph, ckpt_codegen=False)
+ gm.recompile()
+
+ # init inputs
+ inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
+ inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
+ model.cuda().eval()
+
+ # bench
+ para_mem = float(parameter_size(model)) / 1024**2
+ act_mem = _benchmark_memory(gm, inputs)
+ speed = _benchmark_speed(gm, inputs)
+ print("unet autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
+ (speed, act_mem, para_mem, act_mem + para_mem))
+
+
+def _benchmark_autochunk_unet_origin(
+ model: Any,
+ data: tuple,
+) -> None:
+ # build model and input
+ meta_args, concrete_args = data
+ if concrete_args is None:
+ concrete_args = {}
+
+ # init inputs
+ inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
+ inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
+ model.cuda().eval()
+
+ # bench
+ para_mem = float(parameter_size(model)) / 1024**2
+ act_mem = _benchmark_memory(model, inputs)
+ speed = _benchmark_speed(model, inputs)
+ print("unet origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
+ (speed, act_mem, para_mem, act_mem + para_mem))
+ return act_mem
+
+
+def _benchmark_memory(model, inputs):
+ with torch.no_grad():
+ torch.cuda.reset_peak_memory_stats()
+ now_mem = float(torch.cuda.memory_allocated()) / 1024**2
+ model(*inputs)
+ new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2
+ return new_max_mem - now_mem
+
+
+def _benchmark_speed(model, inputs, loop=5):
+ with torch.no_grad():
+ for _ in range(loop // 2 + 1):
+ model(*inputs)
+ torch.cuda.synchronize()
+ time1 = time.time()
+ for _ in range(loop):
+ model(*inputs)
+ torch.cuda.synchronize()
+ time2 = time.time()
+ return (time2 - time1) / loop
+
+
+def benchmark_autochunk_unet(batch=1, height=448, width=448):
+ from test_autochunk_unet import UNet2DModel, get_data
+ model = UNet2DModel()
+ latent_shape = (batch, 3, height // 7, width // 7)
+
+ print("\nbatch: %d, height: %d, width: %d" % (batch, height, width))
+ max_mem = _benchmark_autochunk_unet_origin(model, get_data(latent_shape))
+ for ratio in [0.5, 0.4, 0.3, 0.2]:
+ try:
+ _benchmark_autochunk_unet_gm(model, get_data(latent_shape), max_mem * ratio)
+ except RuntimeError as e:
+ if e.args[0] == 'Search failed. Try a larger memory threshold.':
+ break
+ except Exception as e:
+ raise e
+ _benchmark_autochunk_unet_gm(model, get_data(latent_shape), None)
+
+
+if __name__ == "__main__":
+ # launch colossalai
+ colossalai.launch(
+ config={},
+ rank=0,
+ world_size=1,
+ host="localhost",
+ port=free_port(),
+ backend="nccl",
+ )
+ benchmark_autochunk_unet(batch=1, height=224 * 3, width=224 * 3)
+ benchmark_autochunk_unet(batch=1, height=224 * 4, width=224 * 4)
+ benchmark_autochunk_unet(batch=1, height=224 * 5, width=224 * 5)
+ benchmark_autochunk_unet(batch=1, height=224 * 6, width=224 * 6)
diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py
new file mode 100644
index 000000000000..529250fe8f51
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_diffuser_utils.py
@@ -0,0 +1,136 @@
+from typing import Any, Dict, List
+
+import torch
+import torch.fx
+
+import colossalai
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+from colossalai.core import global_context as gpc
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.utils import free_port
+
+if AUTOCHUNK_AVAILABLE:
+ from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
+ from colossalai.fx.profiler import MetaTensor
+ from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
+
+
+def assert_codegen_run(
+ model: Any,
+ meta_args: List,
+ concrete_args: List = None,
+ max_memory: int = None,
+ print_mem: bool = False,
+ print_est_mem: bool = False,
+ print_progress: bool = False,
+ print_code: bool = False,
+) -> List[Dict]:
+ if concrete_args is None:
+ concrete_args = []
+ model = model()
+
+ # trace the meta graph and setup codegen
+ meta_graph = symbolic_trace(
+ model,
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
+ concrete_args={k: v for k, v in concrete_args},
+ )
+ model = model.cuda().eval()
+ interp = MetaInfoProp(meta_graph)
+ meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args] + [i[1] for i in concrete_args]
+ interp.propagate(*meta_tensors)
+ codegen = AutoChunkCodeGen(
+ meta_graph,
+ max_memory=max_memory,
+ print_mem=print_est_mem,
+ print_progress=print_progress,
+ )
+ chunks = codegen.chunk_infos
+
+ # trace and recompile
+ # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
+ graph = ColoTracer().trace(
+ model.cuda(),
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args},
+ concrete_args={k: v for k, v in concrete_args},
+ )
+ graph.set_codegen(codegen)
+ gm = ColoGraphModule(model, graph, ckpt_codegen=False)
+ gm.recompile()
+
+ # assert chunk in code
+ code = graph.python_code("self").src
+ if print_code:
+ print(code)
+ assert "chunk_size = None; " in code
+
+ # assert result
+ inputs = [i[1] for i in meta_args] + [i[1] for i in concrete_args]
+ inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
+ model.cuda().eval()
+ gm.eval()
+ with torch.no_grad():
+ if print_mem:
+ torch.cuda.reset_peak_memory_stats()
+ now_mem_gm = torch.cuda.memory_allocated() / 1024**2
+ out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
+ if print_mem:
+ max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2
+ torch.cuda.reset_peak_memory_stats()
+ now_mem_ori = torch.cuda.memory_allocated() / 1024**2
+ out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
+ if print_mem:
+ max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2
+ print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm))
+
+ assert torch.allclose(out_gm["sample"], out_model["sample"],
+ atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
+ torch.abs(out_gm["sample"] - out_model["sample"]))
+
+ return chunks
+
+
+def run_test(
+ rank: int,
+ model: Any,
+ data: tuple,
+ max_memory: int,
+ print_code: bool = False,
+ print_mem: bool = False,
+ print_est_mem: bool = False,
+ print_progress: bool = False,
+ get_chunk_target: Any = None,
+) -> None:
+ # launch colossalai
+ colossalai.launch(
+ config={},
+ rank=rank,
+ world_size=1,
+ host="localhost",
+ port=free_port(),
+ backend="nccl",
+ )
+
+ # build model and input
+ meta_args, concrete_args = data
+ chunks = assert_codegen_run(
+ model,
+ meta_args=meta_args,
+ concrete_args=concrete_args,
+ max_memory=max_memory,
+ print_code=print_code,
+ print_mem=print_mem,
+ print_est_mem=print_est_mem,
+ print_progress=print_progress,
+ )
+
+ if get_chunk_target is not None:
+ chunk_found = [i["region"] for i in chunks]
+ chunk_target = get_chunk_target()[max_memory]
+ assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
+ str(chunk_found),
+ str(chunk_target),
+ )
+
+ gpc.destroy()
diff --git a/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py
new file mode 100644
index 000000000000..16c5b10ff4ae
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_diffuser/test_autochunk_unet.py
@@ -0,0 +1,63 @@
+from functools import partial
+from typing import List, Tuple
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+
+try:
+ from diffusers import UNet2DModel
+ MODELS = [UNet2DModel]
+ HAS_REPO = True
+except:
+ MODELS = []
+ HAS_REPO = False
+
+from test_autochunk_diffuser_utils import run_test
+
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+
+BATCH_SIZE = 1
+HEIGHT = 448
+WIDTH = 448
+IN_CHANNELS = 3
+LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 7, WIDTH // 7)
+
+
+def get_data(shape: tuple) -> Tuple[List, List]:
+ sample = torch.randn(shape)
+ meta_args = [
+ ("sample", sample),
+ ]
+ concrete_args = [("timestep", 50)]
+ return meta_args, concrete_args
+
+
+@pytest.mark.skipif(
+ not (AUTOCHUNK_AVAILABLE and HAS_REPO),
+ reason="torch version is lower than 1.12.0",
+)
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("shape", [LATENTS_SHAPE])
+@pytest.mark.parametrize("max_memory", [None, 150, 300])
+def test_evoformer_block(model, shape, max_memory):
+ run_func = partial(
+ run_test,
+ max_memory=max_memory,
+ model=model,
+ data=get_data(shape),
+ )
+ mp.spawn(run_func, nprocs=1)
+
+
+if __name__ == "__main__":
+ run_test(
+ rank=0,
+ data=get_data(LATENTS_SHAPE),
+ max_memory=None,
+ model=UNet2DModel,
+ print_code=False,
+ print_mem=True,
+ print_est_mem=False,
+ print_progress=False,
+ )
diff --git a/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py
new file mode 100644
index 000000000000..63490aaee7ff
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_transformer/benchmark_autochunk_transformer.py
@@ -0,0 +1,149 @@
+import time
+from typing import Any, Dict, List
+
+import torch
+import torch.fx
+
+import colossalai
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.fx.profiler import parameter_size
+from colossalai.utils import free_port
+
+if AUTOCHUNK_AVAILABLE:
+ from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
+ from colossalai.fx.profiler import MetaTensor
+ from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
+
+
+def _benchmark_autochunk_gpt_gm(
+ model: Any,
+ data: tuple,
+ max_memory: int = None,
+) -> None:
+ model = model.eval().cpu()
+
+ # build model and input
+ meta_args, concrete_args, sequence = data
+ if concrete_args is None:
+ concrete_args = {}
+
+ # trace the meta graph and setup codegen
+ meta_graph = symbolic_trace(
+ model,
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
+ concrete_args={k: v for k, v in concrete_args.items()},
+ )
+ interp = MetaInfoProp(meta_graph)
+ meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
+ meta_tensors = [MetaTensor(i, fake_device="cpu") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
+ interp.propagate(*meta_tensors)
+ codegen = AutoChunkCodeGen(
+ meta_graph,
+ max_memory=max_memory,
+ )
+
+ # trace and recompile
+ # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
+ graph = ColoTracer().trace(
+ model.cuda().eval(),
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
+ concrete_args={k: v for k, v in concrete_args.items()},
+ )
+ graph.set_codegen(codegen)
+ gm = ColoGraphModule(model, graph, ckpt_codegen=False)
+ gm.recompile()
+
+ # init inputs
+ inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
+ inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
+ model.cuda()
+
+ # bench
+ para_mem = float(parameter_size(model)) / 1024**2 * 6
+ act_mem = _benchmark_memory(gm, inputs)
+ speed = _benchmark_speed(gm, inputs)
+ print("gpt autochunk, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
+ (speed, act_mem, para_mem, act_mem + para_mem))
+
+
+def _benchmark_autochunk_gpt_origin(
+ model: Any,
+ data: tuple,
+) -> None:
+ # build model and input
+ meta_args, concrete_args, sequence = data
+ if concrete_args is None:
+ concrete_args = {}
+
+ # init inputs
+ inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
+ inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
+ model.cuda().eval()
+
+ # bench
+ para_mem = float(parameter_size(model)) / 1024**2 * 6
+ act_mem = _benchmark_memory(model, inputs)
+ speed = _benchmark_speed(model, inputs)
+ print("gpt origin, time: %.4fs, act mem: %.2fMB, para mem: %.2fMB, all mem: %.2fMB" %
+ (speed, act_mem, para_mem, act_mem + para_mem))
+ return act_mem
+
+
+def _benchmark_memory(model, inputs):
+ with torch.no_grad():
+ torch.cuda.reset_peak_memory_stats()
+ now_mem = float(torch.cuda.memory_allocated()) / 1024**2
+ model(*inputs)
+ new_max_mem = float(torch.cuda.max_memory_allocated()) / 1024**2
+ return new_max_mem - now_mem
+
+
+def _benchmark_speed(model, inputs, loop=5):
+ with torch.no_grad():
+ for _ in range(loop // 2 + 1):
+ model(*inputs)
+ torch.cuda.synchronize()
+ time1 = time.time()
+ for _ in range(loop):
+ model(*inputs)
+ torch.cuda.synchronize()
+ time2 = time.time()
+ return (time2 - time1) / loop
+
+
+def benchmark_autochunk_gpt(batch=1, seq=512, n_embd=768, n_head=12):
+ from test_autochunk_gpt import GPT2Config, GPT2Model, get_data
+ model = GPT2Model
+ config = GPT2Config(n_embd=n_embd, n_positions=seq, n_layer=2, n_head=n_head)
+ model = model(config=config)
+ shape = [batch, seq]
+ print("\nbatch: %d, seq: %d, n_embd: %d, n_head: %d" % (batch, seq, n_embd, n_head))
+ max_mem = _benchmark_autochunk_gpt_origin(model, get_data(shape))
+ for ratio in [0.5, 0.4, 0.3, 0.2]:
+ try:
+ _benchmark_autochunk_gpt_gm(model, get_data(shape), max_mem * ratio)
+ except RuntimeError as e:
+ if e.args[0] == 'Search failed. Try a larger memory threshold.':
+ break
+ except Exception as e:
+ raise e
+ _benchmark_autochunk_gpt_gm(model, get_data(shape), None)
+
+
+if __name__ == "__main__":
+ # launch colossalai
+ colossalai.launch(
+ config={},
+ rank=0,
+ world_size=1,
+ host="localhost",
+ port=free_port(),
+ backend="nccl",
+ )
+ benchmark_autochunk_gpt(batch=1, seq=1024, n_embd=768, n_head=12)
+ benchmark_autochunk_gpt(batch=1, seq=2048, n_embd=768, n_head=12)
+ benchmark_autochunk_gpt(batch=1, seq=4096, n_embd=768, n_head=12)
+ benchmark_autochunk_gpt(batch=1, seq=6144, n_embd=768, n_head=12)
+ benchmark_autochunk_gpt(batch=1, seq=8192, n_embd=768, n_head=12)
diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py
new file mode 100644
index 000000000000..018a2557a974
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_gpt.py
@@ -0,0 +1,62 @@
+from functools import partial
+from typing import List, Tuple
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+
+try:
+ from transformers import GPT2Config, GPT2Model
+ MODELS = [GPT2Model]
+ HAS_REPO = True
+except:
+ MODELS = []
+ HAS_REPO = False
+
+from test_autochunk_transformer_utils import run_test
+
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+
+BATCH_SIZE = 1
+SEQ_LENGTH = 512
+
+
+def get_data(shape: tuple) -> Tuple[List, List]:
+ input_ids = torch.zeros(shape, dtype=torch.int64)
+ token_type_ids = torch.zeros(shape, dtype=torch.int64)
+ attention_mask = torch.ones(shape, dtype=torch.int64)
+ meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+ concrete_args = {"past_key_values": None}
+ sequence = ["input_ids", "past_key_values", "attention_mask", "token_type_ids"]
+ return meta_args, concrete_args, sequence
+
+
+@pytest.mark.skipif(
+ not (AUTOCHUNK_AVAILABLE and HAS_REPO),
+ reason="torch version is lower than 1.12.0",
+)
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("shape", [(BATCH_SIZE, SEQ_LENGTH)])
+@pytest.mark.parametrize("max_memory", [None, 6, 8])
+def test_autochunk_gpt(model, shape, max_memory):
+ run_func = partial(
+ run_test,
+ data=get_data(shape),
+ max_memory=max_memory,
+ model=model,
+ config=GPT2Config(n_embd=96, n_positions=shape[1], n_layer=2, n_head=4),
+ )
+ mp.spawn(run_func, nprocs=1)
+
+
+if __name__ == "__main__":
+ run_test(rank=0,
+ data=get_data((BATCH_SIZE, SEQ_LENGTH)),
+ max_memory=None,
+ model=GPT2Model,
+ config=GPT2Config(n_embd=96, n_position=SEQ_LENGTH, n_layer=2, n_head=4),
+ print_code=False,
+ print_est_mem=False,
+ print_mem=False,
+ print_progress=False,
+ eval_mem=False)
diff --git a/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py
new file mode 100644
index 000000000000..bc5eda7edf91
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_transformer/test_autochunk_transformer_utils.py
@@ -0,0 +1,141 @@
+from typing import Any, Dict, List
+
+import torch
+import torch.fx
+
+import colossalai
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+from colossalai.core import global_context as gpc
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.utils import free_port
+
+if AUTOCHUNK_AVAILABLE:
+ from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
+ from colossalai.fx.profiler import MetaTensor
+ from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
+
+
+def assert_codegen_run(
+ model: Any,
+ data: tuple,
+ max_memory: int = None,
+ print_est_mem: bool = False,
+ print_mem: bool = False,
+ print_progress: bool = False,
+ print_code: bool = False,
+ eval_mem: bool = False,
+) -> List[Dict]:
+ meta_args, concrete_args, sequence = data
+ if concrete_args is None:
+ concrete_args = {}
+
+ # trace the meta graph and setup codegen
+ meta_graph = symbolic_trace(
+ model,
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
+ concrete_args={k: v for k, v in concrete_args.items()},
+ )
+ interp = MetaInfoProp(meta_graph)
+ meta_tensors = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
+ meta_tensors = [MetaTensor(i, fake_device="cuda:0") if isinstance(i, torch.Tensor) else i for i in meta_tensors]
+ interp.propagate(*meta_tensors)
+ codegen = AutoChunkCodeGen(meta_graph,
+ max_memory=max_memory,
+ print_mem=print_est_mem,
+ print_progress=print_progress,
+ eval_mem=eval_mem)
+ chunks = codegen.chunk_infos
+
+ # trace and recompile
+ # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
+ graph = ColoTracer().trace(
+ model.cuda(),
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
+ concrete_args={k: v for k, v in concrete_args.items()},
+ )
+ graph.set_codegen(codegen)
+ gm = ColoGraphModule(model, graph, ckpt_codegen=False)
+ gm.recompile()
+
+ # assert chunk in code
+ code = graph.python_code("self").src
+ if print_code:
+ print(code)
+ assert "chunk_size = None; " in code
+
+ # assert result
+ inputs = [meta_args[i] if i in meta_args else concrete_args[i] for i in sequence]
+ inputs = [i.cuda() if isinstance(i, torch.Tensor) else i for i in inputs]
+ model.cuda().eval()
+ gm.eval()
+ with torch.no_grad():
+ if print_mem:
+ torch.cuda.reset_peak_memory_stats()
+ now_mem = torch.cuda.memory_allocated() / 1024**2
+ out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
+ if print_mem:
+ new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
+ print("mem: %.2fMB" % (new_max_mem - now_mem))
+ out_model = model(*inputs)
+ assert_allclose(out_model, out_gm)
+ return chunks
+
+
+def assert_allclose(out_model: Any, out_gm: Any) -> None:
+ """
+ assert allclose for out
+ """
+ if isinstance(out_model, torch.Tensor):
+ assert torch.allclose(out_model, out_gm,
+ atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
+ torch.abs(out_model - out_gm))
+ elif isinstance(out_model, dict):
+ for k in out_model.keys():
+ assert_allclose(out_model[k], out_gm[k])
+ elif isinstance(out_model, tuple) or isinstance(out_model, list) or isinstance(out_model, set):
+ for i, j in zip(out_model, out_gm):
+ assert_allclose(i, j)
+
+
+def run_test(
+ rank: int,
+ model: Any,
+ config: Any,
+ data: tuple,
+ max_memory: int,
+ print_code: bool = False,
+ print_est_mem: bool = False,
+ print_mem: bool = False,
+ print_progress: bool = False,
+ eval_mem: bool = False,
+ get_chunk_target: Any = None,
+) -> None:
+ model = model(config=config)
+ # launch colossalai
+ colossalai.launch(
+ config={},
+ rank=rank,
+ world_size=1,
+ host="localhost",
+ port=free_port(),
+ backend="nccl",
+ )
+
+ # build model and input
+ chunks = assert_codegen_run(model,
+ data=data,
+ max_memory=max_memory,
+ print_code=print_code,
+ print_est_mem=print_est_mem,
+ print_mem=print_mem,
+ print_progress=print_progress,
+ eval_mem=eval_mem)
+
+ if get_chunk_target is not None:
+ chunk_found = [i["region"] for i in chunks]
+ chunk_target = get_chunk_target()[max_memory]
+ assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
+ str(chunk_found),
+ str(chunk_target),
+ )
diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py
new file mode 100644
index 000000000000..2b7cbf1390d2
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit.py
@@ -0,0 +1,53 @@
+from functools import partial
+from typing import List, Tuple
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+
+try:
+ from timm.models.vision_transformer import vit_large_patch16_384 as vit
+ MODELS = [vit]
+ HAS_REPO = True
+except:
+ MODELS = []
+ HAS_REPO = False
+
+from test_autochunk_vit_utils import run_test
+
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+
+
+def get_data() -> Tuple[List, List]:
+ data = torch.rand(1, 3, 384, 384)
+ meta_args = {'x': data}
+ return data, meta_args
+
+
+@pytest.mark.skipif(
+ not (AUTOCHUNK_AVAILABLE and HAS_REPO),
+ reason="torch version is lower than 1.12.0",
+)
+@pytest.mark.parametrize("model", MODELS)
+@pytest.mark.parametrize("max_memory", [None, 32, 40])
+def test_evoformer_block(model, max_memory):
+ run_func = partial(
+ run_test,
+ max_memory=max_memory,
+ model=model,
+ data=get_data(),
+ )
+ mp.spawn(run_func, nprocs=1)
+
+
+if __name__ == "__main__":
+ run_test(
+ rank=0,
+ data=get_data(),
+ max_memory=None,
+ model=vit,
+ print_code=False,
+ print_mem=False,
+ print_est_mem=False,
+ print_progress=False,
+ )
diff --git a/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py
new file mode 100644
index 000000000000..035dd59799b4
--- /dev/null
+++ b/tests/test_autochunk/test_autochunk_vit/test_autochunk_vit_utils.py
@@ -0,0 +1,128 @@
+from typing import Any, Dict, List
+
+import torch
+import torch.fx
+
+import colossalai
+from colossalai.autochunk.autochunk_codegen import AUTOCHUNK_AVAILABLE
+from colossalai.core import global_context as gpc
+from colossalai.fx.graph_module import ColoGraphModule
+from colossalai.fx.passes.meta_info_prop import MetaInfoProp
+from colossalai.utils import free_port
+
+if AUTOCHUNK_AVAILABLE:
+ from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
+ from colossalai.fx.profiler import MetaTensor
+ from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
+
+
+def assert_codegen_run(
+ model: Any,
+ meta_args: Dict,
+ data: Any,
+ max_memory: int = None,
+ print_mem: bool = False,
+ print_est_mem: bool = False,
+ print_progress: bool = False,
+ print_code: bool = False,
+) -> List[Dict]:
+ model = model()
+
+ # trace the meta graph and setup codegen
+ meta_graph = symbolic_trace(model, meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()})
+ model = model.cuda().eval()
+ interp = MetaInfoProp(meta_graph)
+ meta_tensors = [MetaTensor(i[1], fake_device="cuda:0") for i in meta_args.items()]
+ interp.propagate(*meta_tensors)
+ codegen = AutoChunkCodeGen(
+ meta_graph,
+ max_memory=max_memory,
+ print_mem=print_est_mem,
+ print_progress=print_progress,
+ )
+ chunks = codegen.chunk_infos
+
+ # trace and recompile
+ # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
+ graph = ColoTracer().trace(
+ model.cuda(),
+ meta_args={k: v.to(torch.device("meta")) for k, v in meta_args.items()},
+ )
+ graph.set_codegen(codegen)
+ gm = ColoGraphModule(model, graph, ckpt_codegen=False)
+ gm.recompile()
+
+ # assert chunk in code
+ code = graph.python_code("self").src
+ if print_code:
+ print(code)
+ assert "chunk_size = None; " in code
+
+ # assert result
+ inputs = [data.cuda()]
+ model.cuda().eval()
+ gm.eval()
+ with torch.no_grad():
+ if print_mem:
+ torch.cuda.reset_peak_memory_stats()
+ now_mem_gm = torch.cuda.memory_allocated() / 1024**2
+ out_gm = gm(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
+ if print_mem:
+ max_mem_gm = torch.cuda.max_memory_allocated() / 1024**2
+ torch.cuda.reset_peak_memory_stats()
+ now_mem_ori = torch.cuda.memory_allocated() / 1024**2
+ out_model = model(*[i.clone() if isinstance(i, torch.Tensor) else i for i in inputs])
+ if print_mem:
+ max_mem_ori = torch.cuda.max_memory_allocated() / 1024**2
+ print("origin mem: %.2fMB, autochunk mem: %.2fMB" % (max_mem_ori - now_mem_ori, max_mem_gm - now_mem_gm))
+
+ assert torch.allclose(out_gm, out_model,
+ atol=1e-3), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
+ torch.abs(out_gm - out_model))
+
+ return chunks
+
+
+def run_test(
+ rank: int,
+ model: Any,
+ data: tuple,
+ max_memory: int,
+ print_code: bool = False,
+ print_mem: bool = False,
+ print_est_mem: bool = False,
+ print_progress: bool = False,
+ get_chunk_target: Any = None,
+) -> None:
+ # launch colossalai
+ colossalai.launch(
+ config={},
+ rank=rank,
+ world_size=1,
+ host="localhost",
+ port=free_port(),
+ backend="nccl",
+ )
+
+ # build model and input
+ data, meta_args = data
+ chunks = assert_codegen_run(
+ model,
+ meta_args=meta_args,
+ data=data,
+ max_memory=max_memory,
+ print_code=print_code,
+ print_mem=print_mem,
+ print_est_mem=print_est_mem,
+ print_progress=print_progress,
+ )
+
+ if get_chunk_target is not None:
+ chunk_found = [i["region"] for i in chunks]
+ chunk_target = get_chunk_target()[max_memory]
+ assert (chunk_found == chunk_target), "found regions %s doesn't equal target regions %s" % (
+ str(chunk_found),
+ str(chunk_target),
+ )
+
+ gpc.destroy()
diff --git a/tests/test_booster/test_accelerator.py b/tests/test_booster/test_accelerator.py
new file mode 100644
index 000000000000..6958a87e2a08
--- /dev/null
+++ b/tests/test_booster/test_accelerator.py
@@ -0,0 +1,27 @@
+from functools import partial
+
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+from colossalai.booster.accelerator import Accelerator
+from colossalai.testing import parameterize, rerun_if_address_is_in_use
+
+
+@parameterize('device', ['cpu', 'cuda'])
+def run_accelerator(device):
+ acceleartor = Accelerator(device)
+ model = nn.Linear(8, 8)
+ model = acceleartor.configure_model(model)
+ assert next(model.parameters()).device.type == device
+ del model, acceleartor
+
+
+def run_dist(rank):
+ run_accelerator()
+
+
+@rerun_if_address_is_in_use()
+def test_accelerator():
+ world_size = 1
+ run_func = partial(run_dist)
+ mp.spawn(run_func, nprocs=world_size)
diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py
new file mode 100644
index 000000000000..bacf29014193
--- /dev/null
+++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py
@@ -0,0 +1,46 @@
+from functools import partial
+
+import torch
+import torch.multiprocessing as mp
+from torch.optim import Adam
+
+import colossalai
+from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
+from colossalai.testing import rerun_if_address_is_in_use
+from colossalai.utils import free_port
+from tests.kit.model_zoo import model_zoo
+
+
+def run_torch_amp(rank, world_size, port):
+ # init dist env
+ colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+ sub_model_zoo = model_zoo.get_sub_registry('timm')
+ for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items():
+ # dlrm_interactionarch has not parameters, so skip
+ if name == 'dlrm_interactionarch':
+ continue
+
+ model = model_fn().cuda()
+ optimizer = Adam(model.parameters(), lr=1e-3)
+ criterion = lambda x: x.mean()
+ data = data_gen_fn()
+ data = {
+ k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
+ }
+ mixed_precision = FP16TorchMixedPrecision()
+ model, optimizer, criterion = mixed_precision.configure(model, optimizer, criterion)
+ output = model(**data)
+ output = output_transform_fn(output)
+ output_key = list(output.keys())[0]
+ loss = criterion(output[output_key])
+ optimizer.backward(loss)
+ optimizer.clip_grad_by_norm(1.0)
+ optimizer.step()
+ del model, optimizer, criterion, data, output, mixed_precision
+
+
+@rerun_if_address_is_in_use()
+def test_torch_ddp_plugin():
+ world_size = 1
+ run_func = partial(run_torch_amp, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py
new file mode 100644
index 000000000000..58aef54c4967
--- /dev/null
+++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py
@@ -0,0 +1,85 @@
+from functools import partial
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim import SGD
+
+import colossalai
+from colossalai.booster import Booster
+from colossalai.booster.interface import OptimizerWrapper
+from colossalai.booster.plugin import TorchDDPPlugin
+from colossalai.testing import rerun_if_address_is_in_use
+from colossalai.utils import free_port
+from tests.kit.model_zoo import model_zoo
+
+
+def check_torch_ddp_plugin():
+ plugin = TorchDDPPlugin()
+ booster = Booster(plugin=plugin)
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
+ if name == 'dlrm_interactionarch':
+ continue
+
+ model = model_fn()
+ optimizer = SGD(model.parameters(), lr=1e-3)
+ criterion = lambda x: x.mean()
+ data = data_gen_fn()
+
+ data = {
+ k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
+ }
+
+ model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
+
+ assert isinstance(model, DDP)
+ assert isinstance(optimizer, OptimizerWrapper)
+
+ output = model(**data)
+ output = output_transform_fn(output)
+ output_key = list(output.keys())[0]
+ loss = criterion(output[output_key])
+
+ booster.backward(loss, optimizer)
+ optimizer.clip_grad_by_norm(1.0)
+ optimizer.step()
+
+
+def check_dataloader_sharding():
+ plugin = TorchDDPPlugin()
+
+ # create a custom dasetset with 0 to 10
+ dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
+ train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
+
+ # get the first batch of data
+ batch = next(iter(train_dataloader))[0].cuda()
+ is_rank_0 = dist.get_rank() == 0
+
+ if is_rank_0:
+ batch_to_compare = batch.clone()
+ else:
+ batch_to_compare = batch
+ # pass to the rank 1 value to rank 0
+ dist.broadcast(batch_to_compare, src=1)
+
+ # compare on rank 0
+ if is_rank_0:
+ assert not torch.equal(batch,
+ batch_to_compare), 'Same number was found across ranks but expected it to be different'
+
+
+def run_dist(rank, world_size, port):
+ # init dist env
+ colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
+ check_dataloader_sharding()
+ check_torch_ddp_plugin()
+
+
+@rerun_if_address_is_in_use()
+def test_torch_ddp_plugin():
+ world_size = 2
+ run_func = partial(run_dist, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py
new file mode 100644
index 000000000000..48376aaa88bf
--- /dev/null
+++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py
@@ -0,0 +1,70 @@
+import tempfile
+
+import torch
+from torch.optim import Adam
+from torchvision.models import resnet18
+
+from colossalai.checkpoint_io import GeneralCheckpointIO
+
+# ========
+# Note:
+# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
+# 2. we will test on both sharded and unsharded checkpoints
+# 3. TODO(FrankLeeeee): implement sharded checkpoint and test it
+# ========
+
+
+def test_unsharded_checkpoint():
+ # create a model and optimizer
+ model = resnet18()
+ optimizer = Adam(model.parameters(), lr=0.001)
+
+ # create test data sample
+ x = torch.randn(1, 3, 224, 224)
+
+ # run fwd and bwd
+ y = model(x)
+ loss = y.sum()
+ loss.backward()
+ optimizer.step()
+
+ # create a temp file for checkpoint
+ model_ckpt_tempfile = tempfile.NamedTemporaryFile()
+ optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
+
+ # save the model and optimizer
+ ckpt_io = GeneralCheckpointIO()
+ ckpt_io.save_model(model, model_ckpt_tempfile.name)
+ ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
+
+ # create new model
+ new_model = resnet18()
+ new_optimizer = Adam(new_model.parameters(), lr=0.001)
+
+ # load the model and optimizer
+ new_model = ckpt_io.load_model(new_model, model_ckpt_tempfile.name)
+ new_optimizer = ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
+
+ # do recursive check for the optimizer state dict
+ # if the value is a dict, compare its values
+ # if the value is a list, comapre all elements one-by-one
+ # if the value is a torch.Tensor, use torch.equal
+ # otherwise use assertEqual
+ def recursive_check(d1, d2):
+ for k, v in d1.items():
+ if isinstance(v, dict):
+ recursive_check(v, d2[k])
+ elif isinstance(v, list):
+ for i in range(len(v)):
+ if isinstance(v[i], torch.Tensor):
+ assert torch.equal(v[i], d2[k][i])
+ else:
+ assert v[i] == d2[k][i]
+ elif isinstance(v, torch.Tensor):
+ assert torch.equal(v, d2[k])
+ else:
+ assert v == d2[k]
+
+ # check for model and optimizer state dict recursively
+ recursive_check(model.state_dict(), new_model.state_dict())
+ recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py
index 2be962e1a2e5..679c8b0f6afe 100644
--- a/tests/test_ddp/test_ddp_ignore_params.py
+++ b/tests/test_ddp/test_ddp_ignore_params.py
@@ -35,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
- chunk_config, _ = search_chunk_configuration(module, 4, 1024)
+ chunk_config, *_ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)
diff --git a/tests/test_device/test_extract_alpha_beta.py b/tests/test_device/test_extract_alpha_beta.py
new file mode 100644
index 000000000000..e32bebdd908e
--- /dev/null
+++ b/tests/test_device/test_extract_alpha_beta.py
@@ -0,0 +1,39 @@
+from functools import partial
+
+import pytest
+import torch.multiprocessing as mp
+
+from colossalai.device import AlphaBetaProfiler
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port
+
+
+def check_extract_alpha_beta(rank, physical_devices, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ profiler = AlphaBetaProfiler(physical_devices)
+
+ mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
+ for alpha in mesh_alpha:
+ assert alpha > 0 and alpha < 1e-3
+ for beta in mesh_beta:
+ assert beta > 0 and beta < 1e-10
+
+
+@pytest.mark.skip(reason="Skip because assertion may fail for CI devices")
+@pytest.mark.dist
+@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]])
+@rerun_if_address_is_in_use()
+def test_profile_alpha_beta(physical_devices):
+ world_size = 4
+ run_func = partial(check_extract_alpha_beta,
+ physical_devices=physical_devices,
+ world_size=world_size,
+ port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_profile_alpha_beta()
diff --git a/tests/test_device/test_search_logical_device_mesh.py b/tests/test_device/test_search_logical_device_mesh.py
new file mode 100644
index 000000000000..591eafb2a50d
--- /dev/null
+++ b/tests/test_device/test_search_logical_device_mesh.py
@@ -0,0 +1,36 @@
+from functools import partial
+
+import pytest
+import torch.multiprocessing as mp
+
+from colossalai.device import AlphaBetaProfiler
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port
+
+
+def check_alpha_beta(rank, physical_devices, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ profiler = AlphaBetaProfiler(physical_devices)
+ best_logical_mesh = profiler.search_best_logical_mesh()
+
+ if physical_devices == [0, 1, 2, 3]:
+ assert best_logical_mesh == [[0, 1], [2, 3]]
+ elif physical_devices == [0, 3]:
+ assert best_logical_mesh == [[0, 3]]
+
+
+@pytest.mark.skip(reason="Skip because assertion may fail for CI devices")
+@pytest.mark.dist
+@parameterize('physical_devices', [[0, 1, 2, 3], [0, 3]])
+@rerun_if_address_is_in_use()
+def test_profile_alpha_beta(physical_devices):
+ world_size = 4
+ run_func = partial(check_alpha_beta, physical_devices=physical_devices, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_profile_alpha_beta()
diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
index 6d93fe0408d7..7a4bf131ae36 100644
--- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
+++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py
@@ -3,7 +3,8 @@
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
-from colossalai.fx import symbolic_trace
+# from colossalai.fx import symbolic_trace
+from colossalai._analyzer.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen):
diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
index 9c36b0c9cc96..31ba2290ed99 100644
--- a/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
+++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_albert.py
@@ -1,66 +1,22 @@
import pytest
import torch
-import transformers
from hf_tracer_utils import trace_model_and_compare_output
+from packaging import version
+
+from tests.kit.model_zoo import model_zoo
BATCH_SIZE = 2
SEQ_LENGTH = 16
-def test_single_sentence_albert():
- MODEL_LIST = [
- transformers.AlbertModel,
- transformers.AlbertForPreTraining,
- transformers.AlbertForMaskedLM,
- transformers.AlbertForSequenceClassification,
- transformers.AlbertForTokenClassification,
- ]
-
- config = transformers.AlbertConfig(embedding_size=128,
- hidden_size=128,
- num_hidden_layers=2,
- num_attention_heads=4,
- intermediate_size=256)
-
- def data_gen():
- input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
- return meta_args
-
- for model_cls in MODEL_LIST:
- model = model_cls(config=config)
- trace_model_and_compare_output(model, data_gen)
-
-
-def test_multi_sentence_albert():
- config = transformers.AlbertConfig(hidden_size=128,
- num_hidden_layers=2,
- num_attention_heads=4,
- intermediate_size=256)
- tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
-
- def data_gen_for_qa():
- question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
- inputs = tokenizer(question, text, return_tensors="pt")
- return inputs
-
- model = transformers.AlbertForQuestionAnswering(config)
- trace_model_and_compare_output(model, data_gen_for_qa)
-
- def data_gen_for_mcq():
- prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
- choice0 = "It is eaten with a fork and a knife."
- choice1 = "It is eaten while held in the hand."
- encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
- encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
- return encoding
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+def test_albert():
+ sub_registry = model_zoo.get_sub_registry('transformers_albert')
- model = transformers.AlbertForMultipleChoice(config)
- trace_model_and_compare_output(model, data_gen_for_mcq)
+ for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
+ model = model_fn()
+ trace_model_and_compare_output(model, data_gen_fn)
if __name__ == '__main__':
- test_single_sentence_albert()
- test_multi_sentence_albert()
+ test_albert()
diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
index 62273e2d51c9..8db6817c66dc 100644
--- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
+++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py
@@ -1,69 +1,19 @@
import pytest
import torch
-import transformers
from hf_tracer_utils import trace_model_and_compare_output
+from packaging import version
-BATCH_SIZE = 2
-SEQ_LENGTH = 16
+from tests.kit.model_zoo import model_zoo
-def test_single_sentence_bert():
- MODEL_LIST = [
- transformers.BertModel,
- transformers.BertForPreTraining,
- transformers.BertLMHeadModel,
- transformers.BertForMaskedLM,
- transformers.BertForSequenceClassification,
- transformers.BertForTokenClassification,
- ]
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+def test_bert():
+ sub_registry = model_zoo.get_sub_registry('transformers_bert')
- config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
-
- def data_gen():
- input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
- return meta_args
-
- for model_cls in MODEL_LIST:
- model = model_cls(config=config)
- trace_model_and_compare_output(model, data_gen)
-
-
-def test_multi_sentence_bert():
- config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
- tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
-
- def data_gen_for_next_sentence():
- prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
- next_sentence = "The sky is blue due to the shorter wavelength of blue light."
- encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
- return encoding
-
- model = transformers.BertForNextSentencePrediction(config)
- trace_model_and_compare_output(model, data_gen_for_next_sentence)
-
- def data_gen_for_qa():
- question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
- inputs = tokenizer(question, text, return_tensors="pt")
- return inputs
-
- model = transformers.BertForQuestionAnswering(config)
- trace_model_and_compare_output(model, data_gen_for_qa)
-
- def data_gen_for_mcq():
- prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
- choice0 = "It is eaten with a fork and a knife."
- choice1 = "It is eaten while held in the hand."
- encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
- encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
- return encoding
-
- model = transformers.BertForMultipleChoice(config)
- trace_model_and_compare_output(model, data_gen_for_mcq)
+ for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
+ model = model_fn()
+ trace_model_and_compare_output(model, data_gen_fn)
if __name__ == '__main__':
- test_single_sentence_bert()
- test_multi_sentence_bert()
+ test_bert()
diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
index 04e874becd00..92ece357bfed 100644
--- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
+++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
@@ -1,114 +1,69 @@
import pytest
import torch
-import transformers
-from hf_tracer_utils import trace_model_and_compare_output
from colossalai.fx import symbolic_trace
+from colossalai.testing.random import seed_all
+from tests.kit.model_zoo import model_zoo
-try:
- import diffusers
- HAS_DIFFUSERS = True
-except ImportError:
- HAS_DIFFUSERS = False
-
-BATCH_SIZE = 2
-SEQ_LENGTH = 5
-HEIGHT = 224
-WIDTH = 224
-IN_CHANNELS = 3
-LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8)
-TIME_STEP = 2
-
-
-@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed")
-def test_vae():
- MODEL_LIST = [
- diffusers.AutoencoderKL,
- diffusers.VQModel,
- ]
-
- for model_cls in MODEL_LIST:
- model = model_cls()
- sample = torch.zeros(LATENTS_SHAPE)
-
- gm = symbolic_trace(model)
-
- model.eval()
- gm.eval()
-
- with torch.no_grad():
- fx_out = gm(sample)
- non_fx_out = model(sample)
- assert torch.allclose(
- fx_out['sample'],
- non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
-
-
-def test_clip():
- MODEL_LIST = [
- transformers.CLIPModel,
- transformers.CLIPTextModel,
- transformers.CLIPVisionModel,
- ]
-
- CONFIG_LIST = [
- transformers.CLIPConfig,
- transformers.CLIPTextConfig,
- transformers.CLIPVisionConfig,
- ]
-
- def data_gen():
- if isinstance(model, transformers.CLIPModel):
- input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
- kwargs = dict(input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- pixel_values=pixel_values)
- elif isinstance(model, transformers.CLIPTextModel):
- input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
- elif isinstance(model, transformers.CLIPVisionModel):
- pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
- kwargs = dict(pixel_values=pixel_values)
- return kwargs
-
- for model_cls, config in zip(MODEL_LIST, CONFIG_LIST):
- model = model_cls(config=config())
- trace_model_and_compare_output(model, data_gen)
-
-
-@pytest.mark.skipif(not HAS_DIFFUSERS, reason="diffusers has not been installed")
-@pytest.mark.skip(reason='cannot pass the test yet')
-def test_unet():
- MODEL_LIST = [
- diffusers.UNet2DModel,
- diffusers.UNet2DConditionModel,
- ]
-
- for model_cls in MODEL_LIST:
- model = model_cls()
- sample = torch.zeros(LATENTS_SHAPE)
-
- gm = symbolic_trace(model)
-
- model.eval()
- gm.eval()
-
- with torch.no_grad():
- fx_out = gm(sample, TIME_STEP)
- non_fx_out = model(sample, TIME_STEP)
- assert torch.allclose(
- fx_out['sample'],
- non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+def assert_dict(da, db, assert_fn):
+ assert len(da) == len(db)
+ for k, v in da.items():
+ assert k in db
+ if not torch.is_tensor(v):
+ continue
+ u = db.get(k)
+ assert_fn(u, v)
-if __name__ == "__main__":
- test_vae()
- test_clip()
- # skip because of failure
- # test_unet()
+def trace_and_compare(model_cls, data, output_fn):
+ model = model_cls()
+ model.eval()
+
+ concrete_args = {k: v for k, v in data.items() if not torch.is_tensor(v)}
+ meta_args = {k: v.to('meta') for k, v in data.items() if torch.is_tensor(v)}
+ gm = symbolic_trace(model, concrete_args=concrete_args, meta_args=meta_args)
+
+ # run forward
+ with torch.no_grad():
+ fx_out = gm(**data)
+ non_fx_out = model(**data)
+
+ # compare output
+ transformed_fx_out = output_fn(fx_out)
+ transformed_non_fx_out = output_fn(non_fx_out)
+
+ def assert_fn(ta, tb):
+ assert torch.equal(ta, tb)
+
+ assert_dict(transformed_fx_out, transformed_non_fx_out, assert_fn)
+
+
+@pytest.mark.skip(reason='cannot pass this test yet')
+def test_diffusers():
+ seed_all(9091, cuda_deterministic=True)
+
+ sub_model_zoo = model_zoo.get_sub_registry('diffusers')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
+ data = data_gen_fn()
+ trace_and_compare(model_fn, data, output_transform_fn)
+ torch.cuda.synchronize()
+ print(f"{name:40s} √")
+
+
+def test_torch_diffusers():
+ seed_all(65535, cuda_deterministic=True)
+
+ sub_model_zoo = model_zoo.get_sub_registry('diffusers')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
+ data = data_gen_fn()
+ model = model_fn()
+ output = model(**data)
+ torch.cuda.synchronize()
+ print(f"{name:40s} √")
+
+
+if __name__ == "__main__":
+ test_torch_diffusers()
diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
index ad4c9684dc42..796c17e398d5 100644
--- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
+++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py
@@ -1,35 +1,25 @@
import pytest
import torch
-import transformers
from hf_tracer_utils import trace_model_and_compare_output
+from packaging import version
-BATCH_SIZE = 1
-SEQ_LENGTH = 16
+from tests.kit.model_zoo import model_zoo
-# TODO: remove this skip once we handle the latest gpt model
-@pytest.mark.skip
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_gpt():
- MODEL_LIST = [
- transformers.GPT2Model,
- transformers.GPT2LMHeadModel,
- transformers.GPT2DoubleHeadsModel,
- transformers.GPT2ForTokenClassification,
- # transformers.GPT2ForSequenceClassification, # not supported yet
- ]
-
- config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
-
- def data_gen():
- input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
- return kwargs
-
- for model_cls in MODEL_LIST:
- model = model_cls(config=config)
- trace_model_and_compare_output(model, data_gen)
+ sub_registry = model_zoo.get_sub_registry('transformers_gpt')
+
+ for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
+ model = model_fn()
+
+ # TODO: support the following models
+ # 1. GPT2DoubleHeadsModel
+ # as they are not supported, let's skip them
+ if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
+ continue
+
+ trace_model_and_compare_output(model, data_gen_fn)
if __name__ == '__main__':
diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
index 06260176ec6f..e7bfa607082e 100644
--- a/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
+++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_opt.py
@@ -1,29 +1,18 @@
import pytest
import torch
-import transformers
from hf_tracer_utils import trace_model_and_compare_output
+from packaging import version
-BATCH_SIZE = 1
-SEQ_LENGTH = 16
+from tests.kit.model_zoo import model_zoo
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_opt():
- MODEL_LIST = [
- transformers.OPTModel,
- transformers.OPTForCausalLM,
- ]
+ sub_registry = model_zoo.get_sub_registry('transformers_opt')
- config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
-
- def data_gen():
- input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
- return kwargs
-
- for model_cls in MODEL_LIST:
- model = model_cls(config=config)
- trace_model_and_compare_output(model, data_gen)
+ for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
+ model = model_fn()
+ trace_model_and_compare_output(model, data_gen_fn)
if __name__ == '__main__':
diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
index 71e782fddc76..5f7e4f81c44e 100644
--- a/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
+++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_t5.py
@@ -1,41 +1,18 @@
import pytest
import torch
-import transformers
from hf_tracer_utils import trace_model_and_compare_output
+from packaging import version
-BATCH_SIZE = 1
-SEQ_LENGTH = 16
+from tests.kit.model_zoo import model_zoo
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_t5():
- MODEL_LIST = [
- transformers.T5Model,
- transformers.T5ForConditionalGeneration,
- transformers.T5EncoderModel,
- ]
+ sub_registry = model_zoo.get_sub_registry('transformers_t5')
- config = transformers.T5Config(d_model=128, num_layers=2)
-
- def data_gen():
- input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
- return kwargs
-
- def data_gen_for_encoder_only():
- input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
- kwargs = dict(input_ids=input_ids)
- return kwargs
-
- for model_cls in MODEL_LIST:
- model = model_cls(config=config)
-
- if isinstance(model, transformers.T5EncoderModel):
- data_gen_func = data_gen_for_encoder_only
- else:
- data_gen_func = data_gen
-
- trace_model_and_compare_output(model, data_gen_func)
+ for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
+ model = model_fn()
+ trace_model_and_compare_output(model, data_gen_fn)
if __name__ == '__main__':
diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
index 28ec3d82556c..b175d8b10c67 100644
--- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
+++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py
@@ -1,11 +1,12 @@
import pytest
-import timm.models as tm
import torch
+from packaging import version
-from colossalai.fx import symbolic_trace
+from colossalai._analyzer.fx import symbolic_trace
+from tests.kit.model_zoo import model_zoo
-def trace_and_compare(model_cls, data, meta_args=None):
+def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
# trace
model = model_cls()
@@ -14,60 +15,48 @@ def trace_and_compare(model_cls, data, meta_args=None):
# without this statement, the torch.nn.functional.batch_norm will always be in training mode
model.eval()
+ # TODO: support the following models
+ # 1. ConViT
+ # 2. NormFreeNet
+ # as they are not supported, let's skip them
+ if model.__class__.__name__ in ['ConViT', 'NormFreeNet']:
+ return
+
gm = symbolic_trace(model, meta_args=meta_args)
# run forward
with torch.no_grad():
- fx_out = gm(data)
- non_fx_out = model(data)
+ fx_out = gm(**data)
+ non_fx_out = model(**data)
# compare output
- if isinstance(fx_out, tuple):
- # some models produce tuple as output
- for v1, v2 in zip(fx_out, non_fx_out):
- assert torch.allclose(v1, v2), f'{model.__class__.__name__} has inconsistent outputs, {v1} vs {v2}'
- else:
- assert torch.allclose(
- fx_out, non_fx_out,
- atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
-
-
-def test_timm_models_without_control_flow():
- torch.backends.cudnn.deterministic = True
-
- MODEL_LIST = [
- tm.resnest.resnest50d,
- tm.beit.beit_base_patch16_224,
- tm.cait.cait_s24_224,
- tm.convmixer.convmixer_768_32,
- tm.efficientnet.efficientnetv2_m,
- tm.resmlp_12_224,
- tm.vision_transformer.vit_base_patch16_224,
- tm.deit_base_distilled_patch16_224,
- ]
+ transformed_fx_out = output_transform_fn(fx_out)
+ transformed_non_fx_out = output_transform_fn(non_fx_out)
- data = torch.rand(2, 3, 224, 224)
+ assert len(transformed_fx_out) == len(transformed_non_fx_out)
- for model_cls in MODEL_LIST:
- trace_and_compare(model_cls, data)
+ for key in transformed_fx_out.keys():
+ fx_output_val = transformed_fx_out[key]
+ non_fx_output_val = transformed_non_fx_out[key]
+ assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
+ f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
-def test_timm_models_with_control_flow():
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+def test_timm_models():
torch.backends.cudnn.deterministic = True
- MODEL_LIST_WITH_CONTROL_FLOW = [
- tm.convnext.convnext_base, tm.vgg.vgg11, tm.dpn.dpn68, tm.densenet.densenet121, tm.rexnet.rexnet_100,
- tm.swin_transformer.swin_base_patch4_window7_224
- ]
-
- data = torch.rand(2, 3, 224, 224)
+ sub_model_zoo = model_zoo.get_sub_registry('timm')
- meta_args = {'x': data.to('meta')}
+ for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
+ data = data_gen_fn()
+ if attribute is not None and attribute.has_control_flow:
+ meta_args = {k: v.to('meta') for k, v in data.items()}
+ else:
+ meta_args = None
- for model_cls in MODEL_LIST_WITH_CONTROL_FLOW:
- trace_and_compare(model_cls, data, meta_args)
+ trace_and_compare(model_fn, data, output_transform_fn, meta_args)
if __name__ == '__main__':
- test_timm_models_with_control_flow()
- test_timm_models_without_control_flow()
+ test_timm_models()
diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py
deleted file mode 100644
index b2fa8c6c0bbb..000000000000
--- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_general.py
+++ /dev/null
@@ -1,145 +0,0 @@
-import torch
-from torchaudio_utils import trace_and_compare
-from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
-from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
-import pytest
-
-
-def test_wave2letter_waveform():
- batch_size = 2
- num_features = 1
- num_classes = 40
- input_length = 320
-
- model = Wav2Letter(num_classes=num_classes, num_features=num_features)
-
- def data_gen():
- x = torch.rand(batch_size, num_features, input_length)
- return dict(x=x)
-
- trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
-
-
-def test_wave2letter_mfcc():
- batch_size = 2
- num_features = 13
- num_classes = 40
- input_length = 2
-
- model = Wav2Letter(num_classes=num_classes, input_type="mfcc", num_features=num_features)
-
- def data_gen():
- x = torch.rand(batch_size, num_features, input_length)
- return dict(x=x)
-
- trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
-
-
-def test_melresnet_waveform():
- n_batch = 2
- n_time = 200
- n_freq = 100
- n_output = 128
- n_res_block = 10
- n_hidden = 128
- kernel_size = 5
-
- model = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
-
- def data_gen():
- x = torch.rand(n_batch, n_freq, n_time)
- return dict(specgram=x)
-
- trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
-
-
-def test_upsample_network_waveform():
- upsample_scales = [5, 5, 8]
- n_batch = 2
- n_time = 200
- n_freq = 100
- n_output = 64
- n_res_block = 10
- n_hidden = 32
- kernel_size = 5
-
- total_scale = 1
- for upsample_scale in upsample_scales:
- total_scale *= upsample_scale
-
- model = UpsampleNetwork(upsample_scales, n_res_block, n_freq, n_hidden, n_output, kernel_size)
-
- def data_gen():
- x = torch.rand(n_batch, n_freq, n_time)
- return dict(specgram=x)
-
- trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
-
-
-def test_wavernn_waveform():
- upsample_scales = [2, 2, 5]
- n_rnn = 16
- n_fc = 16
- n_classes = 10
- hop_length = 20
- n_batch = 2
- n_time = 20
- n_freq = 10
- n_output = 16
- n_res_block = 3
- n_hidden = 16
- kernel_size = 5
-
- model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, n_rnn, n_fc, kernel_size, n_freq, n_hidden,
- n_output)
-
- def data_gen():
- x = torch.rand(n_batch, 1, hop_length * (n_time - kernel_size + 1))
- mels = torch.rand(n_batch, 1, n_freq, n_time)
- return dict(waveform=x, specgram=mels)
-
- trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
-
-
-def test_convtasnet_config():
- batch_size = 32
- num_frames = 800
-
- model = ConvTasNet()
-
- def data_gen():
- tensor = torch.rand(batch_size, 1, num_frames)
- return dict(input=tensor)
-
- trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
-
-
-def test_deepspeech():
- n_batch = 2
- n_feature = 1
- n_channel = 1
- n_class = 40
- n_time = 32
-
- model = DeepSpeech(n_feature=n_feature, n_class=n_class)
-
- def data_gen():
- x = torch.rand(n_batch, n_channel, n_time, n_feature)
- return dict(x=x)
-
- trace_and_compare(model, data_gen, need_meta=False, need_concrete=False)
-
-
-if __name__ == '__main__':
- TEST_LIST = [
- test_wave2letter_waveform,
- test_wave2letter_mfcc,
- test_melresnet_waveform,
- test_upsample_network_waveform,
- test_wavernn_waveform,
- test_convtasnet_config,
- test_deepspeech,
- ]
-
- for test_fn in TEST_LIST:
- test_fn()
diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py
new file mode 100644
index 000000000000..65f9f5149dda
--- /dev/null
+++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py
@@ -0,0 +1,20 @@
+import pytest
+import torch
+from packaging import version
+from torchaudio_utils import trace_and_compare
+
+from tests.kit.model_zoo import model_zoo
+
+
+@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
+def test_torchaudio_models():
+ torch.backends.cudnn.deterministic = True
+
+ sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
+ model = model_fn()
+ trace_and_compare(model,
+ data_gen_fn,
+ output_transform_fn,
+ need_meta=(attribute is not None and attribute.has_control_flow))
diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py
deleted file mode 100644
index 2073c46897f4..000000000000
--- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_tacotron.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import torch
-from torchaudio.models import Tacotron2
-from torchaudio_utils import trace_and_compare
-import pytest
-
-
-def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
- return Tacotron2(
- mask_padding=False,
- n_mels=n_mels,
- n_symbol=20,
- n_frames_per_step=1,
- symbol_embedding_dim=32,
- encoder_embedding_dim=32,
- encoder_n_convolution=3,
- encoder_kernel_size=5,
- decoder_rnn_dim=32,
- decoder_max_step=decoder_max_step,
- decoder_dropout=0.1,
- decoder_early_stopping=True,
- attention_rnn_dim=32,
- attention_hidden_dim=32,
- attention_location_n_filter=32,
- attention_location_kernel_size=31,
- attention_dropout=0.1,
- prenet_dim=32,
- postnet_n_convolution=5,
- postnet_kernel_size=5,
- postnet_embedding_dim=512,
- gate_threshold=gate_threshold,
- )
-
-
-@pytest.mark.skip("Tracing failed")
-def test_tacotron_model():
- n_mels = 80
- n_batch = 3
- max_mel_specgram_length = 300
- max_text_length = 100
-
- model = _get_tacotron2_model(n_mels)
-
- def data_gen():
- text = torch.randint(0, 148, (n_batch, max_text_length))
- text_lengths = max_text_length * torch.ones((n_batch,))
- mel_specgram = torch.rand(n_batch, n_mels, max_mel_specgram_length)
- mel_specgram_lengths = max_mel_specgram_length * torch.ones((n_batch,))
- return dict(tokens=text,
- token_lengths=text_lengths,
- mel_specgram=mel_specgram,
- mel_specgram_lengths=mel_specgram_lengths)
-
- trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
-
-
-if __name__ == "__main__":
- test_tacotron_model()
diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py
deleted file mode 100644
index fbe24a8cd91f..000000000000
--- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_transformer.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import torch
-from torchaudio_utils import trace_and_compare
-from torchaudio.models import Emformer, Conformer
-import pytest
-
-
-def test_conformer():
- input_dim = 80
- batch_size = 10
- num_frames = 400
- num_heads = 4
- ffn_dim = 128
- num_layers = 4
- depthwise_conv_kernel_size = 31
-
- model = Conformer(
- input_dim=input_dim,
- num_heads=num_heads,
- ffn_dim=ffn_dim,
- num_layers=num_layers,
- depthwise_conv_kernel_size=depthwise_conv_kernel_size,
- )
-
- def data_gen():
- lengths = torch.randint(1, num_frames, (batch_size,))
- input = torch.rand(batch_size, int(lengths.max()), input_dim)
- return dict(input=input, lengths=lengths)
-
- def kwargs_transform(data):
- new_data = {}
-
- for k, v in data.items():
- new_data[f'{k}_1'] = v
- return new_data
-
- trace_and_compare(model, data_gen, need_meta=False, need_concrete=True, kwargs_transform=kwargs_transform)
-
-
-@pytest.mark.skip("Tracing failed")
-def test_emformer():
- input_dim = 128
- batch_size = 10
- num_heads = 8
- ffn_dim = 256
- num_layers = 3
- segment_length = 4
- num_frames = 400
- right_context_length = 1
-
- model = Emformer(input_dim, num_heads, ffn_dim, num_layers, segment_length, right_context_length)
-
- def data_gen():
- lengths = torch.randint(1, num_frames, (batch_size,))
- input = torch.rand(batch_size, num_frames, input_dim)
- return dict(input=input, lengths=lengths)
-
- trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
-
-
-@pytest.mark.skip
-def test_torchaudio_transformers():
- test_conformer()
- test_emformer()
-
-
-if __name__ == "__main__":
- test_torchaudio_transformers()
diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py
deleted file mode 100644
index e8729b83fba0..000000000000
--- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_wave2vec.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import torch
-from torchaudio.models.wav2vec2 import (
- hubert_base,
- hubert_large,
- hubert_xlarge,
- wav2vec2_base,
- wav2vec2_large,
- wav2vec2_large_lv60k,
-)
-from torchaudio_utils import trace_and_compare
-import pytest
-
-MODEL_LIST = [
- hubert_base,
- hubert_large,
- hubert_xlarge,
- wav2vec2_base,
- wav2vec2_large,
- wav2vec2_large_lv60k,
-]
-
-
-def _smoke_test(model, device):
- model = model.to(device=device)
-
- batch_size, num_frames = 3, 1024
-
- def data_gen():
- waveforms = torch.randn(batch_size, num_frames, device=device)
- lengths = torch.randint(
- low=0,
- high=num_frames,
- size=[
- batch_size,
- ],
- device=device,
- )
- return dict(waveforms=waveforms, lengths=lengths)
-
- trace_and_compare(model, data_gen, need_meta=True, need_concrete=False)
-
-
-@pytest.mark.skip("Tracing failed")
-def test_wav2vec():
- for model_fn in MODEL_LIST:
- _smoke_test(model_fn(), 'cpu')
-
-
-if __name__ == "__main__":
- test_wav2vec()
diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
index 702c5f8f6a24..239f38680cec 100644
--- a/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
+++ b/tests/test_fx/test_tracer/test_torchaudio_model/torchaudio_utils.py
@@ -1,9 +1,9 @@
import torch
-from colossalai.fx import symbolic_trace
+from colossalai._analyzer.fx import symbolic_trace
-def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwargs_transform=False):
+def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
data = data_gen()
concrete_args = data if need_concrete else {}
meta_args = {k: v.to('meta') for k, v in data.items()} if need_meta else {}
@@ -14,16 +14,15 @@ def trace_and_compare(model, data_gen, need_meta=False, need_concrete=False, kwa
with torch.no_grad():
non_fx_out = model(**data)
+ fx_out = gm(**data)
- if kwargs_transform:
- data = kwargs_transform(data)
+ # compare output
+ transformed_fx_out = output_transform_fn(fx_out)
+ transformed_non_fx_out = output_transform_fn(non_fx_out)
- fx_out = gm(**data)
- if isinstance(fx_out, tuple):
- for non_fx, fx in zip(non_fx_out, fx_out):
- assert torch.allclose(
- non_fx, fx, atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
- else:
- assert torch.allclose(
- fx_out, non_fx_out,
- atol=1e-5), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ assert len(transformed_fx_out) == len(transformed_non_fx_out)
+
+ for key, fx_output_val in transformed_fx_out.items():
+ non_fx_output_val = transformed_non_fx_out[key]
+ assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
+ f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
index dbe8a62e7c59..40f83d47a7cc 100644
--- a/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
+++ b/tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
@@ -1,85 +1,64 @@
import pytest
import torch
-from colossalai.fx import symbolic_trace
-
-try:
- from torchrec.models import deepfm
- from torchrec.modules.embedding_configs import EmbeddingBagConfig
- from torchrec.modules.embedding_modules import EmbeddingBagCollection
- from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
- NOT_TORCHREC = False
-except ImportError:
- NOT_TORCHREC = True
+from colossalai._analyzer.fx import symbolic_trace
+from tests.kit.model_zoo import model_zoo
BATCH = 2
SHAPE = 10
-@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
-def test_torchrec_deepfm_models():
- MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]
-
- # Data Preparation
- # EmbeddingBagCollection
- eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
- eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
-
- ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
- keys = ["f1", "f2"]
-
- # KeyedTensor
- KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
-
- # KeyedJaggedTensor
- KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
- values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
- offsets=torch.tensor([0, 2, 4, 6, 8]))
-
- # Dense Features
- features = torch.rand((BATCH, SHAPE))
-
- for model_cls in MODEL_LIST:
- # Initializing model
- if model_cls == deepfm.DenseArch:
- model = model_cls(SHAPE, SHAPE, SHAPE)
- elif model_cls == deepfm.FMInteractionArch:
- model = model_cls(SHAPE * 3, keys, SHAPE)
- elif model_cls == deepfm.OverArch:
- model = model_cls(SHAPE)
- elif model_cls == deepfm.SimpleDeepFMNN:
- model = model_cls(SHAPE, ebc, SHAPE, SHAPE)
- elif model_cls == deepfm.SparseArch:
- model = model_cls(ebc)
-
- # Setup GraphModule
- gm = symbolic_trace(model)
+def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
+ # trace
+ model = model_cls()
+
+ # convert to eval for inference
+ # it is important to set it to eval mode before tracing
+ # without this statement, the torch.nn.functional.batch_norm will always be in training mode
+ model.eval()
+
+ gm = symbolic_trace(model, meta_args=meta_args)
+ gm.eval()
+ # run forward
+ with torch.no_grad():
+ fx_out = gm(**data)
+ non_fx_out = model(**data)
+
+ # compare output
+ transformed_fx_out = output_transform_fn(fx_out)
+ transformed_non_fx_out = output_transform_fn(non_fx_out)
+
+ assert len(transformed_fx_out) == len(transformed_non_fx_out)
+ if torch.is_tensor(fx_out):
+ assert torch.allclose(
+ fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ else:
+ assert torch.allclose(
+ fx_out.values(),
+ non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ for key in transformed_fx_out.keys():
+ fx_output_val = transformed_fx_out[key]
+ non_fx_output_val = transformed_non_fx_out[key]
+ if torch.is_tensor(fx_output_val):
+ assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
+ f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
+ else:
+ assert torch.allclose(fx_output_val.values(), non_fx_output_val.values()
+ ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
- model.eval()
- gm.eval()
- # Aligned Test
- with torch.no_grad():
- if model_cls == deepfm.DenseArch or model_cls == deepfm.OverArch:
- fx_out = gm(features)
- non_fx_out = model(features)
- elif model_cls == deepfm.FMInteractionArch:
- fx_out = gm(features, KT)
- non_fx_out = model(features, KT)
- elif model_cls == deepfm.SimpleDeepFMNN:
- fx_out = gm(features, KJT)
- non_fx_out = model(features, KJT)
- elif model_cls == deepfm.SparseArch:
- fx_out = gm(KJT)
- non_fx_out = model(KJT)
+def test_torchrec_deepfm_models():
+ deepfm_models = model_zoo.get_sub_registry('deepfm')
+ torch.backends.cudnn.deterministic = True
- if torch.is_tensor(fx_out):
- assert torch.allclose(
- fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items():
+ data = data_gen_fn()
+ if attribute is not None and attribute.has_control_flow:
+ meta_args = {k: v.to('meta') for k, v in data.items()}
else:
- assert torch.allclose(
- fx_out.values(),
- non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ meta_args = None
+
+ trace_and_compare(model_fn, data, output_transform_fn, meta_args)
if __name__ == "__main__":
diff --git a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
index 2f9fd8fe5982..6d4b6ab81b12 100644
--- a/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
+++ b/tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
@@ -1,111 +1,70 @@
+import pytest
import torch
-from colossalai.fx import symbolic_trace
-
-try:
- from torchrec.models import dlrm
- from torchrec.modules.embedding_configs import EmbeddingBagConfig
- from torchrec.modules.embedding_modules import EmbeddingBagCollection
- from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
- NOT_TORCHREC = False
-except ImportError:
- NOT_TORCHREC = True
-
-import pytest
+from colossalai._analyzer.fx import symbolic_trace
+from tests.kit.model_zoo import model_zoo
BATCH = 2
SHAPE = 10
-@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
-def test_torchrec_dlrm_models():
- MODEL_LIST = [
- dlrm.DLRM,
- dlrm.DenseArch,
- dlrm.InteractionArch,
- dlrm.InteractionV2Arch,
- dlrm.OverArch,
- dlrm.SparseArch,
- # dlrm.DLRMV2
- ]
-
- # Data Preparation
- # EmbeddingBagCollection
- eb1_config = EmbeddingBagConfig(name="t1", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f1"])
- eb2_config = EmbeddingBagConfig(name="t2", embedding_dim=SHAPE, num_embeddings=SHAPE, feature_names=["f2"])
-
- ebc = EmbeddingBagCollection(tables=[eb1_config, eb2_config])
- keys = ["f1", "f2"]
-
- # KeyedTensor
- KT = KeyedTensor(keys=keys, length_per_key=[SHAPE, SHAPE], values=torch.rand((BATCH, 2 * SHAPE)))
-
- # KeyedJaggedTensor
- KJT = KeyedJaggedTensor.from_offsets_sync(keys=keys,
- values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]),
- offsets=torch.tensor([0, 2, 4, 6, 8]))
-
- # Dense Features
- dense_features = torch.rand((BATCH, SHAPE))
-
- # Sparse Features
- sparse_features = torch.rand((BATCH, len(keys), SHAPE))
-
- for model_cls in MODEL_LIST:
- # Initializing model
- if model_cls == dlrm.DLRM:
- model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1])
- elif model_cls == dlrm.DenseArch:
- model = model_cls(SHAPE, [SHAPE, SHAPE])
- elif model_cls == dlrm.InteractionArch:
- model = model_cls(len(keys))
- elif model_cls == dlrm.InteractionV2Arch:
- I1 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
- I2 = dlrm.DenseArch(3 * SHAPE, [3 * SHAPE, 3 * SHAPE])
- model = model_cls(len(keys), I1, I2)
- elif model_cls == dlrm.OverArch:
- model = model_cls(SHAPE, [5, 1])
- elif model_cls == dlrm.SparseArch:
- model = model_cls(ebc)
- elif model_cls == dlrm.DLRMV2:
- # Currently DLRMV2 cannot be traced
- model = model_cls(ebc, SHAPE, [SHAPE, SHAPE], [5, 1], [4 * SHAPE, 4 * SHAPE], [4 * SHAPE, 4 * SHAPE])
-
- # Setup GraphModule
- if model_cls == dlrm.InteractionV2Arch:
- concrete_args = {"dense_features": dense_features, "sparse_features": sparse_features}
- gm = symbolic_trace(model, concrete_args=concrete_args)
+def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
+ # trace
+ model = model_cls()
+
+ # convert to eval for inference
+ # it is important to set it to eval mode before tracing
+ # without this statement, the torch.nn.functional.batch_norm will always be in training mode
+ model.eval()
+
+ gm = symbolic_trace(model, meta_args=meta_args)
+ gm.eval()
+ # run forward
+ with torch.no_grad():
+ fx_out = gm(**data)
+ non_fx_out = model(**data)
+
+ # compare output
+ transformed_fx_out = output_transform_fn(fx_out)
+ transformed_non_fx_out = output_transform_fn(non_fx_out)
+
+ assert len(transformed_fx_out) == len(transformed_non_fx_out)
+ if torch.is_tensor(fx_out):
+ assert torch.allclose(
+ fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ else:
+ assert torch.allclose(
+ fx_out.values(),
+ non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ for key in transformed_fx_out.keys():
+ fx_output_val = transformed_fx_out[key]
+ non_fx_output_val = transformed_non_fx_out[key]
+ if torch.is_tensor(fx_output_val):
+ assert torch.allclose(fx_output_val, non_fx_output_val, atol=1e-5), \
+ f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
else:
- gm = symbolic_trace(model)
-
- model.eval()
- gm.eval()
-
- # Aligned Test
- with torch.no_grad():
- if model_cls == dlrm.DLRM or model_cls == dlrm.DLRMV2:
- fx_out = gm(dense_features, KJT)
- non_fx_out = model(dense_features, KJT)
- elif model_cls == dlrm.DenseArch:
- fx_out = gm(dense_features)
- non_fx_out = model(dense_features)
- elif model_cls == dlrm.InteractionArch or model_cls == dlrm.InteractionV2Arch:
- fx_out = gm(dense_features, sparse_features)
- non_fx_out = model(dense_features, sparse_features)
- elif model_cls == dlrm.OverArch:
- fx_out = gm(dense_features)
- non_fx_out = model(dense_features)
- elif model_cls == dlrm.SparseArch:
- fx_out = gm(KJT)
- non_fx_out = model(KJT)
-
- if torch.is_tensor(fx_out):
- assert torch.allclose(
- fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ assert torch.allclose(fx_output_val.values(), non_fx_output_val.values()
+ ), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+
+
+def test_torchrec_dlrm_models():
+ torch.backends.cudnn.deterministic = True
+ dlrm_models = model_zoo.get_sub_registry('dlrm')
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items():
+ data = data_gen_fn()
+
+ # dlrm_interactionarch is not supported
+ # TODO(FrankLeeeee): support this model
+ if name == 'dlrm_interactionarch':
+ continue
+
+ if attribute is not None and attribute.has_control_flow:
+ meta_args = {k: v.to('meta') for k, v in data.items()}
else:
- assert torch.allclose(
- fx_out.values(),
- non_fx_out.values()), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ meta_args = None
+
+ trace_and_compare(model_fn, data, output_transform_fn, meta_args)
if __name__ == "__main__":
diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
index 2a6c6ae1674b..8dbbf9f5aab7 100644
--- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
+++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py
@@ -1,44 +1,43 @@
import torch
-import torchvision
-import torchvision.models as tm
-from packaging import version
-from colossalai.fx import symbolic_trace
+from colossalai._analyzer.fx import symbolic_trace
+from tests.kit.model_zoo import model_zoo
def test_torchvision_models():
- MODEL_LIST = [
- tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
- tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0
- ]
-
- RANDOMIZED_MODELS = [tm.efficientnet_b0]
-
- if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
- MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
- RANDOMIZED_MODELS.append(tm.convnext_small)
-
torch.backends.cudnn.deterministic = True
+ tv_sub_registry = model_zoo.get_sub_registry('torchvision')
- data = torch.rand(2, 3, 224, 224)
+ for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items():
+ data = data_gen_fn()
- for model_cls in MODEL_LIST:
- if model_cls in RANDOMIZED_MODELS:
- # remove the impact of randomicity
- model = model_cls(stochastic_depth_prob=0)
+ if model_attribute is not None and model_attribute.has_stochastic_depth_prob:
+ model = model_fn(stochastic_depth_prob=0)
else:
- model = model_cls()
+ model = model_fn()
gm = symbolic_trace(model)
model.eval()
gm.eval()
- with torch.no_grad():
- fx_out = gm(data)
- non_fx_out = model(data)
- assert torch.allclose(
- fx_out, non_fx_out), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
+ try:
+ with torch.no_grad():
+ fx_out = gm(**data)
+ non_fx_out = model(**data)
+ transformed_out = output_transform_fn(fx_out)
+ transformed_non_fx_out = output_transform_fn(non_fx_out)
+
+ assert len(transformed_out) == len(transformed_non_fx_out)
+
+ for key in transformed_out.keys():
+ fx_val = transformed_out[key]
+ non_fx_val = transformed_non_fx_out[key]
+ assert torch.allclose(
+ fx_val,
+ non_fx_val), f'{model.__class__.__name__} has inconsistent outputs, {fx_val} vs {non_fx_val}'
+ except Exception as e:
+ print(name, e)
if __name__ == '__main__':
diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py
index af98878e9e70..2821dc78d984 100644
--- a/tests/test_gemini/update/test_fwd_bwd.py
+++ b/tests/test_gemini/update/test_fwd_bwd.py
@@ -34,17 +34,17 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
-@parameterize('init_device', [get_current_device()])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('keep_gather', [False, True])
@parameterize('model_name', ['gpt2', 'bert', 'albert'])
@parameterize('use_grad_checkpoint', [False, True])
-def exam_gpt_fwd_bwd(placement_policy,
- keep_gather,
- model_name: str,
- use_grad_checkpoint: bool = False,
- init_device=get_current_device()):
-
+def exam_gpt_fwd_bwd(
+ placement_policy,
+ keep_gather,
+ model_name: str,
+ use_grad_checkpoint: bool = False,
+):
+ init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@@ -58,7 +58,7 @@ def exam_gpt_fwd_bwd(placement_policy,
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
- config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+ config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py
index 7fce84a5099a..8cf17a0a726e 100644
--- a/tests/test_gemini/update/test_gemini_use_rmt.py
+++ b/tests/test_gemini/update/test_gemini_use_rmt.py
@@ -62,7 +62,7 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
assert len(step_list) == 4
world_size = torch.distributed.get_world_size()
- config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+ config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gather
chunk_manager = ChunkManager(config_dict)
diff --git a/tests/test_gemini/update/test_grad_clip.py b/tests/test_gemini/update/test_grad_clip.py
index 185521edb357..d97ba94399c0 100644
--- a/tests/test_gemini/update/test_grad_clip.py
+++ b/tests/test_gemini/update/test_grad_clip.py
@@ -31,8 +31,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
- if key == 'model.lm_head.weight':
- continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
@@ -60,7 +58,7 @@ def exam_grad_clipping(placement_policy, model_name: str):
p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size()
- config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+ config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
diff --git a/tests/test_gemini/update/test_inference.py b/tests/test_gemini/update/test_inference.py
new file mode 100644
index 000000000000..b057448ad378
--- /dev/null
+++ b/tests/test_gemini/update/test_inference.py
@@ -0,0 +1,138 @@
+from functools import partial
+from typing import Callable
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.testing import assert_close
+
+import colossalai
+from colossalai.amp import convert_to_apex_amp
+from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration
+from colossalai.gemini.gemini_mgr import GeminiManager
+from colossalai.nn.optimizer import HybridAdam
+from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
+from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper
+from colossalai.testing import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port
+from colossalai.utils.cuda import get_current_device
+from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
+from tests.components_to_test import run_fwd_bwd
+from tests.components_to_test.registry import non_distributed_component_funcs
+from tests.test_tensor.common_utils import debug_print, set_seed
+
+
+def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
+ zero_dict = model.state_dict(only_rank_0=False)
+ torch_dict = torch_model.state_dict()
+
+ for key, value in torch_dict.items():
+ # key is 'module.model.PARAMETER', so we truncate it
+ key = key[7:]
+ assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
+ temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
+ # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
+ assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
+
+
+def multi_chunk_init(model: torch.nn.Module, placement_policy: str):
+ world_size = dist.get_world_size()
+ config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+ config_dict[world_size]['chunk_size'] = 5000
+ config_dict[world_size]['keep_gathered'] = False
+ if placement_policy != 'cuda':
+ init_device = torch.device('cpu')
+ else:
+ init_device = None
+ chunk_manager = ChunkManager(config_dict, init_device=init_device)
+ gemini_manager = GeminiManager(placement_policy, chunk_manager)
+ model = ZeroDDP(model, gemini_manager, pin_memory=True)
+ return model
+
+
+def single_chunk_init(model: torch.nn.Module, placement_policy: str):
+ gemini_config = dict(
+ device=get_current_device(),
+ placement_policy=placement_policy,
+ pin_memory=True,
+ )
+ model = zero_model_wrapper(model=model, zero_stage=3, gemini_config=gemini_config)
+ return model
+
+
+@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
+@parameterize('model_name', ['gpt2'])
+@parameterize('model_init_func', [single_chunk_init, multi_chunk_init])
+def exam_inference(placement_policy: str, model_name: str, model_init_func: Callable):
+ set_seed(19360226)
+ get_components_func = non_distributed_component_funcs.get_callable(model_name)
+ model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
+
+ torch_model = model_builder().cuda()
+ amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=128)
+ torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
+ torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
+ torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
+
+ init_dev = get_current_device()
+ with ColoInitContext(device=init_dev):
+ model = model_builder()
+
+ for torch_p, p in zip(torch_model.parameters(), model.parameters()):
+ p.data.copy_(torch_p.data)
+
+ model = model_init_func(model, placement_policy)
+ optimizer = HybridAdam(model.parameters(), lr=1e-3)
+ zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
+
+ model.eval()
+ torch_model.eval()
+
+ set_seed(dist.get_rank() * 3 + 128)
+ train_dataloader = iter(train_dataloader)
+
+ def train_iter():
+ input_ids, label = next(train_dataloader)
+ input_ids, label = input_ids.cuda(), label.cuda()
+ zero_optim.zero_grad()
+ torch_optim.zero_grad()
+ torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
+ loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
+ assert_close(torch_loss, loss)
+ zero_optim.step()
+ torch_optim.step()
+ check_param(model, torch_model)
+
+ def inference_iter():
+ input_ids, label = next(train_dataloader)
+ input_ids, label = input_ids.cuda(), label.cuda()
+ with torch.no_grad():
+ torch_output = torch_model(input_ids)
+ torch_loss = criterion(torch_output.float(), label)
+ zero_output = model(input_ids)
+ zero_loss = criterion(zero_output.float(), label)
+ assert_close(torch_loss, zero_loss)
+
+ train_iter()
+ inference_iter()
+ train_iter()
+
+
+def run_dist(rank, world_size, port):
+ config = {}
+ colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ exam_inference()
+
+
+@pytest.mark.dist
+@pytest.mark.parametrize('world_size', [1, 4])
+@rerun_if_address_is_in_use()
+def test_inference(world_size):
+ run_func = partial(run_dist, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_inference(1)
diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py
index 34509cc0cf00..cd3aa6051d78 100644
--- a/tests/test_gemini/update/test_optim.py
+++ b/tests/test_gemini/update/test_optim.py
@@ -36,8 +36,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
- if key == 'model.lm_head.weight':
- continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
@@ -65,7 +63,7 @@ def exam_model_step(placement_policy, model_name: str):
p.data.copy_(torch_p.data)
world_size = torch.distributed.get_world_size()
- config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+ config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
diff --git a/tests/test_gemini/update/test_search.py b/tests/test_gemini/update/test_search.py
index e0b4e207f16f..2fcdd5380906 100644
--- a/tests/test_gemini/update/test_search.py
+++ b/tests/test_gemini/update/test_search.py
@@ -6,7 +6,7 @@
import torch.multiprocessing as mp
import colossalai
-from colossalai.gemini.chunk import search_chunk_configuration
+from colossalai.gemini.chunk import init_chunk_manager, search_chunk_configuration
from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port, get_current_device
@@ -23,7 +23,6 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def exam_search_chunk_size():
-
world_size = torch.distributed.get_world_size()
pg_tp = ProcessGroup(tp_degree=world_size)
@@ -34,11 +33,11 @@ def exam_search_chunk_size():
with ColoInitContext(device=get_current_device()):
model = model_builder()
init_1d_row_spec(model, pg_tp)
- config_dict, _ = search_chunk_configuration(model,
- search_range_mb=1,
- search_interval_byte=16,
- min_chunk_size_mb=0,
- filter_exlarge_params=True)
+ config_dict, *_ = search_chunk_configuration(model,
+ search_range_mb=1,
+ search_interval_byte=16,
+ min_chunk_size_mb=0,
+ filter_exlarge_params=True)
for key in config_dict:
chunk_size = config_dict[key]['chunk_size']
@@ -48,9 +47,68 @@ def exam_search_chunk_size():
assert chunk_size == 1024
+def exam_search_strict_ddp():
+ world_size = torch.distributed.get_world_size()
+ default_shard_pg = ProcessGroup(tp_degree=world_size)
+ default_shard_spec = ShardSpec([-1], [world_size])
+
+ get_components_func = non_distributed_component_funcs.get_callable('gpt2')
+ model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
+ # get the chunk configuration over replicated models
+ with ColoInitContext(device=get_current_device()):
+ ddp_model = model_builder()
+ re_dict, re_total, re_wasted = search_chunk_configuration(ddp_model,
+ search_range_mb=1,
+ search_interval_byte=16,
+ min_chunk_size_mb=0,
+ filter_exlarge_params=True,
+ strict_ddp_flag=False)
+ # get the chunk configuration over sharded ddp models
+ with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
+ default_dist_spec=default_shard_spec):
+ sharded_ddp_model = model_builder()
+ sh_dict, sh_total, sh_wasted = search_chunk_configuration(sharded_ddp_model,
+ search_range_mb=1,
+ search_interval_byte=16,
+ min_chunk_size_mb=0,
+ filter_exlarge_params=True,
+ strict_ddp_flag=True)
+ assert re_dict == sh_dict
+ for key in re_dict:
+ assert re_dict[key] == sh_dict[key]
+
+ assert re_total == sh_total
+ assert re_wasted == sh_wasted
+
+
+def exam_chunk_manager():
+ world_size = torch.distributed.get_world_size()
+ default_shard_pg = ProcessGroup(tp_degree=world_size)
+ default_shard_spec = ShardSpec([-1], [world_size])
+
+ get_components_func = non_distributed_component_funcs.get_callable('gpt2')
+ model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
+
+ with ColoInitContext(device=get_current_device(), default_pg=default_shard_pg,
+ default_dist_spec=default_shard_spec):
+ sharded_ddp_model = model_builder()
+ chunk_manager = init_chunk_manager(sharded_ddp_model,
+ get_current_device(),
+ hidden_dim=16,
+ search_range_mb=1,
+ min_chunk_size_mb=0,
+ filter_exlarge_params=True,
+ strict_ddp_flag=True)
+ config_dict = chunk_manager.dp_degree_chunk_size_dict
+ assert len(config_dict) == 1
+ assert config_dict[world_size] == 31616
+
+
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_search_chunk_size()
+ exam_search_strict_ddp()
+ exam_chunk_manager()
@pytest.mark.dist
diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py
index 7b0c6e37a7e8..00d835842f79 100644
--- a/tests/test_gemini/update/test_zeroddp_state_dict.py
+++ b/tests/test_gemini/update/test_zeroddp_state_dict.py
@@ -4,6 +4,7 @@
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
+from torch.testing import assert_close
import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
@@ -17,6 +18,13 @@
from tests.test_tensor.common_utils import debug_print, set_seed
+def ignore_the_first_parameter(model: torch.nn.Module):
+ for name, param in model.named_parameters():
+ print(f"parameter `{name}` is set ignored")
+ ZeroDDP.set_params_to_ignore([param])
+ return
+
+
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@parameterize('keep_gathered', [True, False])
@parameterize('model_name', ['gpt2', 'bert'])
@@ -33,7 +41,7 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_p.data.copy_(p.data)
world_size = torch.distributed.get_world_size()
- config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+ config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
chunk_manager = ChunkManager(config_dict)
@@ -45,11 +53,9 @@ def exam_state_dict(placement_policy, keep_gathered, model_name: str):
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
- if key == 'model.lm_head.weight':
- continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
- assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
+ assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
@@ -67,7 +73,7 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
- config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+ config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
@@ -84,11 +90,9 @@ def exam_load_state_dict(placement_policy, keep_gathered, model_name: str):
zero_dict = model.state_dict(only_rank_0=False)
for key, value in torch_dict.items():
- if key == 'model.lm_head.weight':
- continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
- assert torch.equal(value, temp_zero_value), "parameter '{}' has problem.".format(key)
+ assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
def run_dist(rank, world_size, port):
diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py
index 7f53415bf22c..fd13af6b2b0a 100644
--- a/tests/test_gemini/update/test_zerooptim_state_dict.py
+++ b/tests/test_gemini/update/test_zerooptim_state_dict.py
@@ -33,7 +33,7 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
- config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+ config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[world_size]['chunk_size'] = 5000
config_dict[world_size]['keep_gathered'] = keep_gathered
@@ -70,8 +70,6 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
for n, m in v.items():
if isinstance(m, torch.Tensor):
o = w[n]
- if m.device != o.device:
- o = o.to(m.device)
assert torch.equal(m, o)
else:
assert m == w[n]
diff --git a/tests/test_tensor/common_utils/_utils.py b/tests/test_tensor/common_utils/_utils.py
index 6b58aa801d15..b405f8cd2108 100644
--- a/tests/test_tensor/common_utils/_utils.py
+++ b/tests/test_tensor/common_utils/_utils.py
@@ -4,6 +4,7 @@
import numpy as np
import torch
import torch.distributed as dist
+from torch.testing import assert_close
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
@@ -41,14 +42,20 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0):
return tensor_chunk.clone()
-def tensor_equal(A, B):
- return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
+def tensor_equal(t_a: torch.Tensor, t_b: torch.Tensor, rtol: float = 1e-3, atol: float = 1e-1):
+ assert_close(t_a, t_b, rtol=rtol, atol=atol)
+ return True
-def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size):
+def tensor_shard_equal(tensor: torch.Tensor,
+ shard: torch.Tensor,
+ rank: int,
+ world_size: int,
+ rtol: float = 1e-3,
+ atol: float = 1e-1):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
- return tensor_equal(tensor, shard)
+ return tensor_equal(tensor, shard, rtol, atol)
else:
dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape))
if dims_not_eq.numel() == 1:
@@ -58,7 +65,7 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
if rank is None:
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
- return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
+ return tensor_equal(tensor.chunk(world_size, dim)[rank], shard, rtol, atol)
else:
raise NotImplementedError
diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py
new file mode 100644
index 000000000000..547a96b264dc
--- /dev/null
+++ b/tests/test_tensor/test_dtensor/test_comm_spec.py
@@ -0,0 +1,190 @@
+from functools import partial
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from torch.distributed import ReduceOp
+
+from colossalai.core import global_context as gpc
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
+from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
+from colossalai.testing import rerun_if_address_is_in_use
+from colossalai.utils import free_port
+
+
+def check_all_gather(process_groups_dict, rank):
+ # tensor to comm
+ if rank in (0, 2):
+ sharded_tensor_to_comm = torch.ones(2, 2).cuda()
+ else:
+ sharded_tensor_to_comm = torch.zeros(2, 2).cuda()
+
+ # tensor to check
+ tensor_to_check = torch.cat((torch.ones(2, 2), torch.zeros(2, 2)), 1).cuda()
+
+ # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
+ comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
+ process_groups_dict,
+ gather_dim=1,
+ logical_process_axis=1)
+ sharded_tensor_to_comm = sharded_tensor_to_comm = comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
+
+ assert sharded_tensor_to_comm.equal(tensor_to_check)
+
+
+def check_shard(process_groups_dict, rank):
+ # tensor to comm
+ sharded_tensor_to_comm_0 = torch.zeros(2, 2).cuda()
+ sharded_tensor_to_comm_1 = torch.ones(2, 2).cuda()
+ # tensor([[0., 0., 1., 1.],
+ # [0., 0., 1., 1.]])
+ tensor_to_shard = torch.cat((sharded_tensor_to_comm_0, sharded_tensor_to_comm_1), 1)
+
+ # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
+ comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD,
+ process_groups_dict,
+ shard_dim=1,
+ logical_process_axis=1)
+ tensor_to_shard = comm_spec.covert_spec_to_action(tensor_to_shard)
+
+ if rank in (0, 2):
+ assert tensor_to_shard.equal(sharded_tensor_to_comm_0)
+ if rank in (1, 3):
+ assert tensor_to_shard.equal(sharded_tensor_to_comm_1)
+
+
+def check_all_to_all(process_groups_dict, rank):
+ # tensor to comm
+ if rank in (0, 1):
+ sharded_tensor_0 = torch.zeros(2, 1)
+ sharded_tensor_1 = torch.ones(2, 1)
+ # tensor([[0., 1.],
+ # [0., 1.]])
+ tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
+ if rank in (2, 3):
+ sharded_tensor_0 = torch.ones(2, 1) * 2
+ sharded_tensor_1 = torch.ones(2, 1) * 3
+ # tensor([[2., 3.],
+ # [2., 3.]])
+ tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
+
+ if rank in (0, 1):
+ # tensor([[0.],
+ # [0.],
+ # [2.],
+ # [2.]])
+ tensor_to_check = torch.tensor([[0], [0], [2], [2]], dtype=tensor_to_comm.dtype).cuda()
+ if rank in (2, 3):
+ # tensor([[1.],
+ # [1.],
+ # [3.],
+ # [3.]])
+ tensor_to_check = torch.tensor([[1], [1], [3], [3]], dtype=tensor_to_comm.dtype).cuda()
+
+ # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
+ comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD,
+ process_groups_dict,
+ gather_dim=0,
+ shard_dim=1,
+ logical_process_axis=0)
+ tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
+
+ assert tensor_to_comm.equal(tensor_to_check)
+
+
+def check_all_reduce_fwd(process_groups_dict, rank):
+ # tensor to comm
+ tensor_to_comm = torch.ones(2, 2).cuda() * rank
+
+ # reduce through logical process axis 0
+ # tensor to check
+ if rank in (0, 2):
+ # tensor([[2., 2.],
+ # [2., 2.]])
+ tensor_to_check = torch.tensor([[2, 2], [2, 2]], dtype=tensor_to_comm.dtype).cuda()
+ if rank in (1, 3):
+ # tensor([[4., 4.],
+ # [4., 4.]])
+ tensor_to_check = torch.tensor([[4, 4], [4, 4]], dtype=tensor_to_comm.dtype).cuda()
+
+ comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
+ tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
+
+ assert tensor_to_comm.equal(tensor_to_check)
+
+
+def check_all_reduce_bwd(process_groups_dict, rank):
+ # tensor to comm
+ tensor_to_comm = torch.ones(2, 2).cuda() * rank
+
+ tensor_to_check = torch.ones(2, 2).cuda() * rank
+
+ comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, process_groups_dict, logical_process_axis=0)
+ tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
+
+ assert tensor_to_comm.equal(tensor_to_check)
+
+
+def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank):
+ # tensor to comm
+ tensor_to_comm = torch.ones(2, 2).cuda() * rank
+
+ # reduce through logical process axis 0 at flatten device mesh
+ # tensor to check
+ # tensor([[6., 6.],
+ # [6., 6.]])
+ tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
+
+ # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
+ comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
+ tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
+
+ assert tensor_to_comm.equal(tensor_to_check)
+
+
+def check_comm(rank, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+
+ physical_mesh_id = torch.arange(0, 4)
+ assert rank == gpc.get_global_rank()
+
+ mesh_shape = (2, 2)
+ # [[0, 1,
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+ process_groups_dict = device_mesh.process_groups_dict
+
+ # test all gather
+ check_all_gather(process_groups_dict, rank)
+
+ # test shard
+ check_shard(process_groups_dict, rank)
+
+ # test all to all
+ check_all_to_all(process_groups_dict, rank)
+
+ # test all reduce
+ check_all_reduce_fwd(process_groups_dict, rank)
+ check_all_reduce_bwd(process_groups_dict, rank)
+
+ flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict
+ # test all reduce in 1D flatten device mesh
+ check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank)
+ gpc.destroy()
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_comm_spec():
+ world_size = 4
+ run_func = partial(check_comm, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_comm_spec()
diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py
new file mode 100644
index 000000000000..a99ac6e41c5e
--- /dev/null
+++ b/tests/test_tensor/test_dtensor/test_dtensor.py
@@ -0,0 +1,102 @@
+from functools import partial
+
+import torch
+import torch.multiprocessing as mp
+
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
+from colossalai.tensor.d_tensor.layout import Layout
+from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
+from colossalai.utils import free_port
+
+
+class TestModel(torch.nn.Module):
+
+ def __init__(self, in_features, out_features):
+ super().__init__()
+ self.linear_1 = torch.nn.Linear(in_features, out_features)
+ self.linear_2 = torch.nn.Linear(out_features, in_features)
+
+ def forward(self, x):
+ x = self.linear_1(x)
+ x = self.linear_2(x)
+ return x
+
+
+def check_dtensor(rank, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ test_model = TestModel(8, 8).to('cuda')
+ original_tensor = torch.rand(4, 8).to('cuda')
+ compare_output = test_model(original_tensor)
+
+ device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
+ target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
+ layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=target_sharding_spec,
+ entire_shape=original_tensor.shape)
+ d_tensor = DTensor(original_tensor, layout)
+
+ assert d_tensor.entire_shape == original_tensor.shape
+ assert d_tensor.data_type == original_tensor.dtype
+
+ if rank in (0, 1):
+ assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
+ elif rank in (2, 3):
+ assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
+ else:
+ raise ValueError(f'rank {rank} is not in the device mesh')
+ assert d_tensor.to_global().equal(original_tensor)
+ output = test_model(d_tensor)
+
+ if rank in (0, 1):
+ assert output.equal(compare_output.narrow(0, 0, 2))
+ elif rank in (2, 3):
+ assert output.equal(compare_output.narrow(0, 2, 2))
+ else:
+ raise ValueError(f'rank {rank} is not in the device mesh')
+
+ new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
+ new_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=new_sharding_spec,
+ entire_shape=original_tensor.shape)
+
+ d_tensor.layout_convert(new_layout)
+
+ if rank == 0:
+ assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
+ elif rank == 1:
+ assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
+ elif rank == 2:
+ assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
+ elif rank == 3:
+ assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
+ else:
+ raise ValueError(f'rank {rank} is not in the device mesh')
+
+ dtensor_from_local = distribute_tensor(original_tensor, new_layout)
+
+ if rank == 0:
+ assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
+ elif rank == 1:
+ assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
+ elif rank == 2:
+ assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
+ elif rank == 3:
+ assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
+ else:
+ raise ValueError(f'rank {rank} is not in the device mesh')
+
+
+def test_dtensor():
+ world_size = 4
+ run_func = partial(check_dtensor, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_dtensor()
diff --git a/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py
new file mode 100644
index 000000000000..7fd1c3d90fc4
--- /dev/null
+++ b/tests/test_tensor/test_dtensor/test_dtensor_sharding_spec.py
@@ -0,0 +1,34 @@
+import operator
+from functools import reduce
+
+from colossalai.tensor.d_tensor.sharding_spec import ALLGATHER_COST, SHARD_COST, STEP_PENALTY, ShardingSpec
+
+
+def test_dtensor_sharding_spec():
+ dims = 4
+ dim_partition_dict_0 = {0: [0, 1]}
+ # DistSpec:
+ # shard_sequence: S01,R,R,R
+ sharding_spec_0 = ShardingSpec(dims, dim_partition_dict=dim_partition_dict_0)
+ assert str(sharding_spec_0.sharding_sequence) == "[S01, R, R, R]"
+
+ dim_partition_dict_1 = {1: [0, 1]}
+ # DistSpec:
+ # shard_sequence: R,S01,R,R
+ sharding_spec_1 = ShardingSpec(dims, dim_partition_dict=dim_partition_dict_1)
+ assert str(sharding_spec_1.sharding_sequence) == "[R, S01, R, R]"
+
+ dim_spec_list_0 = [dim_spec for dim_spec in sharding_spec_0.sharding_sequence]
+ dim_spec_list_1 = [dim_spec for dim_spec in sharding_spec_1.sharding_sequence]
+
+ assert dim_spec_list_0[0].dim_diff(dim_spec_list_1[0]) == ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
+ assert dim_spec_list_0[1].dim_diff(dim_spec_list_1[1]) == SHARD_COST + STEP_PENALTY + SHARD_COST
+ assert dim_spec_list_0[2].dim_diff(dim_spec_list_1[2]) == 0
+ assert dim_spec_list_0[3].dim_diff(dim_spec_list_1[3]) == 0
+
+ assert sharding_spec_0.spec_diff(sharding_spec_1) == \
+ reduce(operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0)
+
+
+if __name__ == '__main__':
+ test_dtensor_sharding_spec()
diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py
new file mode 100644
index 000000000000..70cf8726dbd0
--- /dev/null
+++ b/tests/test_tensor/test_dtensor/test_layout_converter.py
@@ -0,0 +1,206 @@
+import math
+from functools import partial
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
+from colossalai.tensor.d_tensor.layout import Layout
+from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
+from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
+from colossalai.testing import rerun_if_address_is_in_use
+from colossalai.utils import free_port
+
+entire_shape = torch.Size((64, 32, 16))
+layout_converter = LayoutConverter()
+physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
+mesh_shape = (2, 2)
+
+
+def check_one_step_transform(rank, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ # [[0, 1],
+ # [2, 3]]
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+
+ dim_partition_dict = {0: [0], 1: [1]}
+ # DistSpec:
+ # shard_sequence: S0,S1,R
+ # device_mesh_shape: (2, 2)
+ sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
+ layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec,
+ entire_shape=entire_shape)
+
+ rst_dict = layout_converter.all_gather_transform_layouts(layout)
+
+ assert '[R, S1, R]' in [
+ str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
+ ]
+ assert '[S0, R, R]' in [
+ str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
+ ]
+
+ dim_partition_dict_all2all = {0: [0], 1: [1]}
+ # DistSpec:
+ # shard_sequence: S0,S1,R
+ # device_mesh_shape: (4, 4)
+ sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
+ layout_all2all = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_all2all,
+ entire_shape=entire_shape)
+
+ rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
+
+ assert '[S01, R, R]' in [
+ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
+ ]
+ assert '[R, S1, S0]' in [
+ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
+ ]
+ assert '[S0, R, S1]' in [
+ str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
+ ]
+
+ dim_partition_shard = {0: [0]}
+ # DistSpec:
+ # shard_sequence: S0,R,R
+ # device_mesh_shape: (4, 4)
+ sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
+ shard_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_shard,
+ entire_shape=entire_shape)
+
+ rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
+
+ assert '[S01, R, R]' in [
+ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
+ ]
+ assert '[S0, S1, R]' in [
+ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
+ ]
+ assert '[S0, R, S1]' in [
+ str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
+ ]
+
+
+def check_layout_converting(rank, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+ dim_partition_source = {1: [0, 1]}
+ dim_partition_target = {0: [0, 1]}
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+
+ # DistSpec:
+ # shard_sequence: R,S01,R
+ # device_mesh_shape: (4, 4)
+ sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
+ source_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_source,
+ entire_shape=entire_shape)
+
+ # DistSpec:
+ # shard_sequence: S01,R,R
+ # device_mesh_shape: (4, 4)
+ sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
+ target_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_target,
+ entire_shape=entire_shape)
+
+ transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
+
+ # check transform path
+ transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
+ assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
+
+ # check comm action sequence
+ # all-gather(S01) -> S0
+ assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
+ assert comm_action_sequence[0].gather_dim == 1
+ assert comm_action_sequence[0].logical_process_axis == 1
+
+ # all-to-all(R, S0) -> [S0, R]
+ assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
+ assert comm_action_sequence[1].gather_dim == 1
+ assert comm_action_sequence[1].shard_dim == 0
+ assert comm_action_sequence[1].logical_process_axis == 0
+
+ # shard(S0) -> [S01]
+ assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
+ assert comm_action_sequence[2].shard_dim == 0
+ assert comm_action_sequence[2].logical_process_axis == 1
+
+ # checkout chached_spec_pairs_transform_path
+ assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
+ assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
+
+ comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)
+
+ assert comm_cost['forward'] == comm_cost['backward']
+ assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward'])
+
+
+def check_layout_converting_apply(rank, world_size, port):
+ disable_existing_loggers()
+ launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+
+ dim_partition_source = {1: [0, 1]}
+ dim_partition_target = {0: [0, 1]}
+ device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+
+ # DistSpec:
+ # shard_sequence: R,S01,R
+ # device_mesh_shape: (4, 4)
+ sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
+ source_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_source,
+ entire_shape=entire_shape)
+
+ # DistSpec:
+ # shard_sequence: S01,R,R
+ # device_mesh_shape: (4, 4)
+ sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
+ target_layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=sharding_spec_target,
+ entire_shape=entire_shape)
+
+ original_tensor = torch.rand(entire_shape).cuda()
+
+ # tensor_to_apply: [R, S01, R]
+ tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
+
+ # tensor_to_check: [S01, R, R]
+ tensor_to_check = original_tensor.narrow(0, rank * 16, 16)
+
+ converted_tensor = layout_converter.apply(tensor_to_apply, source_layout, target_layout)
+ assert converted_tensor.equal(tensor_to_check)
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_layout_converter():
+ world_size = 4
+ run_func = partial(check_one_step_transform, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+ run_func = partial(check_layout_converting, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+ run_func = partial(check_layout_converting_apply, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_layout_converter()
diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py
index 33db676cb85f..1a6d23f6a2eb 100644
--- a/tests/test_tensor/test_tp_with_zero.py
+++ b/tests/test_tensor/test_tp_with_zero.py
@@ -27,8 +27,6 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
- if key == 'model.lm_head.weight':
- continue
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
@@ -87,7 +85,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
tp_init_spec_func(model, pg)
dp_world_size = pg.dp_world_size()
- config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
+ config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
config_dict[dp_world_size]['chunk_size'] = 5000
config_dict[dp_world_size]['keep_gathered'] = False
if placement_policy != 'cuda':
@@ -95,7 +93,7 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
else:
init_device = None
- model = GeminiDDP(model, init_device, placement_policy, True, False, 32)
+ model = GeminiDDP(model, init_device, placement_policy, True, False)
# The same as the following 3 lines
# chunk_manager = ChunkManager(config_dict, init_device=init_device)
# gemini_manager = GeminiManager(placement_policy, chunk_manager)
diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py
index 58e3b21d97eb..441cbbb22ce7 100644
--- a/tests/test_utils/test_flash_attention.py
+++ b/tests/test_utils/test_flash_attention.py
@@ -1,22 +1,13 @@
+import random
+
import pytest
import torch
from einops import rearrange
-from colossalai.kernel.cuda_native.flash_attention import HAS_FLASH_ATTN, HAS_MEM_EFF_ATTN, HAS_TRITON
-
-if HAS_FLASH_ATTN:
- from colossalai.kernel.cuda_native.flash_attention import (
- MaskedFlashAttention,
- flash_attention_q_k_v,
- flash_attention_q_kv,
- flash_attention_qkv,
- )
-
-if HAS_TRITON:
- from colossalai.kernel.cuda_native.flash_attention import triton_flash_attention
+from colossalai.kernel.cuda_native.flash_attention import HAS_MEM_EFF_ATTN
if HAS_MEM_EFF_ATTN:
- from colossalai.kernel.cuda_native.flash_attention import LowerTriangularMask, MemoryEfficientAttention
+ from colossalai.kernel.cuda_native.flash_attention import AttnMaskType, ColoAttention
def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
@@ -30,117 +21,88 @@ def baseline_attention(Z, N_CTX, H, q, k, v, sm_scale):
return ref_out
-@pytest.mark.skipif(HAS_TRITON == False, reason="triton is not available")
-@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
-def test_triton_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
- torch.manual_seed(20)
- q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
- k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
- v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
- sm_scale = 0.3
- dout = torch.randn_like(q)
-
- ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
- ref_out.backward(dout)
- ref_dv, v.grad = v.grad.clone(), None
- ref_dk, k.grad = k.grad.clone(), None
- ref_dq, q.grad = q.grad.clone(), None
-
- # triton implementation
- tri_out = triton_flash_attention(q, k, v, sm_scale)
- tri_out.backward(dout)
- tri_dv, v.grad = v.grad.clone(), None
- tri_dk, k.grad = k.grad.clone(), None
- tri_dq, q.grad = q.grad.clone(), None
- # compare
- assert torch.allclose(ref_out, tri_out, atol=1e-3)
- assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
- assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
- assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
-
-
-@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
-@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
-def test_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
- torch.manual_seed(20)
- q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
- k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
- v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
- sm_scale = 0.3
- dout = torch.randn_like(q)
-
- # reference implementation
- ref_out = baseline_attention(Z, N_CTX, H, q, k, v, sm_scale)
- ref_out.backward(dout)
- ref_dv, v.grad = v.grad.clone(), None
- ref_dk, k.grad = k.grad.clone(), None
- ref_dq, q.grad = q.grad.clone(), None
-
- # flash implementation
- q, k, v = map(lambda x: rearrange(x, 'z h n d -> (z n) h d'), [q, k, v])
- dout = rearrange(dout, 'z h n d -> (z n) h d').detach()
- for i in range(3):
- if i == 0:
- tri_out = flash_attention_q_k_v(q, k, v, sm_scale, Z, N_CTX, N_CTX, causal=True)
- elif i == 1:
- kv = torch.cat((k.unsqueeze(1), v.unsqueeze(1)), dim=1)
- tri_out = flash_attention_q_kv(q, kv, sm_scale, Z, N_CTX, N_CTX, causal=True)
- else:
- qkv = torch.cat((q.unsqueeze(1), k.unsqueeze(1), v.unsqueeze(1)), dim=1)
- tri_out = flash_attention_qkv(qkv, sm_scale, Z, N_CTX, causal=True)
-
- tri_out.backward(dout, retain_graph=True)
-
- if i == 0:
- tri_dq, tri_dk, tri_dv, = torch.autograd.grad(tri_out, (q, k, v), dout)
- tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
- (tri_out, tri_dq, tri_dk, tri_dv))
- elif i == 1:
- tri_dq, tri_dkv, = torch.autograd.grad(tri_out, (q, kv), dout)
- tri_dk, tri_dv = torch.chunk(tri_dkv, 2, dim=1)
- tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
- (tri_out, tri_dq, tri_dk.squeeze(1), tri_dv.squeeze(1)))
- else:
- tri_dqkv, = torch.autograd.grad(tri_out, (qkv), dout)
- tri_dq, tri_dk, tri_dv = torch.chunk(tri_dqkv, 3, dim=1)
- tri_out, tri_dq, tri_dk, tri_dv = map(lambda x: rearrange(x, '(z n) h d -> z h n d', z=Z),
- (tri_out, tri_dq.squeeze(1), tri_dk.squeeze(1), tri_dv.squeeze(1)))
-
- # compare
- assert torch.allclose(ref_out, tri_out, atol=1e-3)
- assert torch.allclose(ref_dv, tri_dv, atol=1e-3)
- assert torch.allclose(ref_dk, tri_dk, atol=1e-3)
- assert torch.allclose(ref_dq, tri_dq, atol=1e-3)
-
-
-@pytest.mark.skipif(HAS_FLASH_ATTN == False, reason="flash is not available")
-@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 4, 2, 16)])
-def test_masked_flash_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
- attn = MaskedFlashAttention(N_CTX, D_HEAD, 0.1)
-
- qkv = torch.randn((Z, H, 3 * N_CTX * D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
- attention_mask = torch.randint(2, (Z, H)).cuda().bool()
-
- out = attn(qkv, attention_mask)
-
- dout = torch.rand_like(out)
- out.backward(dout)
+@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
+def test_attention_gpt(B, S, H, D_HEAD, dtype=torch.float16):
+ D = H * D_HEAD
+
+ c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
+ attn = ColoAttention(D, H, dropout=0.1)
+
+ x = torch.randn((B, S, D), dtype=dtype, device="cuda")
+
+ qkv = c_attn(x)
+ q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
+ y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
+
+ assert list(y.shape) == [B, S, D]
+
+ dy = torch.rand_like(y)
+ y.backward(dy)
@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
-@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(6, 8, 4, 16)])
-def test_memory_efficient_attention(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
- attn = MemoryEfficientAttention(N_CTX * D_HEAD, N_CTX, 0.1)
+@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
+def test_attention_bert(B, S, H, D_HEAD, dtype=torch.float16):
+ D = H * D_HEAD
+
+ c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
+ attn = ColoAttention(D, H, dropout=0.1)
+
+ x = torch.randn((B, S, D), dtype=dtype, device="cuda")
+ # attention mask of shape [B, S] with zero padding to max length S
+ mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
+ mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)
+
+ qkv = c_attn(x)
+ q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
+ y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.padding)
+
+ assert list(y.shape) == [B, S, D]
+
+ dy = torch.rand_like(y)
+ y.backward(dy)
+
+
+@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.parametrize('B, S, H, D_HEAD', [(6, 8, 4, 16)])
+def test_attention_no_mask(B, S, H, D_HEAD, dtype=torch.float16):
+ D = H * D_HEAD
+
+ c_attn = torch.nn.Linear(D, 3 * D, dtype=dtype, device="cuda")
+ attn = ColoAttention(D, H, dropout=0.1)
+
+ x = torch.randn((B, S, D), dtype=dtype, device="cuda")
+ qkv = c_attn(x)
+ q, k, v = rearrange(qkv, 'b s (n h d) -> b s n h d', n=3, h=H).unbind(dim=2)
+ y = attn(q, k, v)
+
+ assert list(y.shape) == [B, S, D]
+
+ dy = torch.rand_like(y)
+ y.backward(dy)
+
+
+@pytest.mark.skipif(HAS_MEM_EFF_ATTN == False, reason="xformers is not available")
+@pytest.mark.parametrize('B, S, T, H, D_HEAD', [(6, 24, 8, 4, 16)])
+def test_cross_attention(B, S, T, H, D_HEAD, dtype=torch.float16):
+ D = H * D_HEAD
+
+ q_attn = torch.nn.Linear(D, D, dtype=dtype, device="cuda")
+ kv_attn = torch.nn.Linear(D, 2 * D, dtype=dtype, device="cuda")
- q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
- k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
- v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+ attn = ColoAttention(D, H, dropout=0.1)
- out = attn(q, k, v, attention_mask=LowerTriangularMask())
+ src = torch.randn((B, S, D), dtype=dtype, device="cuda")
+ tgt = torch.randn((B, T, D), dtype=dtype, device="cuda")
- dout = torch.rand_like(out)
- out.backward(dout)
+ q = q_attn(tgt)
+ kv = kv_attn(src)
+ q = rearrange(q, 'b s (h d) -> b s h d', h=H)
+ k, v = rearrange(kv, 'b s (n h d) -> b s n h d', n=2, h=H).unbind(dim=2)
+ y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)
+ assert list(y.shape) == [B, T, D]
-if __name__ == '__main__':
- test_flash_attention(3, 4, 2, 16)
+ dy = torch.rand_like(y)
+ y.backward(dy)
diff --git a/tests/test_utils/test_lazy_init/test_distribute.py b/tests/test_utils/test_lazy_init/test_distribute.py
new file mode 100644
index 000000000000..37b2c5da1efa
--- /dev/null
+++ b/tests/test_utils/test_lazy_init/test_distribute.py
@@ -0,0 +1,110 @@
+from functools import partial
+from typing import Optional
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+import colossalai
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.tensor.d_tensor.layout import Layout
+from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
+from colossalai.testing import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port
+from colossalai.utils.common import print_rank_0
+from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
+from tests.kit.model_zoo import model_zoo
+
+# from utils import assert_dist_model_equal, set_seed
+
+
+def find_shard_dim(shape: torch.Size) -> Optional[int]:
+ for dim, size in enumerate(shape):
+ if size % 2 == 0:
+ return dim
+
+
+def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
+ shard_dim = find_shard_dim(original_tensor.shape)
+ dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
+ target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
+ layout = Layout(device_mesh=device_mesh,
+ device_type=torch.device('cuda'),
+ sharding_spec=target_sharding_spec,
+ entire_shape=original_tensor.shape)
+ return layout
+
+
+def _get_current_name(prefix: str, name: str) -> str:
+ return f'{prefix}.{name}'.lstrip('.')
+
+
+def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
+ layout_dict = {}
+
+ @torch.no_grad()
+ def generate_recursively(module: nn.Module, prefix: str = ''):
+ # recursively initialize the module
+ for name, mod in module.named_children():
+ generate_recursively(mod, prefix=_get_current_name(prefix, name))
+
+ # initialize tensors directly attached to the current module
+ for name, param in module.named_parameters(recurse=False):
+ if isinstance(param, LazyTensor):
+ layout = make_layout(device_mesh, param)
+ layout_dict[_get_current_name(prefix, name)] = layout
+
+ for name, buf in module.named_buffers(recurse=False):
+ if isinstance(buf, LazyTensor):
+ layout = make_layout(device_mesh, buf)
+ layout_dict[_get_current_name(prefix, name)] = layout
+
+ generate_recursively(model)
+
+ return layout_dict
+
+
+@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
+def run_dist_lazy_init(subset, seed: int = 42):
+ sub_model_zoo = model_zoo.get_sub_registry(subset)
+ device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
+ # FIXME(ver217): uncomment this line
+ # _MyTensor._pre_op_fn = lambda *args: set_seed(seed)
+ # LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
+
+ for name, entry in sub_model_zoo.items():
+ # TODO(ver217): lazy init does not support weight norm, skip these models
+ if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
+ continue
+ print_rank_0(name)
+ model_fn, data_gen_fn, output_transform_fn, model_attr = entry
+ ctx = LazyInitContext(tensor_cls=_MyTensor)
+ with ctx:
+ model = model_fn()
+ ctx = LazyInitContext()
+ with ctx:
+ deferred_model = model_fn()
+ layout_dict = generate_layout_dict(deferred_model, device_mesh)
+ ctx.distribute(deferred_model, layout_dict, verbose=True)
+ # FIXME(ver217): uncomment this line
+ # assert_dist_model_equal(model, deferred_model, layout_dict)
+
+
+def run_dist(rank, world_size, port) -> None:
+ colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port)
+ run_dist_lazy_init()
+
+
+# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor
+@pytest.mark.skip
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_dist_lazy_init():
+ world_size = 4
+ run_func = partial(run_dist, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_dist_lazy_init()
diff --git a/tests/test_utils/test_lazy_init/test_models.py b/tests/test_utils/test_lazy_init/test_models.py
new file mode 100644
index 000000000000..9faddecbaca4
--- /dev/null
+++ b/tests/test_utils/test_lazy_init/test_models.py
@@ -0,0 +1,23 @@
+import pytest
+
+from tests.kit.model_zoo import model_zoo
+
+# FIXME(ver217): uncomment this line
+# from utils import check_lazy_init
+
+
+# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor
+@pytest.mark.skip
+@pytest.mark.parametrize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
+def test_torchvision_models_lazy_init(subset):
+ sub_model_zoo = model_zoo.get_sub_registry(subset)
+ for name, entry in sub_model_zoo.items():
+ # TODO(ver217): lazy init does not support weight norm, skip these models
+ if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
+ continue
+ # FIXME(ver217): uncomment this line
+ # check_lazy_init(entry, verbose=True)
+
+
+if __name__ == '__main__':
+ test_torchvision_models_lazy_init('torchvision')
diff --git a/tests/test_utils/test_lazy_init/utils.py b/tests/test_utils/test_lazy_init/utils.py
new file mode 100644
index 000000000000..a8aeb4c8930c
--- /dev/null
+++ b/tests/test_utils/test_lazy_init/utils.py
@@ -0,0 +1,85 @@
+import random
+from typing import Any, Callable, Optional, Tuple
+
+import numpy as np
+import torch
+
+from colossalai.tensor.d_tensor.layout_converter import to_global
+from colossalai.utils.model.experimental import LazyInitContext, LazyTensor, _MyTensor
+from tests.kit.model_zoo.registry import ModelAttribute
+
+# model_fn, data_gen_fn, output_transform_fn, model_attr
+TestingEntry = Tuple[Callable[[], torch.nn.Module], Callable[[], dict], Callable[[], dict], Optional[ModelAttribute]]
+
+
+def set_seed(seed: int) -> None:
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+
+
+def assert_model_eqaual(m1: torch.nn.Module, m2: torch.nn.Module) -> None:
+ s1 = m1.state_dict()
+ s2 = m2.state_dict()
+
+ assert len(s1) == len(s2), f'len {len(s1)} vs {len(s2)}'
+
+ for (n1, t1), (n2, t2) in zip(s1.items(), s2.items()):
+ assert n1 == n2
+ assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
+
+
+def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: Callable[[], dict],
+ output_transform_fn: Callable[[Any], dict]) -> None:
+ data = data_gen_fn()
+
+ m1.eval()
+ m2.eval()
+ # run forward
+ with torch.no_grad():
+ outputs1 = m1(**data)
+ outputs2 = m2(**data)
+
+ # compare output
+ transformed_out1 = output_transform_fn(outputs1)
+ transformed_out2 = output_transform_fn(outputs2)
+
+ assert len(transformed_out1) == len(transformed_out2)
+
+ for key, out1 in transformed_out1.items():
+ out2 = transformed_out2[key]
+ assert torch.allclose(out1, out2, atol=1e-5), \
+ f'{m1.__class__.__name__} has inconsistent outputs, {out1} vs {out2}'
+
+
+def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
+ model_fn, data_gen_fn, output_transform_fn, model_attr = entry
+ _MyTensor._pre_op_fn = lambda *args: set_seed(seed)
+ LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
+ ctx = LazyInitContext(tensor_cls=_MyTensor)
+ with ctx:
+ model = model_fn()
+ ctx = LazyInitContext()
+ with ctx:
+ deferred_model = model_fn()
+ deferred_model = ctx.materialize(deferred_model, verbose=verbose)
+ assert_model_eqaual(model, deferred_model)
+ if check_forward:
+ assert_forward_equal(model, deferred_model, data_gen_fn, output_transform_fn)
+ if verbose:
+ print(f'{model.__class__.__name__} pass')
+
+
+def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
+ state = model.state_dict()
+ distributed_state = distributed_model.state_dict()
+
+ assert len(state) == len(distributed_state), f'len {len(state)} vs {len(distributed_state)}'
+
+ for (n1, t1), (n2, t2) in zip(state.items(), distributed_state.items()):
+ assert n1 == n2
+ t1 = t1.cuda()
+ t2 = t2.cuda()
+ if n2 in layout_dict:
+ t2 = to_global(t2, layout_dict[n2])
+ assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
diff --git a/tests/test_zero/low_level_zero/test_grad_acc.py b/tests/test_zero/low_level_zero/test_grad_acc.py
index c23b3a3e8fd8..504df202e168 100644
--- a/tests/test_zero/low_level_zero/test_grad_acc.py
+++ b/tests/test_zero/low_level_zero/test_grad_acc.py
@@ -14,10 +14,10 @@
from colossalai.zero import LowLevelZeroOptimizer
-class TestModel(nn.Module):
+class MlpModel(nn.Module):
def __init__(self):
- super(TestModel, self).__init__()
+ super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
@@ -32,9 +32,8 @@ def exam_zero_1_2_grad_acc():
seed_all(2009)
# create model
- zero1_model = TestModel().cuda()
+ zero1_model = MlpModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
-
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1)
@@ -60,16 +59,16 @@ def fwd_bwd_func(number, cur_data):
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
- zero1_optimizer.backward(zero1_output.sum().float())
- zero2_optimizer.backward(zero2_output.sum().float())
+ zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False)
+ zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False)
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
- zero1_optimizer.sync_grad()
- zero2_optimizer.sync_grad()
+ zero1_optimizer._sync_grad()
+ zero2_optimizer._sync_grad()
fwd_bwd_func(0, input_data1)
fwd_bwd_func(1, input_data2)
@@ -89,9 +88,10 @@ def exam_zero_1_grad_acc():
seed_all(2008)
# create models
- zero_model = TestModel()
+ zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
+ seed_all(2008)
zero_model = zero_model.cuda()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
@@ -123,7 +123,7 @@ def fwd_bwd_func(number, cur_data, check_flag):
assert torch.equal(zero_output, torch_output)
# zero-dp backward
- zero_optimizer.backward(zero_output.sum().float())
+ zero_optimizer.backward(zero_output.sum().float(), sync_grad=False)
# torch-ddp backward
torch_output.sum().backward()
@@ -134,7 +134,7 @@ def fwd_bwd_func(number, cur_data, check_flag):
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, unscale_grad)
- zero_optimizer.sync_grad()
+ zero_optimizer._sync_grad()
fwd_bwd_func(0, input_data1, True)
fwd_bwd_func(1, input_data2, False)
@@ -153,7 +153,7 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_grad_acc()
- # exam_zero_1_2_grad_acc()
+ exam_zero_1_2_grad_acc()
@pytest.mark.dist
diff --git a/tests/test_zero/low_level_zero/test_zero1_2.py b/tests/test_zero/low_level_zero/test_zero1_2.py
index b02d3a6a4486..930b6129174e 100644
--- a/tests/test_zero/low_level_zero/test_zero1_2.py
+++ b/tests/test_zero/low_level_zero/test_zero1_2.py
@@ -14,10 +14,10 @@
from colossalai.zero import LowLevelZeroOptimizer
-class TestModel(nn.Module):
+class MlpModel(nn.Module):
def __init__(self):
- super(TestModel, self).__init__()
+ super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)
@@ -55,7 +55,7 @@ def exam_zero_1_2():
seed_all(2001)
# create model
- zero1_model = TestModel().cuda()
+ zero1_model = MlpModel().cuda()
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
@@ -78,16 +78,16 @@ def exam_zero_1_2():
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
- zero1_optimizer.backward(zero1_output.mean().float())
- zero2_optimizer.backward(zero2_output.mean().float())
+ zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False)
+ zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False)
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
if z2p.grad is not None:
# print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad)))
assert torch.equal(z1p.grad, z2p.grad)
- zero1_optimizer.sync_grad()
- zero2_optimizer.sync_grad()
+ zero1_optimizer._sync_grad()
+ zero2_optimizer._sync_grad()
# step
zero1_optimizer.step()
@@ -111,11 +111,11 @@ def exam_zero_1_torch_ddp():
seed_all(1453)
# create models
- zero_model = TestModel()
+ zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half()
- # torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
+ torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
torch_model = torch_model.cuda()
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
@@ -146,7 +146,7 @@ def exam_zero_1_torch_ddp():
half_close(zero_output, torch_output, loose=True)
# zero-dp backward
- zero_optimizer.backward(zero_output.mean().float())
+ zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
# torch-ddp backward
torch_output.mean().backward()
@@ -156,7 +156,7 @@ def exam_zero_1_torch_ddp():
half_close(p.grad, z1p.grad, loose=True)
# zero-dp step
- zero_optimizer.sync_grad()
+ zero_optimizer._sync_grad()
zero_optimizer.step()
# torch ddp step
diff --git a/tests/test_zero/low_level_zero/test_zero_init.py b/tests/test_zero/low_level_zero/test_zero_init.py
new file mode 100644
index 000000000000..1305da5df9c5
--- /dev/null
+++ b/tests/test_zero/low_level_zero/test_zero_init.py
@@ -0,0 +1,61 @@
+from functools import partial
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import torch.nn as nn
+
+import colossalai
+from colossalai.tensor import ProcessGroup
+from colossalai.utils import free_port, get_current_device
+from colossalai.utils.model.colo_init_context import ColoInitContext
+from colossalai.zero import LowLevelZeroOptimizer
+
+
+class MlpModel(nn.Module):
+
+ def __init__(self):
+ super(MlpModel, self).__init__()
+ self.linear1 = nn.Linear(128, 256)
+ self.linear2 = nn.Linear(256, 512)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.linear2(x)
+ return x
+
+
+def exam_zero_init():
+ dp_2_tp_2_pg = ProcessGroup(dp_degree=2, tp_degree=2)
+ model1 = MlpModel().cuda()
+ with ColoInitContext(device=get_current_device(), default_pg=dp_2_tp_2_pg):
+ model2 = MlpModel()
+ optimizer1 = LowLevelZeroOptimizer(torch.optim.Adam(model1.parameters(), lr=1))
+ optimizer2 = LowLevelZeroOptimizer(torch.optim.Adam(model2.parameters(), lr=1))
+
+ assert optimizer1._local_rank == optimizer2._local_rank
+ assert optimizer1._world_size == optimizer2._world_size
+ assert optimizer1._dp_global_ranks == optimizer2._dp_global_ranks
+
+ mp_group1 = optimizer1._mp_torch_group
+ mp_group2 = optimizer2._mp_torch_group
+ assert dist.get_world_size(mp_group1) == dist.get_world_size(mp_group2)
+ assert dist.get_rank(mp_group1) == dist.get_rank(mp_group2)
+
+
+def run_dist(rank, world_size, port):
+ config_dict = dict(parallel=dict(data=2, tensor=dict(size=2, mode='1d')))
+ colossalai.launch(config=config_dict, rank=rank, world_size=world_size, port=port, host='localhost')
+ exam_zero_init()
+
+
+@pytest.mark.dist
+def test_zero_init():
+ world_size = 4
+ run_func = partial(run_dist, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_zero_init()
diff --git a/tests/test_zero/low_level_zero/test_zero_tp.py b/tests/test_zero/low_level_zero/test_zero_tp.py
new file mode 100644
index 000000000000..15d3530ff90a
--- /dev/null
+++ b/tests/test_zero/low_level_zero/test_zero_tp.py
@@ -0,0 +1,99 @@
+from functools import partial
+
+import pytest
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.testing import assert_close
+
+import colossalai
+from colossalai.tensor import ProcessGroup
+from colossalai.testing import parameterize, rerun_if_address_is_in_use
+from colossalai.utils import free_port, get_current_device
+from colossalai.utils.model.colo_init_context import ColoInitContext
+from colossalai.zero import LowLevelZeroOptimizer
+from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal
+
+
+def strict_shard_equal(tensor, shard, tp_pg, rtol=1e-3, atol=1e-4):
+ return tensor_shard_equal(tensor, shard, tp_pg.tp_local_rank(), tp_pg.tp_world_size(), rtol, atol)
+
+
+class MlpModel(nn.Module):
+
+ def __init__(self):
+ super(MlpModel, self).__init__()
+ self.linear1 = nn.Linear(32, 128)
+ self.act = nn.GELU()
+ self.linear2 = nn.Linear(128, 32)
+
+ def forward(self, x):
+ y = self.linear1(x)
+ y = self.act(y)
+ y = self.linear2(y)
+ return x + y
+
+
+@parameterize("overlap_flag", [False, True])
+@parameterize("partition_flag", [False, True])
+def exam_zero_with_tp(overlap_flag, partition_flag):
+ set_seed(233010)
+ tp_pg = ProcessGroup(tp_degree=2)
+
+ with ColoInitContext(device=get_current_device(), default_pg=tp_pg):
+ hybrid_model = MlpModel()
+ torch_model = MlpModel().cuda()
+ for pt, ph in zip(torch_model.parameters(), hybrid_model.parameters()):
+ pt.data.copy_(ph.data)
+
+ for name, param in hybrid_model.named_parameters():
+ if 'linear1' in name:
+ split_param_row_tp1d(param, tp_pg)
+ param.compute_spec.set_output_replicate(False)
+ if 'linear2.weight' in name:
+ split_param_col_tp1d(param, tp_pg)
+
+ torch_model = DDP(torch_model, device_ids=[tp_pg.rank()], process_group=tp_pg.dp_process_group())
+ torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-2) # set to 1e-2 for torch-1.11
+ hybrid_optim = torch.optim.Adam(hybrid_model.parameters(), lr=1e-2)
+ hybrid_optim = LowLevelZeroOptimizer(hybrid_optim,
+ initial_scale=2,
+ clip_grad_norm=1.0,
+ overlap_communication=overlap_flag,
+ partition_grad=partition_flag)
+
+ dp_local_rank = tp_pg.dp_local_rank()
+ set_seed(255 + dp_local_rank)
+
+ data = torch.randn(8, 32, device=get_current_device())
+ torch_loss = torch_model(data).sum()
+ hybrid_loss = hybrid_model(data).sum()
+ assert_close(torch_loss, hybrid_loss)
+
+ torch_loss.backward()
+ torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0)
+ hybrid_optim.backward(hybrid_loss)
+
+ torch_optim.step()
+ hybrid_optim.step()
+
+ for (name, pt), ph in zip(torch_model.named_parameters(), hybrid_model.parameters()):
+ assert strict_shard_equal(pt.data, ph.data, tp_pg)
+
+
+def run_dist(rank, world_size, port):
+ colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
+ exam_zero_with_tp()
+
+
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+def test_zero_with_tp():
+ world_size = 4
+ run_func = partial(run_dist, world_size=world_size, port=free_port())
+ mp.spawn(run_func, nprocs=world_size)
+
+
+if __name__ == '__main__':
+ test_zero_with_tp()
diff --git a/version.txt b/version.txt
index 0ea3a944b399..b0032849c80b 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-0.2.0
+0.2.7